Spaces:
Running
Running
File size: 4,454 Bytes
fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 c0825f1 fd3aa38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
import torch
# Define available models and BLEU scores
MODEL_OPTIONS = {
"Version v1 (EN→RU BLEU: 35.93 | RU→EN BLEU: 41.11)": {
"key": "v1",
"en_ru": "kafarasi/marian-en-ru-finetuned",
"ru_en": "kafarasi/marian-ru-en-finetuned"
},
"Version v2 (EN→RU BLEU: 36.61 | RU→EN BLEU: 50.00)": {
"key": "v2",
"en_ru": "kafarasi/marian-en-ru-finetunedv2",
"ru_en": "kafarasi/marian-ru-en-finetunedv2"
},
"Version v3 (EN→RU BLEU: 37.52 | RU→EN BLEU: 47.06)": {
"key": "v3",
"en_ru": "kafarasi/marian-en-ru-finetunedv4",
"ru_en": "kafarasi/marian-ru-en-finetunedv4"
}
}
# Select device (CPU for Hugging Face Spaces)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_float16 = device.type == "cuda"
loaded_models = {}
# Load and cache models
def get_model_and_tokenizer(display_label, direction):
version_info = MODEL_OPTIONS[display_label]
version_key = version_info["key"]
cache_key = f"{version_key}_{direction}"
if cache_key in loaded_models:
return loaded_models[cache_key]
model_name = version_info["en_ru"] if direction == "English → Runyankore" else version_info["ru_en"]
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(
model_name,
torch_dtype=torch.float16 if use_float16 else torch.float32
).to(device)
loaded_models[cache_key] = (tokenizer, model)
return tokenizer, model
# Translation function (no splitting)
def translate_text(text, direction, version_label):
tokenizer, model = get_model_and_tokenizer(version_label, direction)
text = text.strip()
if not text:
return ""
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=128, num_beams=1)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clear inputs
def clear_fields():
return "", ""
# UI Styling
custom_css = """
body {
background: linear-gradient(to right, #e3f2fd, #fce4ec);
font-family: 'Segoe UI', sans-serif;
}
h1 {
color: #2c3e50;
font-size: 32px;
text-align: center;
margin-bottom: 10px;
}
p {
text-align: center;
font-size: 16px;
color: #4e4e4e;
}
.gradio-container {
max-width: 900px;
margin: auto;
padding: 30px;
background: #ffffff;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
border-radius: 16px;
}
textarea, input {
font-size: 16px !important;
border: 2px solid #2980b9;
border-radius: 8px !important;
}
.gr-button {
background-color: #3498db !important;
color: white !important;
border-radius: 8px !important;
font-size: 16px !important;
padding: 10px 20px !important;
transition: background-color 0.3s ease;
}
.gr-button:hover {
background-color: #2c81ba !important;
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css) as iface:
gr.Markdown("<h1>Runyankore ↔ English Translator</h1>")
gr.Markdown(
"<p>Select a model version and translation direction. Input text will be translated efficiently even on CPU.</p>"
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model Version", value=list(MODEL_OPTIONS.keys())[1])
text_input = gr.Textbox(lines=5, label="Input Text", placeholder="Enter text...", interactive=True, show_copy_button=True)
direction = gr.Radio(["English → Runyankore", "Runyankore → English"], label="Translation Direction")
with gr.Row():
translate_btn = gr.Button("🔄 Translate")
clear_btn = gr.Button("🗑️ Clear")
with gr.Column(scale=1):
output_text = gr.Textbox(lines=5, label="Translated Output", interactive=False, show_copy_button=True)
translate_btn.click(fn=translate_text, inputs=[text_input, direction, model_selector], outputs=output_text)
clear_btn.click(fn=clear_fields, outputs=[text_input, output_text])
gr.Markdown(
"<p style='text-align: center;'>💬 Feedback? "
"<a href='https://forms.gle/WD1sFUD6JHSxrr3u6' target='_blank' style='color: #2980b9;'>Click here to help improve</a></p>"
)
# Launch
iface.launch()
|