import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # -------------------------------------------------- # Model setup # -------------------------------------------------- BASE_MODEL = "abhinand/tamil-llama-7b-base-v0.1" ADAPTER_MODEL = "Vinushaanth/fine_tuned_metaphor_model_qlora" # change to your adapter repo name print("🔄 Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) print("🔄 Loading base model on CPU (this may take a while)...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) print("🔄 Loading adapter (QLoRA)...") model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, device_map="auto") model.eval() # -------------------------------------------------- # Inference # -------------------------------------------------- def generate_metaphors(vehicle, tenor, tamil_context, num_metaphors): if not (vehicle.strip() and tenor.strip() and tamil_context.strip()): return {"error": "⚠️ Please fill all fields properly."} prompt = f"Vehicle:{vehicle} Tenor:{tenor} Context:{tamil_context} " inputs = tokenizer(prompt, return_tensors="pt").to("cpu") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=60, num_return_sequences=num_metaphors, do_sample=True, top_p=0.9, temperature=0.8, pad_token_id=tokenizer.eos_token_id, early_stopping=True ) decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] metaphors = [t.replace(prompt, "").strip() for t in decoded_texts] return { "input_text": prompt, "num_metaphors": num_metaphors, "metaphors": [{"id": i+1, "text": m} for i, m in enumerate(metaphors)] } # -------------------------------------------------- # Gradio UI # -------------------------------------------------- with gr.Blocks(title="Tamil Metaphor Generation (QLoRA - CPU)") as demo: gr.Markdown("## 🌸 Tamil Metaphor Generation (QLoRA Fine-tuned Model, CPU)") gr.Markdown("Generate **creative Tamil metaphors** using your QLoRA adapter model. JSON output included.") with gr.Row(): vehicle_input = gr.Textbox(label="🔹 Vehicle (Source Concept)", placeholder="e.g., கண்கள் (Eyes)") tenor_input = gr.Textbox(label="🔹 Tenor (Target Concept)", placeholder="e.g., ஏரி (Lake)") tamil_context_input = gr.Textbox(label="🔹 Context", placeholder="e.g., Comparing a woman to nature.") num_input = gr.Slider(1, 5, value=2, step=1, label="Number of metaphors") output_box = gr.JSON(label="Generated Output (JSON)") submit_btn = gr.Button("✨ Generate Metaphors") submit_btn.click( fn=generate_metaphors, inputs=[vehicle_input, tenor_input, tamil_context_input, num_input], outputs=output_box ) # -------------------------------------------------- # Launch for local testing # -------------------------------------------------- if __name__ == "__main__": demo.launch()