import tensorflow as tf from tensorflow.keras.layers import Layer, Dense, Dropout, LayerNormalization @tf.keras.utils.register_keras_serializable() class EncoderLayer(Layer): def __init__(self, d_model, num_heads, dff, dropout_rate, **kwargs): super(EncoderLayer, self).__init__(**kwargs) self.mha = MultiHeadAttention(d_model, num_heads) self.ffn = point_wise_feed_forward_network(d_model, dff) self.layernorm1 = LayerNormalization(epsilon=1e-6) self.layernorm2 = LayerNormalization(epsilon=1e-6) self.dropout1 = Dropout(dropout_rate) self.dropout2 = Dropout(dropout_rate) def call(self, x, training=None, mask=None): attn_output, _ = self.mha(x, x, x, mask) attn_output = self.dropout1(attn_output, training=training) out1 = self.layernorm1(x + attn_output) ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output, training=training) out2 = self.layernorm2(out1 + ffn_output) return out2 def get_config(self): config = super().get_config() config.update({ 'd_model': self.mha.d_model, 'num_heads': self.mha.num_heads, 'dff': self.ffn.layers[0].units, 'dropout_rate': self.dropout1.rate }) return config @tf.keras.utils.register_keras_serializable() class DecoderLayer(Layer): def __init__(self, d_model, num_heads, dff, dropout_rate, **kwargs): super(DecoderLayer, self).__init__(**kwargs) self.mha1 = MultiHeadAttention(d_model, num_heads) self.mha2 = MultiHeadAttention(d_model, num_heads) self.ffn = point_wise_feed_forward_network(d_model, dff) self.layernorm1 = LayerNormalization(epsilon=1e-6) self.layernorm2 = LayerNormalization(epsilon=1e-6) self.layernorm3 = LayerNormalization(epsilon=1e-6) self.dropout1 = Dropout(dropout_rate) self.dropout2 = Dropout(dropout_rate) self.dropout3 = Dropout(dropout_rate) def call(self, x, enc_output, training=None, look_ahead_mask=None, padding_mask=None): attn1, _ = self.mha1(x, x, x, look_ahead_mask) attn1 = self.dropout1(attn1, training=training) out1 = self.layernorm1(x + attn1) attn2, _ = self.mha2(enc_output, enc_output, out1, padding_mask) attn2 = self.dropout2(attn2, training=training) out2 = self.layernorm2(out1 + attn2) ffn_output = self.ffn(out2) ffn_output = self.dropout3(ffn_output, training=training) out3 = self.layernorm3(out2 + ffn_output) return out3 def get_config(self): config = super().get_config() config.update({ 'd_model': self.mha1.d_model, 'num_heads': self.mha1.num_heads, 'dff': self.ffn.layers[0].units, 'dropout_rate': self.dropout1.rate }) return config @tf.keras.utils.register_keras_serializable() class MultiHeadAttention(Layer): def __init__(self, d_model, num_heads, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.d_model = d_model assert d_model % num_heads == 0 self.depth = d_model // num_heads self.wq = Dense(d_model) self.wk = Dense(d_model) self.wv = Dense(d_model) self.dense = Dense(d_model) def split_heads(self, x, batch_size): x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, v, k, q, mask=None): batch_size = tf.shape(q)[0] q = self.wq(q) k = self.wk(k) v = self.wv(v) q = self.split_heads(q, batch_size) k = self.split_heads(k, batch_size) v = self.split_heads(v, batch_size) scaled_attention, _ = self.scaled_dot_product_attention(q, k, v, mask) scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) output = self.dense(concat_attention) return output, _ def scaled_dot_product_attention(self, q, k, v, mask): matmul_qk = tf.matmul(q, k, transpose_b=True) dk = tf.cast(tf.shape(k)[-1], tf.float32) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) if mask is not None: scaled_attention_logits += (mask * -1e9) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) output = tf.matmul(attention_weights, v) return output, attention_weights def get_config(self): config = super().get_config() config.update({ 'd_model': self.d_model, 'num_heads': self.num_heads }) return config def point_wise_feed_forward_network(d_model, dff): return tf.keras.Sequential([ Dense(dff, activation='relu'), Dense(d_model) ])