import tensorflow as tf from .encoder import Encoder from .decoder import Decoder from tensorflow.keras.layers import Dense @tf.keras.utils.register_keras_serializable() class Transformer(tf.keras.Model): def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_tokens, dropout_rate=0.1, **kwargs): super(Transformer, self).__init__(**kwargs) self.num_layers = num_layers self.d_model = d_model self.num_heads = num_heads self.dff = dff self.input_vocab_size = input_vocab_size self.target_vocab_size = target_vocab_size self.max_tokens = max_tokens self.dropout_rate = dropout_rate self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, max_tokens, dropout_rate) self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, max_tokens, dropout_rate) self.final_layer = Dense(target_vocab_size) def call(self, inputs, training=None): enc_input, dec_input = inputs enc_padding_mask = self.create_padding_mask(enc_input) look_ahead_mask = self.create_look_ahead_mask(tf.shape(dec_input)[1]) dec_padding_mask = self.create_padding_mask(enc_input) enc_output = self.encoder(enc_input, training=training, mask=enc_padding_mask) dec_output = self.decoder(dec_input, enc_output, training=training, look_ahead_mask=look_ahead_mask, padding_mask=dec_padding_mask) final_output = self.final_layer(dec_output) return final_output def create_padding_mask(self, seq): mask = tf.cast(tf.math.equal(seq, 0), tf.float32) return mask[:, tf.newaxis, tf.newaxis, :] def create_look_ahead_mask(self, size): mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) return mask def get_config(self): config = super().get_config() config.update({ 'num_layers': self.num_layers, 'd_model': self.d_model, 'num_heads': self.num_heads, 'dff': self.dff, 'input_vocab_size': self.input_vocab_size, 'target_vocab_size': self.target_vocab_size, 'max_tokens': self.max_tokens, 'dropout_rate': self.dropout_rate }) return config