CDGPT2-Deployment / CNN_encoder.py
Ziad Meligy
Pushing deployment to space
eb8805a
import torch
import torch.nn as nn
import numpy as np
from utility import load_model,get_layers
import tensorflow as tf
from tensorflow.keras.models import Model # type: ignore
import sys
class CNN_Encoder(nn.Module):
def __init__(self,model_path,model_name,pop_conv_layers,encoder_layers,tags_threshold,tags_embeddings=None,finetune_visual_model=False,num_tags=105):
super(CNN_Encoder,self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if tags_embeddings is not None:
# Initialize embeddings and move them to the device
self.tags_embeddings = nn.Parameter(torch.tensor(tags_embeddings, dtype=torch.float32).to(self.device), requires_grad=True)
else:
# Initialize embeddings with ones and move them to the device
self.tags_embeddings = nn.Parameter(torch.ones((num_tags, 400), dtype=torch.float32).to(self.device), requires_grad=True)
self.tags_threshold=tags_threshold
# visual_model.children() gets an iterator over child modules(layers) of the model (pretrained chexnet)
#list* => converts this iterator into a list of layers so it is easier to manipulate
# if pop_conv_layers is, it removes the last layer from the model
# nn.sequentail=> creates a new model that consists of all the layers up to the pop_conv_layers, stacks layers in a linear sequence
visual_model=load_model(model_path,model_name)
self.visual_model = Model(inputs=visual_model.input,
outputs=[visual_model.output, visual_model.layers[-pop_conv_layers - 1].output],
trainable=finetune_visual_model)
self.encoder_layers=get_layers(encoder_layers,'relu')
def get_visual_features(self, images):
images_np = images.cpu().numpy()
images_tf = tf.convert_to_tensor(images_np)
images_tf = tf.transpose(images_tf, perm=[0,2,3,1 ])
predictions, visual_features = self.visual_model(images_tf)
predictions = torch.tensor(predictions.numpy(), device=self.device, requires_grad=True)
visual_features = torch.tensor(visual_features.numpy(), device=self.device, requires_grad=True)
predictions = predictions.view(predictions.size(0), predictions.size(-1), -1)
visual_features = visual_features.view(visual_features.size(0), -1, visual_features.size(-1))
if self.tags_threshold >= 0:
predictions = (predictions >= self.tags_threshold).float()
return predictions, visual_features
def forward(self, images):
# print python version
print("Python version:", sys.version)
images=images.to(self.device)
tags_predictions, visual_features = self.get_visual_features(images)
if tags_predictions is not None:
tags_predictions=tags_predictions.to(self.device)
self.tags_embeddings = self.tags_embeddings.to(self.device)
tags_embed = tags_predictions * self.tags_embeddings
else:
tags_embed=None
for layer in self.encoder_layers:
visual_features = layer(visual_features)
if tags_embed is not None:
tags_embed = layer(tags_embed)
return visual_features, tags_embed