|
|
import tensorflow as tf
|
|
|
from .encoder import Encoder
|
|
|
from .decoder import Decoder
|
|
|
from tensorflow.keras.layers import Dense
|
|
|
|
|
|
@tf.keras.utils.register_keras_serializable()
|
|
|
class Transformer(tf.keras.Model):
|
|
|
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
|
|
|
target_vocab_size, max_tokens, dropout_rate=0.1, **kwargs):
|
|
|
super(Transformer, self).__init__(**kwargs)
|
|
|
self.num_layers = num_layers
|
|
|
self.d_model = d_model
|
|
|
self.num_heads = num_heads
|
|
|
self.dff = dff
|
|
|
self.input_vocab_size = input_vocab_size
|
|
|
self.target_vocab_size = target_vocab_size
|
|
|
self.max_tokens = max_tokens
|
|
|
self.dropout_rate = dropout_rate
|
|
|
|
|
|
self.encoder = Encoder(num_layers, d_model, num_heads, dff,
|
|
|
input_vocab_size, max_tokens, dropout_rate)
|
|
|
self.decoder = Decoder(num_layers, d_model, num_heads, dff,
|
|
|
target_vocab_size, max_tokens, dropout_rate)
|
|
|
self.final_layer = Dense(target_vocab_size)
|
|
|
|
|
|
def call(self, inputs, training=None):
|
|
|
enc_input, dec_input = inputs
|
|
|
enc_padding_mask = self.create_padding_mask(enc_input)
|
|
|
look_ahead_mask = self.create_look_ahead_mask(tf.shape(dec_input)[1])
|
|
|
dec_padding_mask = self.create_padding_mask(enc_input)
|
|
|
enc_output = self.encoder(enc_input, training=training, mask=enc_padding_mask)
|
|
|
dec_output = self.decoder(dec_input, enc_output, training=training,
|
|
|
look_ahead_mask=look_ahead_mask,
|
|
|
padding_mask=dec_padding_mask)
|
|
|
final_output = self.final_layer(dec_output)
|
|
|
return final_output
|
|
|
|
|
|
def create_padding_mask(self, seq):
|
|
|
mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
|
|
|
return mask[:, tf.newaxis, tf.newaxis, :]
|
|
|
|
|
|
def create_look_ahead_mask(self, size):
|
|
|
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
|
|
|
return mask
|
|
|
|
|
|
def get_config(self):
|
|
|
config = super().get_config()
|
|
|
config.update({
|
|
|
'num_layers': self.num_layers,
|
|
|
'd_model': self.d_model,
|
|
|
'num_heads': self.num_heads,
|
|
|
'dff': self.dff,
|
|
|
'input_vocab_size': self.input_vocab_size,
|
|
|
'target_vocab_size': self.target_vocab_size,
|
|
|
'max_tokens': self.max_tokens,
|
|
|
'dropout_rate': self.dropout_rate
|
|
|
})
|
|
|
return config |