|
|
import gradio as gr |
|
|
import tensorflow as tf |
|
|
from translator import Translator |
|
|
from utils import tokenizer_utils |
|
|
from utils.preprocessing import input_processing, output_processing |
|
|
from models.transformer import Transformer |
|
|
from models.encoder import Encoder |
|
|
from models.decoder import Decoder |
|
|
from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network |
|
|
from models.utils import masked_loss, masked_accuracy |
|
|
|
|
|
def load_model_and_tokenizers(model_path="ckpts/en_vi_translation.keras"): |
|
|
""" |
|
|
Load the pre-trained model and tokenizers. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the pre-trained model file. |
|
|
|
|
|
Returns: |
|
|
model: Loaded TensorFlow model. |
|
|
en_tokenizer: English tokenizer. |
|
|
vi_tokenizer: Vietnamese tokenizer. |
|
|
""" |
|
|
|
|
|
custom_objects = { |
|
|
"Transformer": Transformer, |
|
|
"Encoder": Encoder, |
|
|
"Decoder": Decoder, |
|
|
"EncoderLayer": EncoderLayer, |
|
|
"DecoderLayer": DecoderLayer, |
|
|
"MultiHeadAttention": MultiHeadAttention, |
|
|
"point_wise_feed_forward_network": point_wise_feed_forward_network, |
|
|
"masked_loss": masked_loss, |
|
|
"masked_accuracy": masked_accuracy, |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
model = tf.keras.models.load_model(model_path, custom_objects=custom_objects) |
|
|
print("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to load model: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers() |
|
|
print("Tokenizers loaded successfully.") |
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to load tokenizers: {str(e)}") |
|
|
|
|
|
return model, en_tokenizer, vi_tokenizer |
|
|
|
|
|
def translate_sentence(sentence, model, en_tokenizer, vi_tokenizer): |
|
|
""" |
|
|
Translate a single English sentence to Vietnamese. |
|
|
|
|
|
Args: |
|
|
sentence (str): English sentence to translate. |
|
|
model: Pre-trained translation model. |
|
|
en_tokenizer: English tokenizer. |
|
|
vi_tokenizer: Vietnamese tokenizer. |
|
|
|
|
|
Returns: |
|
|
str: Translated Vietnamese sentence. |
|
|
""" |
|
|
if not sentence.strip(): |
|
|
return "Please provide a valid sentence." |
|
|
|
|
|
|
|
|
translator = Translator(en_tokenizer, vi_tokenizer, model) |
|
|
|
|
|
|
|
|
processed_sentence = input_processing(sentence) |
|
|
translated_text = translator(processed_sentence) |
|
|
translated_text = output_processing(translated_text) |
|
|
|
|
|
return translated_text |
|
|
|
|
|
|
|
|
try: |
|
|
model, en_tokenizer, vi_tokenizer = load_model_and_tokenizers() |
|
|
except Exception as e: |
|
|
raise Exception(f"Initialization failed: {str(e)}") |
|
|
|
|
|
|
|
|
def gradio_translate(sentence): |
|
|
""" |
|
|
Gradio-compatible translation function. |
|
|
|
|
|
Args: |
|
|
sentence (str): Input English sentence. |
|
|
|
|
|
Returns: |
|
|
str: Translated Vietnamese sentence. |
|
|
""" |
|
|
return translate_sentence(sentence, model, en_tokenizer, vi_tokenizer) |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=gradio_translate, |
|
|
inputs=gr.Textbox( |
|
|
label="Enter English Sentence", |
|
|
placeholder="Type an English sentence to translate to Vietnamese...", |
|
|
lines=2 |
|
|
), |
|
|
outputs=gr.Textbox( |
|
|
label="Translated Vietnamese Sentence" |
|
|
), |
|
|
title="English to Vietnamese Translation Transformer using TensorFlow", |
|
|
description=( |
|
|
"English to Vietnamese Translation Transformer using TensorFlow from Scratch" |
|
|
"Enter an English sentence to translate it to Vietnames. " |
|
|
"Example: 'Hello, world!'" |
|
|
), |
|
|
examples=[ |
|
|
[ |
|
|
"For at least six centuries, residents along a lake in the mountains of central Japan " |
|
|
"have marked the depth of winter by celebrating the return of a natural phenomenon " |
|
|
"once revered as the trail of a wandering god." |
|
|
], |
|
|
["Hello, world!"], |
|
|
["The sun is shining."] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |