hurairamuzammal's picture
Update app.py
631a919 verified
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
pipeline
)
# -------------------------
# Model repo IDs (change if needed)
# -------------------------
BERT_ID = "hurairamuzammal/bert_fintetuned_sentiment"
GPT2_ID = "hurairamuzammal/gpt2_fintetuned_code_gen"
PARROT_ID = "hurairamuzammal/parrot_finetuned_code_gen"
SUMMARIZER_ID = "hurairamuzammal/encoder_decoder_T5_summarizer" # Your model
# -------------------------
# Device setup
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
device_index = 0 if device == "cuda" else -1
# -------------------------
# Load BERT (Task 1)
# -------------------------
try:
bert_tokenizer = AutoTokenizer.from_pretrained(BERT_ID)
bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_ID)
sentiment_pipeline = pipeline(
"text-classification",
model=bert_model,
tokenizer=bert_tokenizer,
device=device_index
)
bert_loaded = True
except Exception as e:
bert_loaded = False
bert_error = str(e)
# -------------------------
# Load GPT-2 and Parrot (Task 2)
# -------------------------
def safe_load_causal(repo_id):
try:
tok = AutoTokenizer.from_pretrained(repo_id)
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.float16)
model.to(device)
else:
model = AutoModelForCausalLM.from_pretrained(repo_id)
return tok, model, None
except Exception as e:
return None, None, str(e)
gpt2_tok, gpt2_model, gpt2_err = safe_load_causal(GPT2_ID)
parrot_tok, parrot_model, parrot_err = safe_load_causal(PARROT_ID)
gpt2_loaded = gpt2_model is not None
parrot_loaded = parrot_model is not None
# -------------------------
# Load summarizer (Task 3) - Simple merged model
# -------------------------
try:
sum_tok = AutoTokenizer.from_pretrained(SUMMARIZER_ID)
sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_ID)
if device == "cuda":
sum_model.to(device)
summarizer_loaded = True
summarizer_error = None
except Exception as e:
summarizer_loaded = False
summarizer_error = str(e)
sum_tok = None
sum_model = None
# -------------------------
# Inference helpers
# -------------------------
def classify_sentiment(text):
"""Returns simple Positive or Negative sentiment"""
if not bert_loaded:
return f"Error: BERT model not loaded - {bert_error}"
if not text.strip():
return "Please enter some text to analyze"
try:
result = sentiment_pipeline(text)[0]
label = result["label"].upper()
score = result["score"]
# Map to Positive/Negative only
if "POS" in label or label == "POSITIVE" or label == "1":
sentiment = "Positive"
else:
sentiment = "Negative"
return f"**{sentiment}** (Confidence: {score:.2%})"
except Exception as e:
return f"Error during classification: {str(e)}"
def extract_code_from_generated(decoded, prompt_marker=None, code_marker="<|code|>", end_marker="<|end|>"):
if code_marker in decoded:
code_part = decoded.split(code_marker, 1)[1]
if end_marker in code_part:
code_part = code_part.split(end_marker, 1)[0]
return code_part.strip()
if prompt_marker and prompt_marker in decoded:
return decoded.split(prompt_marker, 1)[1].strip()
return decoded.strip()
def generate_code(model_name, pseudo, max_new_tokens=200, temperature=0.7, top_p=0.95):
pseudo = pseudo.strip()
if not pseudo:
return "Please enter pseudo-code"
if model_name == "GPT-2":
if not gpt2_loaded:
return f"GPT-2 not loaded: {gpt2_err}"
tok, model = gpt2_tok, gpt2_model
else:
if not parrot_loaded:
return f"CodeParrot not loaded: {parrot_err}"
tok, model = parrot_tok, parrot_model
prompt = pseudo
if "<|pseudocode|>" not in pseudo:
prompt = f"<|pseudocode|>{pseudo}<|code|>"
inputs = tok(prompt, return_tensors="pt", truncation=True, padding=True).to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tok.eos_token_id
)
decoded = tok.decode(out[0], skip_special_tokens=False)
code = extract_code_from_generated(decoded, prompt_marker=prompt)
return code
def summarize_text(text, max_length=120, min_length=30):
if not summarizer_loaded:
return f"Summarization model not loaded: {summarizer_error}"
if not text.strip():
return "Please enter text to summarize"
try:
# Add "summarize: " prefix for T5 models
input_text = "summarize: " + text
inputs = sum_tok(
input_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
).to(sum_model.device)
with torch.no_grad():
out = sum_model.generate(
**inputs,
max_length=max_length,
min_length=min_length,
num_beams=4,
early_stopping=True
)
summary = sum_tok.decode(out[0], skip_special_tokens=True)
return summary
except Exception as e:
return f"Error during summarization: {str(e)}"
# -------------------------
# CSS styling
# -------------------------
css = """
body, html {
overflow-y: auto !important;
}
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
}
pre, code {
white-space: pre-wrap !important;
word-wrap: break-word !important;
}
"""
# -------------------------
# Gradio Interface
# -------------------------
with gr.Blocks(css=css, title="Project 03 - Transformer Tasks", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Assignment 03 — Fine Tuning Models")
gr.Markdown("*Fine-tuned models for Sentiment Analysis, Code Generation, and Summarization*")
# Task 1: Sentiment (Encoder-only)
with gr.Tab("Sentiment Analysis"):
gr.Markdown("### BERT Sentiment Classification")
gr.Markdown("Analyze the sentiment of customer reviews as **Positive** or **Negative**")
t1_input = gr.Textbox(
lines=4,
placeholder="Enter customer review here...",
label="Customer Review"
)
t1_button = gr.Button("Analyze Sentiment", variant="primary", size="lg")
t1_output = gr.Markdown(label="Sentiment Result")
gr.Markdown("#### Example Reviews")
gr.Examples(
examples=[
["I love this product! Fast delivery and great quality."],
["The package arrived broken and support did not reply."],
["Product is okay but could be improved."],
["Absolutely fantastic! Exceeded my expectations."],
["Worst purchase ever. Complete waste of money."],
],
inputs=t1_input
)
t1_button.click(fn=classify_sentiment, inputs=t1_input, outputs=t1_output)
# Task 2: Decoder-only (GPT-2 / CodeParrot)
with gr.Tab("Code Generation"):
gr.Markdown("### Pseudo-code to Python Code")
gr.Markdown("Convert natural language descriptions into Python code")
model_dd = gr.Dropdown(
["GPT-2", "CodeParrot"],
value="GPT-2",
label="Model Selection"
)
t2_input = gr.Textbox(
lines=4,
placeholder="Describe what you want the code to do...",
label="Pseudo-code / Description"
)
with gr.Accordion("Generation Parameters", open=False):
t2_max_tokens = gr.Slider(32, 512, step=32, value=200, label="Max tokens")
t2_temp = gr.Slider(0.1, 1.0, step=0.05, value=0.7, label="Temperature")
t2_top_p = gr.Slider(0.1, 1.0, step=0.05, value=0.95, label="Top-p")
t2_button = gr.Button("Generate Code", variant="primary", size="lg")
t2_output = gr.Code(language="python", label="Generated Python Code", lines=12)
gr.Markdown("#### Example Prompts")
gr.Examples(
examples=[
["GPT-2", "Loop through a list and print each number."],
["GPT-2", "Function to compute factorial using recursion."],
["CodeParrot", "Function to check if string is palindrome."],
["CodeParrot", "Class Point with x, y coordinates and distance_to method."],
],
inputs=[model_dd, t2_input]
)
t2_button.click(
fn=generate_code,
inputs=[model_dd, t2_input, t2_max_tokens, t2_temp, t2_top_p],
outputs=t2_output
)
# Task 3: Encoder-Decoder (Summarization)
with gr.Tab("Text Summarization"):
gr.Markdown("### T5 Text Summarization")
gr.Markdown("Generate concise summaries of long articles or documents")
t3_input = gr.Textbox(
lines=10,
placeholder="Paste your article or long text here...",
label="Article Text"
)
with gr.Row():
t3_min = gr.Slider(10, 80, step=5, value=30, label="Minimum length")
t3_max = gr.Slider(40, 300, step=5, value=120, label="Maximum length")
t3_button = gr.Button("Generate Summary", variant="primary", size="lg")
t3_output = gr.Textbox(label="Summary", lines=6)
gr.Markdown("#### Example Texts")
gr.Examples(
examples=[
[("Artificial Intelligence (AI) has rapidly evolved from a futuristic concept to an integral part of everyday life, reshaping industries, decision-making processes, and human interaction. Powered by machine learning, natural language processing, and vast datasets, AI systems now assist in everything from diagnosing diseases and personalizing education to optimizing logistics and generating creative content. While AI promises enhanced efficiency and innovation, it also raises ethical concerns around privacy, bias, job displacement, and the concentration of technological power. As developers and policymakers work to establish responsible frameworks, the global conversation continues to balance the immense potential of AI with the need to safeguard human values and societal well-being." )],
[("The world has increasingly transformed into what sociologist Marshall McLuhan famously termed a “global village”—a metaphor made more relevant today through the lens of urbanization and digital connectivity. Cities have become dense hubs of culture, commerce, and communication, where people from diverse backgrounds live in closer proximity than ever before. Simultaneously, digital technologies—social media, instant messaging, and real-time news feeds—have collapsed geographical distances, allowing individuals in Tokyo, Nairobi, or São Paulo to share experiences and ideas instantaneously. This urban-digital convergence has fostered unprecedented interdependence, blurring traditional boundaries and creating a shared global consciousness, while also presenting challenges related to inequality, cultural homogenization, and the sustainability of urban ecosystems.")],
],
inputs=t3_input
)
t3_button.click(
fn=lambda text, max_l, min_l: summarize_text(text, max_length=max_l, min_length=min_l),
inputs=[t3_input, t3_max, t3_min],
outputs=t3_output
)
gr.Markdown("---")
gr.Markdown("**Note:** Models may take a moment to load on first startup. Running on: `" + device + "`")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False)