Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from utility import load_model,get_layers | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model # type: ignore | |
| import sys | |
| class CNN_Encoder(nn.Module): | |
| def __init__(self,model_path,model_name,pop_conv_layers,encoder_layers,tags_threshold,tags_embeddings=None,finetune_visual_model=False,num_tags=105): | |
| super(CNN_Encoder,self).__init__() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if tags_embeddings is not None: | |
| # Initialize embeddings and move them to the device | |
| self.tags_embeddings = nn.Parameter(torch.tensor(tags_embeddings, dtype=torch.float32).to(self.device), requires_grad=True) | |
| else: | |
| # Initialize embeddings with ones and move them to the device | |
| self.tags_embeddings = nn.Parameter(torch.ones((num_tags, 400), dtype=torch.float32).to(self.device), requires_grad=True) | |
| self.tags_threshold=tags_threshold | |
| # visual_model.children() gets an iterator over child modules(layers) of the model (pretrained chexnet) | |
| #list* => converts this iterator into a list of layers so it is easier to manipulate | |
| # if pop_conv_layers is, it removes the last layer from the model | |
| # nn.sequentail=> creates a new model that consists of all the layers up to the pop_conv_layers, stacks layers in a linear sequence | |
| visual_model=load_model(model_path,model_name) | |
| self.visual_model = Model(inputs=visual_model.input, | |
| outputs=[visual_model.output, visual_model.layers[-pop_conv_layers - 1].output], | |
| trainable=finetune_visual_model) | |
| self.encoder_layers=get_layers(encoder_layers,'relu') | |
| def get_visual_features(self, images): | |
| images_np = images.cpu().numpy() | |
| images_tf = tf.convert_to_tensor(images_np) | |
| images_tf = tf.transpose(images_tf, perm=[0,2,3,1 ]) | |
| predictions, visual_features = self.visual_model(images_tf) | |
| predictions = torch.tensor(predictions.numpy(), device=self.device, requires_grad=True) | |
| visual_features = torch.tensor(visual_features.numpy(), device=self.device, requires_grad=True) | |
| predictions = predictions.view(predictions.size(0), predictions.size(-1), -1) | |
| visual_features = visual_features.view(visual_features.size(0), -1, visual_features.size(-1)) | |
| if self.tags_threshold >= 0: | |
| predictions = (predictions >= self.tags_threshold).float() | |
| return predictions, visual_features | |
| def forward(self, images): | |
| # print python version | |
| print("Python version:", sys.version) | |
| images=images.to(self.device) | |
| tags_predictions, visual_features = self.get_visual_features(images) | |
| if tags_predictions is not None: | |
| tags_predictions=tags_predictions.to(self.device) | |
| self.tags_embeddings = self.tags_embeddings.to(self.device) | |
| tags_embed = tags_predictions * self.tags_embeddings | |
| else: | |
| tags_embed=None | |
| for layer in self.encoder_layers: | |
| visual_features = layer(visual_features) | |
| if tags_embed is not None: | |
| tags_embed = layer(tags_embed) | |
| return visual_features, tags_embed | |