|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import spaces |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if HF_TOKEN is None: |
|
|
raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the script.") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
zero = torch.Tensor([0]).to(device) |
|
|
print(f"Device being used: {zero.device}") |
|
|
|
|
|
|
|
|
MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT" |
|
|
SYRIAN_TO_MSA_MODEL = "Omartificial-Intelligence-Space/SHAMI-MT-2MSA" |
|
|
|
|
|
|
|
|
print("Loading MSA to Syrian model...") |
|
|
msa_to_syrian_tokenizer = AutoTokenizer.from_pretrained(MSA_TO_SYRIAN_MODEL) |
|
|
msa_to_syrian_model = AutoModelForSeq2SeqLM.from_pretrained(MSA_TO_SYRIAN_MODEL).to(device) |
|
|
|
|
|
print("Loading Syrian to MSA model...") |
|
|
syrian_to_msa_tokenizer = AutoTokenizer.from_pretrained(SYRIAN_TO_MSA_MODEL) |
|
|
syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL).to(device) |
|
|
|
|
|
print("Models loaded successfully!") |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def translate_msa_to_syrian(text): |
|
|
"""Translate from Modern Standard Arabic to Syrian dialect""" |
|
|
if not text.strip(): |
|
|
return "" |
|
|
|
|
|
try: |
|
|
input_ids = msa_to_syrian_tokenizer(text, return_tensors="pt").input_ids.to(device) |
|
|
outputs = msa_to_syrian_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True) |
|
|
translated_text = msa_to_syrian_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return translated_text |
|
|
except Exception as e: |
|
|
return f"Translation error: {str(e)}" |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def translate_syrian_to_msa(text): |
|
|
"""Translate from Syrian dialect to Modern Standard Arabic""" |
|
|
if not text.strip(): |
|
|
return "" |
|
|
|
|
|
try: |
|
|
input_ids = syrian_to_msa_tokenizer(text, return_tensors="pt").input_ids.to(device) |
|
|
outputs = syrian_to_msa_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True) |
|
|
translated_text = syrian_to_msa_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return translated_text |
|
|
except Exception as e: |
|
|
return f"Translation error: {str(e)}" |
|
|
|
|
|
def bidirectional_translate(text, direction): |
|
|
"""Handle bidirectional translation based on user selection""" |
|
|
if direction == "MSA โ Syrian": |
|
|
return translate_msa_to_syrian(text) |
|
|
elif direction == "Syrian โ MSA": |
|
|
return translate_syrian_to_msa(text) |
|
|
else: |
|
|
return "Please select a translation direction" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SHAMI-MT: Bidirectional Syria Arabic Dialect MT Framework") as demo: |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-bottom: 2rem;"> |
|
|
<h1>๐ SHAMI-MT: Bidirectional Arabic Translation</h1> |
|
|
<p>Translate between Modern Standard Arabic (MSA) and Syrian Dialect</p> |
|
|
<p><strong>Built on AraT5v2-base-1024 architecture</strong></p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.HTML(""" |
|
|
<div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0;"> |
|
|
<h3>๐ Model Information</h3> |
|
|
<ul> |
|
|
<li><strong>Model Type:</strong> Sequence-to-Sequence Translation</li> |
|
|
<li><strong>Base Model:</strong> UBC-NLP/AraT5v2-base-1024</li> |
|
|
<li><strong>Languages:</strong> Arabic (MSA โ Syrian Dialect)</li> |
|
|
<li><strong>Device:</strong> GPU/CPU Auto-detection</li> |
|
|
</ul> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
direction = gr.Dropdown( |
|
|
choices=["MSA โ Syrian", "Syrian โ MSA"], |
|
|
value="MSA โ Syrian", |
|
|
label="Translation Direction" |
|
|
) |
|
|
|
|
|
input_text = gr.Textbox( |
|
|
label="Input Text", |
|
|
placeholder="Enter Arabic text here...", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
translate_btn = gr.Button("๐ Translate", variant="primary") |
|
|
|
|
|
output_text = gr.Textbox( |
|
|
label="Translation", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
|
|
|
translate_btn.click( |
|
|
fn=bidirectional_translate, |
|
|
inputs=[input_text, direction], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["ุฃูุง ูุง ุฃุนุฑู ุฅุฐุง ูุงู ุณูุชู
ูู ู
ู ุงูุญุถูุฑ ุงูููู
ุฃู
ูุง.", "MSA โ Syrian"], |
|
|
["ููู ุญุงููุ", "MSA โ Syrian"], |
|
|
["ู
ุง ุจุนุฑู ุฅุฐุง ุฑุญ ููุฏุฑ ูุฌู ุงูููู
ููุง ูุฃ.", "Syrian โ MSA"], |
|
|
["ุดููููุ", "Syrian โ MSA"] |
|
|
], |
|
|
inputs=[input_text, direction], |
|
|
outputs=output_text, |
|
|
fn=bidirectional_translate |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |