Spaces:
Running
Running
| from transformers import MarianMTModel, MarianTokenizer | |
| import gradio as gr | |
| import torch | |
| # Define available models and BLEU scores | |
| MODEL_OPTIONS = { | |
| "Version v1 (EN→RU BLEU: 35.93 | RU→EN BLEU: 41.11)": { | |
| "key": "v1", | |
| "en_ru": "kafarasi/marian-en-ru-finetuned", | |
| "ru_en": "kafarasi/marian-ru-en-finetuned" | |
| }, | |
| "Version v3 (EN→RU BLEU: 37.52 | RU→EN BLEU: 47.06)": { | |
| "key": "v3", | |
| "en_ru": "kafarasi/marian-en-ru-finetunedv4", | |
| "ru_en": "kafarasi/marian-ru-en-finetunedv4" | |
| } | |
| } | |
| # Select device (CPU for Hugging Face Spaces) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| use_float16 = device.type == "cuda" | |
| loaded_models = {} | |
| # Load and cache models | |
| def get_model_and_tokenizer(display_label, direction): | |
| version_info = MODEL_OPTIONS[display_label] | |
| version_key = version_info["key"] | |
| cache_key = f"{version_key}_{direction}" | |
| if cache_key in loaded_models: | |
| return loaded_models[cache_key] | |
| model_name = version_info["en_ru"] if direction == "English → Runyankore" else version_info["ru_en"] | |
| tokenizer = MarianTokenizer.from_pretrained(model_name) | |
| model = MarianMTModel.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if use_float16 else torch.float32 | |
| ).to(device) | |
| loaded_models[cache_key] = (tokenizer, model) | |
| return tokenizer, model | |
| # Translation function (no splitting) | |
| def translate_text(text, direction, version_label): | |
| tokenizer, model = get_model_and_tokenizer(version_label, direction) | |
| text = text.strip() | |
| if not text: | |
| return "" | |
| inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_length=128, num_beams=1) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clear inputs | |
| def clear_fields(): | |
| return "", "" | |
| # UI Styling | |
| custom_css = """ | |
| body { | |
| background: linear-gradient(to right, #e3f2fd, #fce4ec); | |
| font-family: 'Segoe UI', sans-serif; | |
| } | |
| h1 { | |
| color: #2c3e50; | |
| font-size: 32px; | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| p { | |
| text-align: center; | |
| font-size: 16px; | |
| color: #4e4e4e; | |
| } | |
| .gradio-container { | |
| max-width: 900px; | |
| margin: auto; | |
| padding: 30px; | |
| background: #ffffff; | |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); | |
| border-radius: 16px; | |
| } | |
| textarea, input { | |
| font-size: 16px !important; | |
| border: 2px solid #2980b9; | |
| border-radius: 8px !important; | |
| } | |
| .gr-button { | |
| background-color: #3498db !important; | |
| color: white !important; | |
| border-radius: 8px !important; | |
| font-size: 16px !important; | |
| padding: 10px 20px !important; | |
| transition: background-color 0.3s ease; | |
| } | |
| .gr-button:hover { | |
| background-color: #2c81ba !important; | |
| } | |
| """ | |
| # Gradio Interface | |
| with gr.Blocks(css=custom_css) as iface: | |
| gr.Markdown("<h1>Runyankore ↔ English Translator</h1>") | |
| gr.Markdown( | |
| "<p>Select a model version and translation direction. Input text will be translated efficiently even on CPU.</p>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model Version", value=list(MODEL_OPTIONS.keys())[1]) | |
| text_input = gr.Textbox(lines=5, label="Input Text", placeholder="Enter text...", interactive=True, show_copy_button=True) | |
| direction = gr.Radio(["English → Runyankore", "Runyankore → English"], label="Translation Direction") | |
| with gr.Row(): | |
| translate_btn = gr.Button("🔄 Translate") | |
| clear_btn = gr.Button("🗑️ Clear") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox(lines=5, label="Translated Output", interactive=False, show_copy_button=True) | |
| translate_btn.click(fn=translate_text, inputs=[text_input, direction, model_selector], outputs=output_text) | |
| clear_btn.click(fn=clear_fields, outputs=[text_input, output_text]) | |
| gr.Markdown( | |
| "<p style='text-align: center;'>💬 Feedback? " | |
| "<a href='https://forms.gle/WD1sFUD6JHSxrr3u6' target='_blank' style='color: #2980b9;'>Click here to help improve</a></p>" | |
| ) | |
| # Launch | |
| iface.launch() | |