Spaces:
Configuration error
Configuration error
| import torch.nn as nn | |
| from lib.config import cfg | |
| import torch | |
| from lib.networks.renderer import volume_renderer | |
| from lib.train import make_optimizer | |
| class NetworkWrapper(nn.Module): | |
| def __init__(self, net): | |
| super(NetworkWrapper, self).__init__() | |
| self.net = net | |
| self.renderer = volume_renderer.Renderer(self.net) | |
| self.img2mse = lambda x, y : torch.mean((x - y) ** 2) | |
| self.acc_crit = torch.nn.functional.smooth_l1_loss | |
| def forward(self, batch): | |
| ret = self.renderer.render(batch) | |
| scalar_stats = {} | |
| loss = 0 | |
| mask = batch['mask_at_box'] | |
| img_loss = self.img2mse(ret['rgb_map'][mask], batch['rgb'][mask]) | |
| scalar_stats.update({'img_loss': img_loss}) | |
| loss += img_loss | |
| if 'rgb0' in ret: | |
| img_loss0 = self.img2mse(ret['rgb0'], batch['rgb']) | |
| scalar_stats.update({'img_loss0': img_loss0}) | |
| loss += img_loss0 | |
| scalar_stats.update({'loss': loss}) | |
| image_stats = {} | |
| return ret, loss, scalar_stats, image_stats | |