afriontranslate / app.py
kafarasi's picture
Update app.py
5d0d576 verified
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
import torch
# Define available models and BLEU scores
MODEL_OPTIONS = {
"Version v1 (EN→RU BLEU: 35.93 | RU→EN BLEU: 41.11)": {
"key": "v1",
"en_ru": "kafarasi/marian-en-ru-finetuned",
"ru_en": "kafarasi/marian-ru-en-finetuned"
},
"Version v3 (EN→RU BLEU: 37.52 | RU→EN BLEU: 47.06)": {
"key": "v3",
"en_ru": "kafarasi/marian-en-ru-finetunedv4",
"ru_en": "kafarasi/marian-ru-en-finetunedv4"
}
}
# Select device (CPU for Hugging Face Spaces)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_float16 = device.type == "cuda"
loaded_models = {}
# Load and cache models
def get_model_and_tokenizer(display_label, direction):
version_info = MODEL_OPTIONS[display_label]
version_key = version_info["key"]
cache_key = f"{version_key}_{direction}"
if cache_key in loaded_models:
return loaded_models[cache_key]
model_name = version_info["en_ru"] if direction == "English → Runyankore" else version_info["ru_en"]
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(
model_name,
torch_dtype=torch.float16 if use_float16 else torch.float32
).to(device)
loaded_models[cache_key] = (tokenizer, model)
return tokenizer, model
# Translation function (no splitting)
def translate_text(text, direction, version_label):
tokenizer, model = get_model_and_tokenizer(version_label, direction)
text = text.strip()
if not text:
return ""
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=128, num_beams=1)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clear inputs
def clear_fields():
return "", ""
# UI Styling
custom_css = """
body {
background: linear-gradient(to right, #e3f2fd, #fce4ec);
font-family: 'Segoe UI', sans-serif;
}
h1 {
color: #2c3e50;
font-size: 32px;
text-align: center;
margin-bottom: 10px;
}
p {
text-align: center;
font-size: 16px;
color: #4e4e4e;
}
.gradio-container {
max-width: 900px;
margin: auto;
padding: 30px;
background: #ffffff;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
border-radius: 16px;
}
textarea, input {
font-size: 16px !important;
border: 2px solid #2980b9;
border-radius: 8px !important;
}
.gr-button {
background-color: #3498db !important;
color: white !important;
border-radius: 8px !important;
font-size: 16px !important;
padding: 10px 20px !important;
transition: background-color 0.3s ease;
}
.gr-button:hover {
background-color: #2c81ba !important;
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css) as iface:
gr.Markdown("<h1>Runyankore ↔ English Translator</h1>")
gr.Markdown(
"<p>Select a model version and translation direction. Input text will be translated efficiently even on CPU.</p>"
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model Version", value=list(MODEL_OPTIONS.keys())[1])
text_input = gr.Textbox(lines=5, label="Input Text", placeholder="Enter text...", interactive=True, show_copy_button=True)
direction = gr.Radio(["English → Runyankore", "Runyankore → English"], label="Translation Direction")
with gr.Row():
translate_btn = gr.Button("🔄 Translate")
clear_btn = gr.Button("🗑️ Clear")
with gr.Column(scale=1):
output_text = gr.Textbox(lines=5, label="Translated Output", interactive=False, show_copy_button=True)
translate_btn.click(fn=translate_text, inputs=[text_input, direction, model_selector], outputs=output_text)
clear_btn.click(fn=clear_fields, outputs=[text_input, output_text])
gr.Markdown(
"<p style='text-align: center;'>💬 Feedback? "
"<a href='https://forms.gle/WD1sFUD6JHSxrr3u6' target='_blank' style='color: #2980b9;'>Click here to help improve</a></p>"
)
# Launch
iface.launch()