|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache' |
|
|
|
|
|
class HunyuanTranslator: |
|
|
def __init__(self, model_name: str = "tencent/Hunyuan-MT-7B-fp8"): |
|
|
self.model_name = model_name |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the pre-quantized FP8 model""" |
|
|
print("Loading Hunyuan-MT FP8 model...") |
|
|
|
|
|
try: |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_name, |
|
|
cache_dir='/tmp/cache', |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
cache_dir='/tmp/cache' |
|
|
) |
|
|
|
|
|
print("FP8 model loaded successfully!") |
|
|
print(f"Model device: {self.model.device}") |
|
|
print(f"Model dtype: {next(self.model.parameters()).dtype}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
raise Exception(f"Could not load the Hunyuan-MT model: {str(e)}") |
|
|
|
|
|
def translate_ja_to_en(self, input_text: str) -> str: |
|
|
"""Translate Japanese to English using FP8 model""" |
|
|
if not input_text or input_text.strip() == "": |
|
|
return "Please enter some Japanese text to translate." |
|
|
|
|
|
|
|
|
if len(input_text) > 1500: |
|
|
return "Input too long. Please keep under 1500 characters for this demo." |
|
|
|
|
|
try: |
|
|
|
|
|
input_text = input_text.strip() |
|
|
|
|
|
|
|
|
prompt = f"""Translate the following Japanese text to English. Provide only the translation without any additional explanations or notes. |
|
|
|
|
|
Japanese: {input_text} |
|
|
|
|
|
English:""" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=1024 |
|
|
) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.1, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
num_return_sequences=1 |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if prompt in generated_text: |
|
|
translation = generated_text.replace(prompt, "").strip() |
|
|
else: |
|
|
|
|
|
if "English:" in generated_text: |
|
|
translation = generated_text.split("English:")[-1].strip() |
|
|
else: |
|
|
translation = generated_text.strip() |
|
|
|
|
|
|
|
|
translation = translation.split('\n')[0].strip() |
|
|
translation = translation.replace('"', '').strip() |
|
|
|
|
|
return translation if translation else "No translation generated. Please try again." |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error during translation: {str(e)}" |
|
|
|
|
|
def create_translation_interface(): |
|
|
"""Create the Gradio interface for Japanese to English translation""" |
|
|
|
|
|
|
|
|
try: |
|
|
translator = HunyuanTranslator() |
|
|
|
|
|
def translate_function(input_text): |
|
|
return translator.translate_ja_to_en(input_text) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to initialize translator: {e}") |
|
|
|
|
|
def translate_function(input_text): |
|
|
return f"Model initialization failed: {str(e)}\n\nPlease check the Space logs for details." |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
max-width: 900px !important; |
|
|
margin: 0 auto; |
|
|
} |
|
|
.container { |
|
|
max-width: 900px; |
|
|
margin: auto; |
|
|
padding: 20px; |
|
|
} |
|
|
.japanese-text { |
|
|
font-family: "Hiragino Sans", "Yu Gothic", "Meiryo", sans-serif; |
|
|
} |
|
|
.translation-box { |
|
|
border-left: 3px solid #4CAF50; |
|
|
padding-left: 15px; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Japanese to English Translation - Hunyuan-MT FP8", |
|
|
theme=gr.themes.Soft(), |
|
|
css=custom_css |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🇯🇵 → 🇺🇸 Japanese to English Translation |
|
|
**Model:** `tencent/Hunyuan-MT-7B-fp8` • **Technology:** FP8 Quantization |
|
|
|
|
|
*Fast, high-quality Japanese to English translation using optimized FP8 model* |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(equal_height=False): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📥 Japanese Input") |
|
|
input_text = gr.Textbox( |
|
|
label="", |
|
|
placeholder="日本語のテキストを入力してください...\n(Enter Japanese text here)", |
|
|
lines=6, |
|
|
max_lines=8, |
|
|
show_copy_button=True, |
|
|
elem_classes=["japanese-text"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
translate_btn = gr.Button( |
|
|
"🚀 Translate", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
scale=2 |
|
|
) |
|
|
clear_btn = gr.Button( |
|
|
"🗑️ Clear", |
|
|
variant="secondary", |
|
|
size="lg", |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📤 English Translation") |
|
|
output_text = gr.Textbox( |
|
|
label="", |
|
|
placeholder="Translation will appear here...", |
|
|
lines=6, |
|
|
max_lines=8, |
|
|
show_copy_button=True, |
|
|
elem_classes=["translation-box"] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### 💡 Example Translations") |
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["こんにちは、元気ですか?"], |
|
|
["今日は良い天気ですね。"], |
|
|
["機械学習と人工知能は現代技術の重要な分野です。"], |
|
|
["このレストランの料理はとても美味しいです。"], |
|
|
["明日の会議は午後二時から始まります。"], |
|
|
["日本の文化は非常に興味深いと思います。"], |
|
|
["新しいプロジェクトの提案書を作成しました。"], |
|
|
["電車の遅延により、到着が30分ほど遅れます。"] |
|
|
], |
|
|
inputs=input_text, |
|
|
outputs=output_text, |
|
|
fn=translate_function, |
|
|
cache_examples=False, |
|
|
label="Click any example to try:" |
|
|
) |
|
|
|
|
|
|
|
|
translate_btn.click( |
|
|
fn=translate_function, |
|
|
inputs=input_text, |
|
|
outputs=output_text, |
|
|
api_name="translate" |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: ("", ""), |
|
|
inputs=[], |
|
|
outputs=[input_text, output_text] |
|
|
) |
|
|
|
|
|
|
|
|
input_text.submit( |
|
|
fn=translate_function, |
|
|
inputs=input_text, |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### 🛠️ Technical Information |
|
|
|
|
|
**Model Details:** |
|
|
- **Base Model**: Hunyuan-MT 7B |
|
|
- **Quantization**: FP8 (8-bit floating point) |
|
|
- **Memory Usage**: ~3-4GB |
|
|
- **Specialization**: Japanese ↔ English translation |
|
|
|
|
|
**Optimization Features:** |
|
|
- ✅ FP8 quantization for faster inference |
|
|
- ✅ GPU acceleration support |
|
|
- ✅ Efficient memory usage |
|
|
|
|
|
**Usage Tips:** |
|
|
- Keep inputs under 1500 characters for best results |
|
|
- Translation takes 5-15 seconds typically |
|
|
- Model works best with complete sentences |
|
|
- Handles technical and casual Japanese well |
|
|
""" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
def get_space_app(): |
|
|
"""Function that returns the Gradio app for Spaces""" |
|
|
return create_translation_interface() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_translation_interface() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True |
|
|
) |