import torch import torch.nn as nn import numpy as np from utility import load_model,get_layers import tensorflow as tf from tensorflow.keras.models import Model # type: ignore import sys class CNN_Encoder(nn.Module): def __init__(self,model_path,model_name,pop_conv_layers,encoder_layers,tags_threshold,tags_embeddings=None,finetune_visual_model=False,num_tags=105): super(CNN_Encoder,self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if tags_embeddings is not None: # Initialize embeddings and move them to the device self.tags_embeddings = nn.Parameter(torch.tensor(tags_embeddings, dtype=torch.float32).to(self.device), requires_grad=True) else: # Initialize embeddings with ones and move them to the device self.tags_embeddings = nn.Parameter(torch.ones((num_tags, 400), dtype=torch.float32).to(self.device), requires_grad=True) self.tags_threshold=tags_threshold # visual_model.children() gets an iterator over child modules(layers) of the model (pretrained chexnet) #list* => converts this iterator into a list of layers so it is easier to manipulate # if pop_conv_layers is, it removes the last layer from the model # nn.sequentail=> creates a new model that consists of all the layers up to the pop_conv_layers, stacks layers in a linear sequence visual_model=load_model(model_path,model_name) self.visual_model = Model(inputs=visual_model.input, outputs=[visual_model.output, visual_model.layers[-pop_conv_layers - 1].output], trainable=finetune_visual_model) self.encoder_layers=get_layers(encoder_layers,'relu') def get_visual_features(self, images): images_np = images.cpu().numpy() images_tf = tf.convert_to_tensor(images_np) images_tf = tf.transpose(images_tf, perm=[0,2,3,1 ]) predictions, visual_features = self.visual_model(images_tf) predictions = torch.tensor(predictions.numpy(), device=self.device, requires_grad=True) visual_features = torch.tensor(visual_features.numpy(), device=self.device, requires_grad=True) predictions = predictions.view(predictions.size(0), predictions.size(-1), -1) visual_features = visual_features.view(visual_features.size(0), -1, visual_features.size(-1)) if self.tags_threshold >= 0: predictions = (predictions >= self.tags_threshold).float() return predictions, visual_features def forward(self, images): # print python version print("Python version:", sys.version) images=images.to(self.device) tags_predictions, visual_features = self.get_visual_features(images) if tags_predictions is not None: tags_predictions=tags_predictions.to(self.device) self.tags_embeddings = self.tags_embeddings.to(self.device) tags_embed = tags_predictions * self.tags_embeddings else: tags_embed=None for layer in self.encoder_layers: visual_features = layer(visual_features) if tags_embed is not None: tags_embed = layer(tags_embed) return visual_features, tags_embed