RavinduW's picture
Update app.py
87e2094 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# --------------------------------------------------
# Model setup
# --------------------------------------------------
BASE_MODEL = "abhinand/tamil-llama-7b-base-v0.1"
ADAPTER_MODEL = "Vinushaanth/fine_tuned_metaphor_model_qlora" # change to your adapter repo name
print("🔄 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
print("🔄 Loading base model on CPU (this may take a while)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=True,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
print("🔄 Loading adapter (QLoRA)...")
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, device_map="auto")
model.eval()
# --------------------------------------------------
# Inference
# --------------------------------------------------
def generate_metaphors(vehicle, tenor, tamil_context, num_metaphors):
if not (vehicle.strip() and tenor.strip() and tamil_context.strip()):
return {"error": "⚠️ Please fill all fields properly."}
prompt = f"Vehicle:{vehicle} Tenor:{tenor} Context:{tamil_context} <sep>"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=60,
num_return_sequences=num_metaphors,
do_sample=True,
top_p=0.9,
temperature=0.8,
pad_token_id=tokenizer.eos_token_id,
early_stopping=True
)
decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
metaphors = [t.replace(prompt, "").strip() for t in decoded_texts]
return {
"input_text": prompt,
"num_metaphors": num_metaphors,
"metaphors": [{"id": i+1, "text": m} for i, m in enumerate(metaphors)]
}
# --------------------------------------------------
# Gradio UI
# --------------------------------------------------
with gr.Blocks(title="Tamil Metaphor Generation (QLoRA - CPU)") as demo:
gr.Markdown("## 🌸 Tamil Metaphor Generation (QLoRA Fine-tuned Model, CPU)")
gr.Markdown("Generate **creative Tamil metaphors** using your QLoRA adapter model. JSON output included.")
with gr.Row():
vehicle_input = gr.Textbox(label="🔹 Vehicle (Source Concept)", placeholder="e.g., கண்கள் (Eyes)")
tenor_input = gr.Textbox(label="🔹 Tenor (Target Concept)", placeholder="e.g., ஏரி (Lake)")
tamil_context_input = gr.Textbox(label="🔹 Context", placeholder="e.g., Comparing a woman to nature.")
num_input = gr.Slider(1, 5, value=2, step=1, label="Number of metaphors")
output_box = gr.JSON(label="Generated Output (JSON)")
submit_btn = gr.Button("✨ Generate Metaphors")
submit_btn.click(
fn=generate_metaphors,
inputs=[vehicle_input, tenor_input, tamil_context_input, num_input],
outputs=output_box
)
# --------------------------------------------------
# Launch for local testing
# --------------------------------------------------
if __name__ == "__main__":
demo.launch()