kafarasi commited on
Commit
fd3aa38
·
verified ·
1 Parent(s): 664221e

Rename backup to backup_selectmodel

Browse files
Files changed (2) hide show
  1. backup +0 -0
  2. backup_selectmodel +151 -0
backup DELETED
File without changes
backup_selectmodel ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MarianMTModel, MarianTokenizer
2
+ import gradio as gr
3
+ import torch
4
+ import re
5
+
6
+ # Define available models and BLEU scores
7
+ MODEL_OPTIONS = {
8
+ "Version v1 (EN→RU BLEU: 35.93 | RU→EN BLEU: 41.11)": {
9
+ "key": "Version v1",
10
+ "en_ru": "kafarasi/marian-en-ru-finetuned",
11
+ "ru_en": "kafarasi/marian-ru-en-finetuned"
12
+ },
13
+ "Version v2 (EN→RU BLEU: 36.61 | RU→EN BLEU: 50.00)": {
14
+ "key": "Version v2",
15
+ "en_ru": "kafarasi/marian-en-ru-finetunedv2",
16
+ "ru_en": "kafarasi/marian-ru-en-finetunedv2"
17
+ },
18
+ "Version v3 (EN→RU BLEU: 37.52 | RU→EN BLEU: 47.06)": {
19
+ "key": "Version v4",
20
+ "en_ru": "kafarasi/marian-en-ru-finetunedv4",
21
+ "ru_en": "kafarasi/marian-ru-en-finetunedv4"
22
+ }
23
+ }
24
+
25
+ # Auto-select device
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ use_float16 = device.type == "cuda"
28
+ loaded_models = {}
29
+
30
+ # Load and cache model/tokenizer
31
+ def get_model_and_tokenizer(display_label, direction):
32
+ version_info = MODEL_OPTIONS[display_label]
33
+ version_key = version_info["key"]
34
+ cache_key = f"{version_key}_{direction}"
35
+
36
+ if cache_key in loaded_models:
37
+ return loaded_models[cache_key]
38
+
39
+ model_name = version_info["en_ru"] if direction == "English → Runyankore" else version_info["ru_en"]
40
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
41
+
42
+ model = MarianMTModel.from_pretrained(
43
+ model_name, torch_dtype=torch.float16 if use_float16 else torch.float32
44
+ ).to(device)
45
+
46
+ loaded_models[cache_key] = (tokenizer, model)
47
+ return tokenizer, model
48
+
49
+ # Sentence splitter
50
+ def split_sentences(text):
51
+ sentences = re.split(r'([.!?])\s+', text.strip())
52
+ merged = [sentences[i] + sentences[i + 1] for i in range(0, len(sentences) - 1, 2)]
53
+ if len(sentences) % 2 != 0:
54
+ merged.append(sentences[-1])
55
+ return merged
56
+
57
+ # Main translation logic
58
+ def translate_text(text, direction, version_label):
59
+ tokenizer, model = get_model_and_tokenizer(version_label, direction)
60
+ sentences = split_sentences(text)
61
+ sentences = [s.strip() for s in sentences if s.strip()]
62
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
63
+ with torch.no_grad():
64
+ outputs = model.generate(**inputs, max_length=256)
65
+ translations = [tokenizer.decode(t, skip_special_tokens=True) for t in outputs]
66
+ return " ".join(translations)
67
+
68
+ def clear_fields():
69
+ return "", ""
70
+
71
+ # UI CSS
72
+ custom_css = """
73
+ body {
74
+ background: linear-gradient(to right, #e3f2fd, #fce4ec);
75
+ font-family: 'Segoe UI', sans-serif;
76
+ }
77
+ h1 {
78
+ color: #2c3e50;
79
+ font-size: 32px;
80
+ font-weight: bold;
81
+ margin-bottom: 5px;
82
+ }
83
+ p {
84
+ font-size: 16px;
85
+ color: #4e4e4e;
86
+ }
87
+ .gradio-container {
88
+ max-width: 3100px;
89
+ margin: auto;
90
+ padding: 30px;
91
+ background: #ffffff;
92
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
93
+ border-radius: 16px;
94
+ }
95
+ textarea, input {
96
+ font-size: 16px !important;
97
+ border: 2px solid #2980b9;
98
+ border-radius: 8px !important;
99
+ }
100
+ .gr-button {
101
+ background-color: #3498db !important;
102
+ color: white !important;
103
+ border-radius: 8px !important;
104
+ font-size: 16px !important;
105
+ padding: 10px 20px !important;
106
+ transition: background-color 0.3s ease;
107
+ }
108
+ .gr-button:hover {
109
+ background-color: #2c81ba !important;
110
+ }
111
+ """
112
+
113
+ # Gradio UI
114
+ with gr.Blocks(css=custom_css) as iface:
115
+ gr.Markdown("<h1 style='text-align: center;'>Runyankore ↔ English Translator</h1>")
116
+ gr.Markdown(
117
+ "<p style='text-align: center;'>Choose a model version to compare translation quality. "
118
+ "You can also share suggestions to help us improve accuracy.</p>"
119
+ )
120
+
121
+ with gr.Row():
122
+ with gr.Column(scale=1):
123
+ model_selector = gr.Dropdown(
124
+ list(MODEL_OPTIONS.keys()), label="Model Version", value=list(MODEL_OPTIONS.keys())[1]
125
+ )
126
+ text_input = gr.Textbox(
127
+ lines=5, label="Input Text", placeholder="Enter paragraph for translation...",
128
+ interactive=True, show_copy_button=True
129
+ )
130
+ direction = gr.Radio(["English → Runyankore", "Runyankore → English"], label="Translation Direction")
131
+
132
+ with gr.Row():
133
+ translate_btn = gr.Button("🔄 Translate")
134
+ clear_btn = gr.Button("🗑️ Clear")
135
+
136
+ with gr.Column(scale=1):
137
+ output_text = gr.Textbox(
138
+ lines=5, label="Translated Output", interactive=False, show_copy_button=True
139
+ )
140
+
141
+ translate_btn.click(fn=translate_text, inputs=[text_input, direction, model_selector], outputs=output_text)
142
+ clear_btn.click(fn=clear_fields, outputs=[text_input, output_text])
143
+
144
+ gr.Markdown(
145
+ "<p style='text-align: center;'>💡 Help us improve! "
146
+ "<a href='https://forms.gle/WD1sFUD6JHSxrr3u6' target='_blank' style='color: #2980b9; font-weight: bold;'>"
147
+ "Click here to submit feedback or corrections</a></p>"
148
+ )
149
+
150
+ # Launch app (disable analytics + example caching for speed)
151
+ iface.launch()