|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL = "abhinand/tamil-llama-7b-base-v0.1" |
|
|
ADAPTER_MODEL = "Vinushaanth/fine_tuned_metaphor_model_qlora" |
|
|
|
|
|
print("🔄 Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
|
|
|
|
print("🔄 Loading base model on CPU (this may take a while)...") |
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL, |
|
|
load_in_8bit=True, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
|
|
|
print("🔄 Loading adapter (QLoRA)...") |
|
|
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, device_map="auto") |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_metaphors(vehicle, tenor, tamil_context, num_metaphors): |
|
|
if not (vehicle.strip() and tenor.strip() and tamil_context.strip()): |
|
|
return {"error": "⚠️ Please fill all fields properly."} |
|
|
|
|
|
prompt = f"Vehicle:{vehicle} Tenor:{tenor} Context:{tamil_context} <sep>" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to("cpu") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=60, |
|
|
num_return_sequences=num_metaphors, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.8, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|
metaphors = [t.replace(prompt, "").strip() for t in decoded_texts] |
|
|
|
|
|
return { |
|
|
"input_text": prompt, |
|
|
"num_metaphors": num_metaphors, |
|
|
"metaphors": [{"id": i+1, "text": m} for i, m in enumerate(metaphors)] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Tamil Metaphor Generation (QLoRA - CPU)") as demo: |
|
|
gr.Markdown("## 🌸 Tamil Metaphor Generation (QLoRA Fine-tuned Model, CPU)") |
|
|
gr.Markdown("Generate **creative Tamil metaphors** using your QLoRA adapter model. JSON output included.") |
|
|
|
|
|
with gr.Row(): |
|
|
vehicle_input = gr.Textbox(label="🔹 Vehicle (Source Concept)", placeholder="e.g., கண்கள் (Eyes)") |
|
|
tenor_input = gr.Textbox(label="🔹 Tenor (Target Concept)", placeholder="e.g., ஏரி (Lake)") |
|
|
tamil_context_input = gr.Textbox(label="🔹 Context", placeholder="e.g., Comparing a woman to nature.") |
|
|
num_input = gr.Slider(1, 5, value=2, step=1, label="Number of metaphors") |
|
|
|
|
|
output_box = gr.JSON(label="Generated Output (JSON)") |
|
|
submit_btn = gr.Button("✨ Generate Metaphors") |
|
|
|
|
|
submit_btn.click( |
|
|
fn=generate_metaphors, |
|
|
inputs=[vehicle_input, tenor_input, tamil_context_input, num_input], |
|
|
outputs=output_box |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|