File size: 4,454 Bytes
fd3aa38
 
 
 
 
 
c0825f1
 
fd3aa38
 
 
 
c0825f1
fd3aa38
 
 
 
c0825f1
fd3aa38
 
 
 
 
c0825f1
fd3aa38
 
 
 
c0825f1
fd3aa38
 
 
 
 
 
 
 
 
 
 
c0825f1
 
fd3aa38
 
 
 
 
c0825f1
fd3aa38
 
c0825f1
 
 
 
fd3aa38
c0825f1
 
fd3aa38
c0825f1
fd3aa38
 
 
c0825f1
fd3aa38
 
 
 
 
 
 
 
c0825f1
 
fd3aa38
 
c0825f1
fd3aa38
 
 
 
c0825f1
fd3aa38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0825f1
fd3aa38
c0825f1
fd3aa38
c0825f1
fd3aa38
 
 
 
c0825f1
 
fd3aa38
 
 
 
 
 
c0825f1
fd3aa38
 
 
 
 
c0825f1
 
fd3aa38
 
c0825f1
fd3aa38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
import torch

# Define available models and BLEU scores
MODEL_OPTIONS = {
    "Version v1 (EN→RU BLEU: 35.93 | RU→EN BLEU: 41.11)": {
        "key": "v1",
        "en_ru": "kafarasi/marian-en-ru-finetuned",
        "ru_en": "kafarasi/marian-ru-en-finetuned"
    },
    "Version v2 (EN→RU BLEU: 36.61 | RU→EN BLEU: 50.00)": {
        "key": "v2",
        "en_ru": "kafarasi/marian-en-ru-finetunedv2",
        "ru_en": "kafarasi/marian-ru-en-finetunedv2"
    },
    "Version v3 (EN→RU BLEU: 37.52 | RU→EN BLEU: 47.06)": {
        "key": "v3",
        "en_ru": "kafarasi/marian-en-ru-finetunedv4",
        "ru_en": "kafarasi/marian-ru-en-finetunedv4"
    }
}

# Select device (CPU for Hugging Face Spaces)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_float16 = device.type == "cuda"
loaded_models = {}

# Load and cache models
def get_model_and_tokenizer(display_label, direction):
    version_info = MODEL_OPTIONS[display_label]
    version_key = version_info["key"]
    cache_key = f"{version_key}_{direction}"

    if cache_key in loaded_models:
        return loaded_models[cache_key]

    model_name = version_info["en_ru"] if direction == "English → Runyankore" else version_info["ru_en"]
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if use_float16 else torch.float32
    ).to(device)

    loaded_models[cache_key] = (tokenizer, model)
    return tokenizer, model

# Translation function (no splitting)
def translate_text(text, direction, version_label):
    tokenizer, model = get_model_and_tokenizer(version_label, direction)
    text = text.strip()
    if not text:
        return ""
    inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=128, num_beams=1)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Clear inputs
def clear_fields():
    return "", ""

# UI Styling
custom_css = """
body {
    background: linear-gradient(to right, #e3f2fd, #fce4ec);
    font-family: 'Segoe UI', sans-serif;
}
h1 {
    color: #2c3e50;
    font-size: 32px;
    text-align: center;
    margin-bottom: 10px;
}
p {
    text-align: center;
    font-size: 16px;
    color: #4e4e4e;
}
.gradio-container {
    max-width: 900px;
    margin: auto;
    padding: 30px;
    background: #ffffff;
    box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
    border-radius: 16px;
}
textarea, input {
    font-size: 16px !important;
    border: 2px solid #2980b9;
    border-radius: 8px !important;
}
.gr-button {
    background-color: #3498db !important;
    color: white !important;
    border-radius: 8px !important;
    font-size: 16px !important;
    padding: 10px 20px !important;
    transition: background-color 0.3s ease;
}
.gr-button:hover {
    background-color: #2c81ba !important;
}
"""

# Gradio Interface
with gr.Blocks(css=custom_css) as iface:
    gr.Markdown("<h1>Runyankore ↔ English Translator</h1>")
    gr.Markdown(
        "<p>Select a model version and translation direction. Input text will be translated efficiently even on CPU.</p>"
    )

    with gr.Row():
        with gr.Column(scale=1):
            model_selector = gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model Version", value=list(MODEL_OPTIONS.keys())[1])
            text_input = gr.Textbox(lines=5, label="Input Text", placeholder="Enter text...", interactive=True, show_copy_button=True)
            direction = gr.Radio(["English → Runyankore", "Runyankore → English"], label="Translation Direction")
            with gr.Row():
                translate_btn = gr.Button("🔄 Translate")
                clear_btn = gr.Button("🗑️ Clear")

        with gr.Column(scale=1):
            output_text = gr.Textbox(lines=5, label="Translated Output", interactive=False, show_copy_button=True)

    translate_btn.click(fn=translate_text, inputs=[text_input, direction, model_selector], outputs=output_text)
    clear_btn.click(fn=clear_fields, outputs=[text_input, output_text])

    gr.Markdown(
        "<p style='text-align: center;'>💬 Feedback? "
        "<a href='https://forms.gle/WD1sFUD6JHSxrr3u6' target='_blank' style='color: #2980b9;'>Click here to help improve</a></p>"
    )

# Launch
iface.launch()