danhtran2mind's picture
Update app.py
dd680c4 verified
import gradio as gr
import tensorflow as tf
from translator import Translator
from utils import tokenizer_utils
from utils.preprocessing import input_processing, output_processing
from models.transformer import Transformer
from models.encoder import Encoder
from models.decoder import Decoder
from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network
from models.utils import masked_loss, masked_accuracy
def load_model_and_tokenizers(model_path="ckpts/en_vi_translation.keras"):
"""
Load the pre-trained model and tokenizers.
Args:
model_path (str): Path to the pre-trained model file.
Returns:
model: Loaded TensorFlow model.
en_tokenizer: English tokenizer.
vi_tokenizer: Vietnamese tokenizer.
"""
# Define custom objects for the model
custom_objects = {
"Transformer": Transformer,
"Encoder": Encoder,
"Decoder": Decoder,
"EncoderLayer": EncoderLayer,
"DecoderLayer": DecoderLayer,
"MultiHeadAttention": MultiHeadAttention,
"point_wise_feed_forward_network": point_wise_feed_forward_network,
"masked_loss": masked_loss,
"masked_accuracy": masked_accuracy,
}
# Load the model
try:
model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)
print("Model loaded successfully.")
except Exception as e:
raise Exception(f"Failed to load model: {str(e)}")
# Load tokenizers
try:
en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers()
print("Tokenizers loaded successfully.")
except Exception as e:
raise Exception(f"Failed to load tokenizers: {str(e)}")
return model, en_tokenizer, vi_tokenizer
def translate_sentence(sentence, model, en_tokenizer, vi_tokenizer):
"""
Translate a single English sentence to Vietnamese.
Args:
sentence (str): English sentence to translate.
model: Pre-trained translation model.
en_tokenizer: English tokenizer.
vi_tokenizer: Vietnamese tokenizer.
Returns:
str: Translated Vietnamese sentence.
"""
if not sentence.strip():
return "Please provide a valid sentence."
# Initialize translator
translator = Translator(en_tokenizer, vi_tokenizer, model)
# Process and translate
processed_sentence = input_processing(sentence)
translated_text = translator(processed_sentence)
translated_text = output_processing(translated_text)
return translated_text
# Load model and tokenizers once at startup
try:
model, en_tokenizer, vi_tokenizer = load_model_and_tokenizers()
except Exception as e:
raise Exception(f"Initialization failed: {str(e)}")
# Define Gradio interface
def gradio_translate(sentence):
"""
Gradio-compatible translation function.
Args:
sentence (str): Input English sentence.
Returns:
str: Translated Vietnamese sentence.
"""
return translate_sentence(sentence, model, en_tokenizer, vi_tokenizer)
# Create Gradio interface
iface = gr.Interface(
fn=gradio_translate,
inputs=gr.Textbox(
label="Enter English Sentence",
placeholder="Type an English sentence to translate to Vietnamese...",
lines=2
),
outputs=gr.Textbox(
label="Translated Vietnamese Sentence"
),
title="English to Vietnamese Translation Transformer using TensorFlow",
description=(
"English to Vietnamese Translation Transformer using TensorFlow from Scratch"
"Enter an English sentence to translate it to Vietnames. "
"Example: 'Hello, world!'"
),
examples=[
[
"For at least six centuries, residents along a lake in the mountains of central Japan "
"have marked the depth of winter by celebrating the return of a natural phenomenon "
"once revered as the trail of a wandering god."
],
["Hello, world!"],
["The sun is shining."]
]
)
# Launch the app
if __name__ == "__main__":
iface.launch()