File size: 3,362 Bytes
eb8805a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import numpy as np
from utility import  load_model,get_layers
import tensorflow as tf
from tensorflow.keras.models import Model # type: ignore
import sys


class CNN_Encoder(nn.Module):
    def __init__(self,model_path,model_name,pop_conv_layers,encoder_layers,tags_threshold,tags_embeddings=None,finetune_visual_model=False,num_tags=105):
        super(CNN_Encoder,self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if tags_embeddings is not None:
            # Initialize embeddings and move them to the device
            self.tags_embeddings = nn.Parameter(torch.tensor(tags_embeddings, dtype=torch.float32).to(self.device), requires_grad=True)
        else:
            # Initialize embeddings with ones and move them to the device
            self.tags_embeddings = nn.Parameter(torch.ones((num_tags, 400), dtype=torch.float32).to(self.device), requires_grad=True)


        self.tags_threshold=tags_threshold
        # visual_model.children() gets an iterator over child modules(layers) of the model (pretrained chexnet)
        #list* => converts this iterator into a list of layers so it is easier to manipulate
        # if pop_conv_layers is, it removes the last layer from the model
        # nn.sequentail=> creates a new model that consists  of all the layers up to the pop_conv_layers, stacks layers in a linear sequence
        visual_model=load_model(model_path,model_name)
        self.visual_model = Model(inputs=visual_model.input,
                                        outputs=[visual_model.output, visual_model.layers[-pop_conv_layers - 1].output],
                                        trainable=finetune_visual_model)

        self.encoder_layers=get_layers(encoder_layers,'relu')

    def get_visual_features(self, images):
        images_np = images.cpu().numpy()
        images_tf = tf.convert_to_tensor(images_np)
        images_tf = tf.transpose(images_tf, perm=[0,2,3,1 ])
        predictions, visual_features = self.visual_model(images_tf)
        predictions = torch.tensor(predictions.numpy(), device=self.device, requires_grad=True)
        visual_features = torch.tensor(visual_features.numpy(), device=self.device, requires_grad=True)

        predictions = predictions.view(predictions.size(0), predictions.size(-1), -1)
        visual_features = visual_features.view(visual_features.size(0), -1, visual_features.size(-1))
        
        if self.tags_threshold >= 0:
            predictions = (predictions >= self.tags_threshold).float()
        
        return predictions, visual_features

    def forward(self, images):
        # print python version
        print("Python version:", sys.version)

        images=images.to(self.device)
        tags_predictions, visual_features = self.get_visual_features(images)
        if tags_predictions is not None:
            tags_predictions=tags_predictions.to(self.device)
            self.tags_embeddings = self.tags_embeddings.to(self.device)
            tags_embed = tags_predictions * self.tags_embeddings 
        else:
            tags_embed=None

        for layer in self.encoder_layers:
            visual_features = layer(visual_features)
            if tags_embed is not None:
                tags_embed = layer(tags_embed)

        return visual_features, tags_embed