danhtran2mind's picture
Upload folder using huggingface_hub
df08b89 verified
import tensorflow as tf
from translator import Translator
from utils import tokenizer_utils
from utils.preprocessing import input_processing, output_processing
from models.transformer import Transformer
from models.encoder import Encoder
from models.decoder import Decoder
from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network
from models.utils import masked_loss, masked_accuracy
import argparse
def main(sentences: list, model: tf.keras.Model, en_tokenizer, vi_tokenizer) -> None:
"""
Translates input English sentences to Vietnamese using a pre-trained model.
Args:
sentences (list): List of English sentences to translate.
model (tf.keras.Model): The pre-trained translation model.
en_tokenizer: English tokenizer.
vi_tokenizer: Vietnamese tokenizer.
"""
# Initialize the translator with tokenizers and the model
translator = Translator(en_tokenizer, vi_tokenizer, model)
# Process and translate each sentence
for sentence in sentences:
processed_sentence = input_processing(sentence)
translated_text = translator(processed_sentence)
translated_text = output_processing(translated_text)
# Display the input and translated text
print("Input:", processed_sentence)
print("Translated:", translated_text)
print("-" * 50)
if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser(
description="Translate English sentences to Vietnamese using a pre-trained transformer model.",
epilog="Example: python translate.py --sentence 'Hello, world!' --sentence 'The sun is shining.'"
)
parser.add_argument(
"--sentence",
type=str,
nargs="*",
default=[
(
"For at least six centuries, residents along a lake in the mountains of central Japan "
"have marked the depth of winter by celebrating the return of a natural phenomenon "
"once revered as the trail of a wandering god."
)
],
help="One or more English sentences to translate (default: provided example sentence)"
)
parser.add_argument(
"--model_path",
type=str,
default="saved_models/en_vi_translation.keras",
help="Path to the pre-trained model file (default: saved_models/en_vi_translation.keras)"
)
# Parse arguments
args = parser.parse_args()
# Define custom objects required for loading the model
custom_objects = {
"Transformer": Transformer,
"Encoder": Encoder,
"Decoder": Decoder,
"EncoderLayer": EncoderLayer,
"DecoderLayer": DecoderLayer,
"MultiHeadAttention": MultiHeadAttention,
"point_wise_feed_forward_network": point_wise_feed_forward_network,
"masked_loss": masked_loss,
"masked_accuracy": masked_accuracy,
}
# Load the pre-trained model once
print("Loading model from:", args.model_path)
loaded_model = tf.keras.models.load_model(
args.model_path, custom_objects=custom_objects
)
print("Model loaded successfully.")
# Load English and Vietnamese tokenizers once
en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers()
# Run the translation for all provided sentences
main(sentences=args.sentence, model=loaded_model, en_tokenizer=en_tokenizer, vi_tokenizer=vi_tokenizer)
# Interactive loop for additional translations
while True:
choice = input("Would you like to translate another sentence? (Y/n): ").strip().lower()
if choice in ['no', 'n', 'quit', 'q']:
print("Exiting the program.")
break
elif choice in ['yes', 'y']:
new_sentence = input("Enter an English sentence to translate: ").strip()
if new_sentence:
main(sentences=[new_sentence], model=loaded_model, en_tokenizer=en_tokenizer, vi_tokenizer=vi_tokenizer)
else:
print("No sentence provided. Please try again.")
else:
print("Invalid input. Please enter 'y' or 'n'.")