ayushshah commited on
Commit
1dc7b3d
·
1 Parent(s): c7a3f00

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +124 -0
model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import torchvision.models as models
5
+ from torchvision.models import ResNet34_Weights
6
+
7
+
8
+ class ResNetEncoder(nn.Module):
9
+ def __init__(self, freeze=True):
10
+ super().__init__()
11
+ resnet = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
12
+
13
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
14
+ with torch.no_grad():
15
+ self.conv1.weight[:] = resnet.conv1.weight.mean(dim=1, keepdim=True)
16
+
17
+ self.bn1 = resnet.bn1
18
+ self.relu = resnet.relu
19
+ self.maxpool = resnet.maxpool
20
+ self.layer1 = resnet.layer1
21
+ self.layer2 = resnet.layer2
22
+ self.layer3 = resnet.layer3
23
+ self.layer4 = resnet.layer4
24
+
25
+ if freeze:
26
+ for param in self.parameters():
27
+ param.requires_grad = False
28
+
29
+ def forward(self, x):
30
+ # x = (x - 0.449) / 0.226
31
+ x = self.conv1(x)
32
+ x = self.bn1(x)
33
+ x = self.relu(x)
34
+ x1 = self.maxpool(x)
35
+ x2 = self.layer1(x1)
36
+ del x1
37
+
38
+ x3 = self.layer2(x2)
39
+ x4 = self.layer3(x3)
40
+ x5 = self.layer4(x4)
41
+
42
+ return x, x2, x3, x4, x5
43
+
44
+
45
+ def icnr(tensor, scale=2, init_func=init.kaiming_normal_):
46
+ ni, nf, h, w = tensor.shape
47
+ ni2 = int(ni / (scale ** 2))
48
+ k = init_func(torch.zeros([ni2, nf, h, w]))
49
+ k = k.repeat_interleave(scale ** 2, 0)
50
+ with torch.no_grad():
51
+ tensor.copy_(k)
52
+
53
+
54
+ class PixelShuffleICNR(nn.Module):
55
+ def __init__(self, in_channels, out_channels, scale=2):
56
+ super().__init__()
57
+ self.conv = nn.Conv2d(in_channels, out_channels * (scale ** 2), kernel_size=3, padding=1)
58
+ icnr(self.conv.weight, scale=scale)
59
+ self.pixel_shuffle = nn.PixelShuffle(scale)
60
+ self.bn = nn.BatchNorm2d(out_channels)
61
+ self.relu = nn.ReLU(inplace=True)
62
+
63
+ def forward(self, x):
64
+ x = self.conv(x)
65
+ x = self.pixel_shuffle(x)
66
+ x = self.bn(x)
67
+ x = self.relu(x)
68
+ return x
69
+
70
+
71
+ class DecoderBlock(nn.Module):
72
+ def __init__(self, in_channels, skip_channels, out_channels):
73
+ super().__init__()
74
+
75
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
76
+ self.conv = nn.Sequential(
77
+ nn.Conv2d(in_channels + skip_channels, out_channels, kernel_size=3, padding=1),
78
+ nn.BatchNorm2d(out_channels),
79
+ nn.ReLU(inplace=True),
80
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
81
+ nn.BatchNorm2d(out_channels),
82
+ nn.ReLU(inplace=True),
83
+ )
84
+
85
+ def forward(self, x, skip):
86
+ x = self.upsample(x)
87
+ if skip is not None:
88
+ x = torch.cat([x, skip], dim=1)
89
+ return self.conv(x)
90
+
91
+
92
+ class Decoder(nn.Module):
93
+ def __init__(self):
94
+ super().__init__()
95
+ self.dec4 = DecoderBlock(512, 256, 256)
96
+ self.dec3 = DecoderBlock(256, 128, 128)
97
+ self.dec2 = DecoderBlock(128, 64, 64)
98
+ self.dec1 = DecoderBlock(64, 64, 64)
99
+ self.pixel_shuffle = PixelShuffleICNR(64, 16, scale=2)
100
+ self.final = nn.Conv2d(16, 2, kernel_size=3, padding=1)
101
+
102
+ def forward(self, x5, x4, x3, x2, x1):
103
+ d4 = self.dec4(x5, x4)
104
+ d3 = self.dec3(d4, x3)
105
+ del d4, x4, x3
106
+ d2 = self.dec2(d3, x2)
107
+ del d3, x2
108
+ d1 = self.dec1(d2, x1)
109
+ del d2, x1
110
+ out = self.pixel_shuffle(d1)
111
+ del d1
112
+ out = self.final(out)
113
+ return torch.tanh(out)
114
+
115
+
116
+ class UNet(nn.Module):
117
+ def __init__(self):
118
+ super().__init__()
119
+ self.encoder = ResNetEncoder()
120
+ self.decoder = Decoder()
121
+
122
+ def forward(self, x):
123
+ x, x2, x3, x4, x5 = self.encoder(x)
124
+ return self.decoder(x5, x4, x3, x2, x)