MegaTronX's picture
Update app.py
3356f4d verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
# Set cache directory for Spaces
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
class HunyuanTranslator:
def __init__(self, model_name: str = "tencent/Hunyuan-MT-7B-fp8"):
self.model_name = model_name
self.tokenizer = None
self.model = None
self._load_model()
def _load_model(self):
"""Load the pre-quantized FP8 model"""
print("Loading Hunyuan-MT FP8 model...")
try:
# Load tokenizer first
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir='/tmp/cache',
trust_remote_code=True
)
# For Compressed Tensors models, use the standard from_pretrained
# The quantization is automatically handled by the model files
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto",
torch_dtype=torch.float16, # Use fp16 as base dtype
trust_remote_code=True,
cache_dir='/tmp/cache'
)
print("FP8 model loaded successfully!")
print(f"Model device: {self.model.device}")
print(f"Model dtype: {next(self.model.parameters()).dtype}")
except Exception as e:
print(f"Error loading model: {e}")
raise Exception(f"Could not load the Hunyuan-MT model: {str(e)}")
def translate_ja_to_en(self, input_text: str) -> str:
"""Translate Japanese to English using FP8 model"""
if not input_text or input_text.strip() == "":
return "Please enter some Japanese text to translate."
# Limit input length for Spaces
if len(input_text) > 1500:
return "Input too long. Please keep under 1500 characters for this demo."
try:
# Clean and prepare the input text
input_text = input_text.strip()
# Create a clear translation prompt
prompt = f"""Translate the following Japanese text to English. Provide only the translation without any additional explanations or notes.
Japanese: {input_text}
English:"""
# Tokenize the input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024
)
# Move inputs to the same device as model
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate translation
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
num_return_sequences=1
)
# Decode the output
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the translation part (remove the prompt)
if prompt in generated_text:
translation = generated_text.replace(prompt, "").strip()
else:
# If prompt isn't found, try to extract after "English:"
if "English:" in generated_text:
translation = generated_text.split("English:")[-1].strip()
else:
translation = generated_text.strip()
# Clean up the translation
translation = translation.split('\n')[0].strip() # Take first line only
translation = translation.replace('"', '').strip()
return translation if translation else "No translation generated. Please try again."
except Exception as e:
return f"Error during translation: {str(e)}"
def create_translation_interface():
"""Create the Gradio interface for Japanese to English translation"""
# Initialize translator
try:
translator = HunyuanTranslator()
def translate_function(input_text):
return translator.translate_ja_to_en(input_text)
except Exception as e:
print(f"Failed to initialize translator: {e}")
def translate_function(input_text):
return f"Model initialization failed: {str(e)}\n\nPlease check the Space logs for details."
# Custom CSS for better appearance
custom_css = """
.gradio-container {
max-width: 900px !important;
margin: 0 auto;
}
.container {
max-width: 900px;
margin: auto;
padding: 20px;
}
.japanese-text {
font-family: "Hiragino Sans", "Yu Gothic", "Meiryo", sans-serif;
}
.translation-box {
border-left: 3px solid #4CAF50;
padding-left: 15px;
}
"""
# Create Gradio interface
with gr.Blocks(
title="Japanese to English Translation - Hunyuan-MT FP8",
theme=gr.themes.Soft(),
css=custom_css
) as demo:
gr.Markdown(
"""
# 🇯🇵 → 🇺🇸 Japanese to English Translation
**Model:** `tencent/Hunyuan-MT-7B-fp8` • **Technology:** FP8 Quantization
*Fast, high-quality Japanese to English translation using optimized FP8 model*
"""
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
gr.Markdown("### 📥 Japanese Input")
input_text = gr.Textbox(
label="",
placeholder="日本語のテキストを入力してください...\n(Enter Japanese text here)",
lines=6,
max_lines=8,
show_copy_button=True,
elem_classes=["japanese-text"]
)
with gr.Row():
translate_btn = gr.Button(
"🚀 Translate",
variant="primary",
size="lg",
scale=2
)
clear_btn = gr.Button(
"🗑️ Clear",
variant="secondary",
size="lg",
scale=1
)
with gr.Column(scale=1):
gr.Markdown("### 📤 English Translation")
output_text = gr.Textbox(
label="",
placeholder="Translation will appear here...",
lines=6,
max_lines=8,
show_copy_button=True,
elem_classes=["translation-box"]
)
# Examples section
gr.Markdown("### 💡 Example Translations")
examples = gr.Examples(
examples=[
["こんにちは、元気ですか?"],
["今日は良い天気ですね。"],
["機械学習と人工知能は現代技術の重要な分野です。"],
["このレストランの料理はとても美味しいです。"],
["明日の会議は午後二時から始まります。"],
["日本の文化は非常に興味深いと思います。"],
["新しいプロジェクトの提案書を作成しました。"],
["電車の遅延により、到着が30分ほど遅れます。"]
],
inputs=input_text,
outputs=output_text,
fn=translate_function,
cache_examples=False,
label="Click any example to try:"
)
# Connect the buttons
translate_btn.click(
fn=translate_function,
inputs=input_text,
outputs=output_text,
api_name="translate"
)
clear_btn.click(
fn=lambda: ("", ""),
inputs=[],
outputs=[input_text, output_text]
)
# Also allow Enter key to trigger translation
input_text.submit(
fn=translate_function,
inputs=input_text,
outputs=output_text
)
# Technical details
gr.Markdown(
"""
---
### 🛠️ Technical Information
**Model Details:**
- **Base Model**: Hunyuan-MT 7B
- **Quantization**: FP8 (8-bit floating point)
- **Memory Usage**: ~3-4GB
- **Specialization**: Japanese ↔ English translation
**Optimization Features:**
- ✅ FP8 quantization for faster inference
- ✅ GPU acceleration support
- ✅ Efficient memory usage
**Usage Tips:**
- Keep inputs under 1500 characters for best results
- Translation takes 5-15 seconds typically
- Model works best with complete sentences
- Handles technical and casual Japanese well
"""
)
return demo
# For Hugging Face Spaces compatibility
def get_space_app():
"""Function that returns the Gradio app for Spaces"""
return create_translation_interface()
# Launch the app
if __name__ == "__main__":
demo = create_translation_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)