Spaces:
Runtime error
Runtime error
| ### IMPORTS | |
| import tensorflow as tf | |
| import numpy as np | |
| import einops | |
| import numpy as np | |
| import tqdm | |
| import collections | |
| import re | |
| import string | |
| import pickle | |
| print("import complete") | |
| #========================================================================================================================= | |
| ### UTILITY FUNCTIONS | |
| #========================================================================================================================= | |
| IMAGE_SHAPE=(224, 224, 3) | |
| def custom_standardization(s): | |
| s = tf.strings.lower(s) | |
| s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '') | |
| s = tf.strings.join(['[START]', s, '[END]'], separator=' ') | |
| return s | |
| def load_image(image_path): | |
| img = tf.io.read_file(image_path) | |
| img = tf.io.decode_jpeg(img, channels=3) | |
| img = tf.image.resize(img, IMAGE_SHAPE[:-1]) | |
| return img | |
| def load_image_obj(img): | |
| img = tf.image.resize(img, IMAGE_SHAPE[:-1]) | |
| return img | |
| def masked_loss(labels, preds): | |
| loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds) | |
| mask = (labels != 0) & (loss < 1e8) | |
| mask = tf.cast(mask, loss.dtype) | |
| loss = loss*mask | |
| loss = tf.reduce_sum(loss)/tf.reduce_sum(mask) | |
| return loss | |
| def masked_acc(labels, preds): | |
| mask = tf.cast(labels!=0, tf.float32) | |
| preds = tf.argmax(preds, axis=-1) | |
| labels = tf.cast(labels, tf.int64) | |
| match = tf.cast(preds == labels, mask.dtype) | |
| acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask) | |
| return acc | |
| print("utility complete") | |
| #========================================================================================================================= | |
| ### MODEL CLASS | |
| #========================================================================================================================= | |
| mobilenet = tf.keras.applications.MobileNetV3Small( | |
| input_shape=IMAGE_SHAPE, | |
| include_top=False, | |
| include_preprocessing=True) | |
| mobilenet.trainable=False | |
| class SeqEmbedding(tf.keras.layers.Layer): | |
| def __init__(self, vocab_size, max_length, depth): | |
| super().__init__() | |
| self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth) | |
| self.token_embedding = tf.keras.layers.Embedding( | |
| input_dim=vocab_size, | |
| output_dim=depth, | |
| mask_zero=True) | |
| self.add = tf.keras.layers.Add() | |
| def call(self, seq): | |
| seq = self.token_embedding(seq) # (batch, seq, depth) | |
| x = tf.range(tf.shape(seq)[1]) # (seq) | |
| x = x[tf.newaxis, :] # (1, seq) | |
| x = self.pos_embedding(x) # (1, seq, depth) | |
| return self.add([seq,x]) | |
| class CausalSelfAttention(tf.keras.layers.Layer): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| self.mha = tf.keras.layers.MultiHeadAttention(**kwargs) | |
| # Use Add instead of + so the keras mask propagates through. | |
| self.add = tf.keras.layers.Add() | |
| self.layernorm = tf.keras.layers.LayerNormalization() | |
| def call(self, x): | |
| attn = self.mha(query=x, value=x, | |
| use_causal_mask=True) | |
| x = self.add([x, attn]) | |
| return self.layernorm(x) | |
| class CrossAttention(tf.keras.layers.Layer): | |
| def __init__(self,**kwargs): | |
| super().__init__() | |
| self.mha = tf.keras.layers.MultiHeadAttention(**kwargs) | |
| self.add = tf.keras.layers.Add() | |
| self.layernorm = tf.keras.layers.LayerNormalization() | |
| def call(self, x, y, **kwargs): | |
| attn, attention_scores = self.mha( | |
| query=x, value=y, | |
| return_attention_scores=True) | |
| self.last_attention_scores = attention_scores | |
| x = self.add([x, attn]) | |
| return self.layernorm(x) | |
| class FeedForward(tf.keras.layers.Layer): | |
| def __init__(self, units, dropout_rate=0.1): | |
| super().__init__() | |
| self.seq = tf.keras.Sequential([ | |
| tf.keras.layers.Dense(units=2*units, activation='relu'), | |
| tf.keras.layers.Dense(units=units), | |
| tf.keras.layers.Dropout(rate=dropout_rate), | |
| ]) | |
| self.layernorm = tf.keras.layers.LayerNormalization() | |
| def call(self, x): | |
| x = x + self.seq(x) | |
| return self.layernorm(x) | |
| class DecoderLayer(tf.keras.layers.Layer): | |
| def __init__(self, units, num_heads=1, dropout_rate=0.1): | |
| super().__init__() | |
| self.self_attention = CausalSelfAttention(num_heads=num_heads, | |
| key_dim=units, | |
| dropout=dropout_rate) | |
| self.cross_attention = CrossAttention(num_heads=num_heads, | |
| key_dim=units, | |
| dropout=dropout_rate) | |
| self.ff = FeedForward(units=units, dropout_rate=dropout_rate) | |
| def call(self, inputs, training=False): | |
| in_seq, out_seq = inputs | |
| # Text input | |
| out_seq = self.self_attention(out_seq) | |
| out_seq = self.cross_attention(out_seq, in_seq) | |
| self.last_attention_scores = self.cross_attention.last_attention_scores | |
| out_seq = self.ff(out_seq) | |
| return out_seq | |
| class TokenOutput(tf.keras.layers.Layer): | |
| def __init__(self, tokenizer, banned_tokens=('', '[UNK]', '[START]'), bias=None, **kwargs): | |
| super().__init__() | |
| self.dense = tf.keras.layers.Dense( | |
| units=tokenizer.vocabulary_size(), **kwargs) | |
| self.tokenizer = tokenizer | |
| self.banned_tokens = banned_tokens | |
| self.bias = bias | |
| def adapt(self, ds): | |
| counts = collections.Counter() | |
| vocab_dict = {name: id | |
| for id, name in enumerate(self.tokenizer.get_vocabulary())} | |
| for tokens in tqdm.tqdm(ds): | |
| counts.update(tokens.numpy().flatten()) | |
| counts_arr = np.zeros(shape=(self.tokenizer.vocabulary_size(),)) | |
| counts_arr[np.array(list(counts.keys()), dtype=np.int32)] = list(counts.values()) | |
| counts_arr = counts_arr[:] | |
| for token in self.banned_tokens: | |
| counts_arr[vocab_dict[token]] = 0 | |
| total = counts_arr.sum() | |
| p = counts_arr/total | |
| p[counts_arr==0] = 1.0 | |
| log_p = np.log(p) # log(1) == 0 | |
| entropy = -(log_p*p).sum() | |
| print() | |
| print(f"Uniform entropy: {np.log(self.tokenizer.vocabulary_size()):0.2f}") | |
| print(f"Marginal entropy: {entropy:0.2f}") | |
| self.bias = log_p | |
| self.bias[counts_arr==0] = -1e9 | |
| def call(self, x): | |
| x = self.dense(x) | |
| return x + self.bias | |
| def get_config(self): | |
| config = super(TokenOutput, self).get_config() | |
| config.update({ | |
| "tokenizer": self.tokenizer, | |
| "banned_tokens": self.banned_tokens, | |
| "bias": self.bias, | |
| "dense":self.dense | |
| }) | |
| return config | |
| class Captioner(tf.keras.Model): | |
| def add_method(cls, fun): | |
| setattr(cls, fun.__name__, fun) | |
| return fun | |
| def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1, | |
| units=256, max_length=50, num_heads=1, dropout_rate=0.1): | |
| super().__init__() | |
| self.feature_extractor = feature_extractor | |
| self.tokenizer = tokenizer | |
| self.word_to_index = tf.keras.layers.StringLookup( | |
| mask_token="", | |
| vocabulary=tokenizer.get_vocabulary()) | |
| self.index_to_word = tf.keras.layers.StringLookup( | |
| mask_token="", | |
| vocabulary=tokenizer.get_vocabulary(), | |
| invert=True) | |
| self.seq_embedding = SeqEmbedding( | |
| vocab_size=tokenizer.vocabulary_size(), | |
| depth=units, | |
| max_length=max_length) | |
| self.decoder_layers = [ | |
| DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate) | |
| for n in range(num_layers)] | |
| self.output_layer = output_layer | |
| def call(self, inputs): | |
| image, txt = inputs | |
| if image.shape[-1] == 3: | |
| # Apply the feature-extractor, if you get an RGB image. | |
| image = self.feature_extractor(image) | |
| # Flatten the feature map | |
| image = einops.rearrange(image, 'b h w c -> b (h w) c') | |
| if txt.dtype == tf.string: | |
| # Apply the tokenizer if you get string inputs. | |
| txt = self.tokenizer(txt) | |
| txt = self.seq_embedding(txt) | |
| # Look at the image | |
| for dec_layer in self.decoder_layers: | |
| txt = dec_layer(inputs=(image, txt)) | |
| txt = self.output_layer(txt) | |
| return txt | |
| def simple_gen(self, image, temperature=1): | |
| initial = self.word_to_index([['[START]']]) # (batch, sequence) | |
| img_features = self.feature_extractor(image[tf.newaxis, ...]) | |
| tokens = initial # (batch, sequence) | |
| for n in range(50): | |
| preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab) | |
| preds = preds[:,-1, :] #(batch, vocab) | |
| if temperature==0: | |
| next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1) | |
| else: | |
| next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1) | |
| tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) | |
| if next[0] == self.word_to_index('[END]'): | |
| break | |
| words = self.index_to_word(tokens[0, 1:-1]) | |
| result = tf.strings.reduce_join(words, axis=-1, separator=' ') | |
| return result.numpy().decode() | |
| # def get_config(self): | |
| # config = super().get_config() | |
| # config.update({"feature_extractor": self.feature_extractor, | |
| # "tokenizer": self.tokenizer, | |
| # "word_to_index": self.word_to_index, | |
| # "index_to_word": self.index_to_word, | |
| # "outputlayer": self.output_layer, | |
| # "seq_embedding": self.seq_embedding, | |
| # "decoder_layers": self.decoder_layers | |
| # }) | |
| # return config | |
| # def build_from_config(self, config): | |
| # return super().build_from_config(config) | |
| # model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), | |
| # loss=masked_loss, | |
| # metrics=[masked_acc]) | |
| print("model complete") | |
| #========================================================================================================================= | |
| ### LOAD FUNCTION | |
| #========================================================================================================================= | |
| def build(): | |
| filename = "model/tokenizer.pkl" | |
| token_meta = pickle.load(open(filename, 'rb')) | |
| tokenizer = tf.keras.layers.TextVectorization.from_config(token_meta["config"]) | |
| tokenizer.set_weights(token_meta['weights']) | |
| print(tokenizer("bulid sentence")) | |
| word_to_index = tf.keras.layers.StringLookup( | |
| mask_token="", | |
| vocabulary=tokenizer.get_vocabulary()) | |
| index_to_word = tf.keras.layers.StringLookup( | |
| mask_token="", | |
| vocabulary=tokenizer.get_vocabulary(), | |
| invert=True) | |
| output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]')) | |
| filename = "model/output_layer.pkl" | |
| bias = pickle.load(open(filename, 'rb')) | |
| output_layer.bias = bias | |
| load_model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer, | |
| units=256, dropout_rate=0.5, num_layers=2, num_heads=2) | |
| load_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), | |
| loss=masked_loss, | |
| metrics=[masked_acc]) | |
| image_url = 'https://tensorflow.org/images/surf.jpg' | |
| image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url) | |
| image = load_image(image_path) | |
| load_model.simple_gen(image) | |
| path = "model/captioner_weights" | |
| load_model.load_weights(path) | |
| return load_model | |
| # loaded_model = build() | |
| print("loaded") | |
| #========================================================================================================================= | |
| ### TEST RUN | |
| #========================================================================================================================= | |
| image_url = 'https://tensorflow.org/images/surf.jpg' | |
| image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url) | |
| image = load_image(image_path) | |