File size: 1,923 Bytes
df08b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import tensorflow as tf
import numpy as np
from models.transformer import Transformer
from models.encoder import Encoder
from models.decoder import Decoder
from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network
from models.utils import masked_loss, masked_accuracy

class Translator(tf.Module):
    def __init__(self, en_tokenizer, vi_tokenizer, transformer, max_tokens=512):
        self.tokenizers_en = en_tokenizer
        self.tokenizers_vi = vi_tokenizer
        self.transformer = transformer
        self.max_tokens = max_tokens

    def __call__(self, sentence, max_length=None):
        if max_length is None:
            max_length = self.max_tokens
            
        sentence = self.tokenizers_en.texts_to_sequences([sentence])
        sentence = sentence[0] + np.zeros(self.max_tokens - len(sentence[0]), 
                                       dtype=np.int32).tolist()
        sentence = tf.convert_to_tensor([sentence])
        encoder_input = sentence

        start = self.tokenizers_vi.texts_to_sequences(["<sos>"])[0]
        end = self.tokenizers_vi.texts_to_sequences(["<eos>"])[0]

        output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
        output_array = output_array.write(0, start)

        for i in tf.range(max_length):
            output = tf.transpose(output_array.stack())
            predictions = self.transformer([encoder_input, output], training=False)
            predictions = predictions[:, -1:, :]
            predicted_id = tf.argmax(predictions, axis=-1, output_type=tf.int32)
            output_array = output_array.write(i+1, predicted_id[0])

            if predicted_id == end:
                break

        output = tf.transpose(output_array.stack())
        text = self.tokenizers_vi.sequences_to_texts(output.numpy().tolist())[0]
        return text