Spaces:
Running
Running
| """ | |
| This script defines the MIPHEI-ViT architecture for image-to-image translation | |
| Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/ | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| from timm.models import VisionTransformer, SwinTransformer | |
| from timm.models import load_state_dict_from_hf | |
| class Basic_Conv3x3(nn.Module): | |
| """ | |
| Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. | |
| https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5 | |
| """ | |
| def __init__( | |
| self, | |
| in_chans, | |
| out_chans, | |
| stride=2, | |
| padding=1, | |
| ): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) | |
| self.bn = nn.BatchNorm2d(out_chans) | |
| self.relu = nn.ReLU(inplace=False) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| x = self.relu(x) | |
| return x | |
| class ConvStream(nn.Module): | |
| """ | |
| Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. | |
| """ | |
| def __init__( | |
| self, | |
| in_chans = 4, | |
| out_chans = [48, 96, 192], | |
| ): | |
| super().__init__() | |
| self.convs = nn.ModuleList() | |
| self.conv_chans = out_chans.copy() | |
| self.conv_chans.insert(0, in_chans) | |
| for i in range(len(self.conv_chans)-1): | |
| in_chan_ = self.conv_chans[i] | |
| out_chan_ = self.conv_chans[i+1] | |
| self.convs.append( | |
| Basic_Conv3x3(in_chan_, out_chan_) | |
| ) | |
| def forward(self, x): | |
| out_dict = {'D0': x} | |
| for i in range(len(self.convs)): | |
| x = self.convs[i](x) | |
| name_ = 'D'+str(i+1) | |
| out_dict[name_] = x | |
| return out_dict | |
| class SegmentationHead(nn.Sequential): | |
| # https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5 | |
| def __init__( | |
| self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False, | |
| ): | |
| if use_attention: | |
| attention = AttentionBlock(in_channels) | |
| else: | |
| attention = nn.Identity() | |
| conv2d = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 | |
| ) | |
| activation = activation | |
| super().__init__(attention, conv2d, activation) | |
| class AttentionBlock(nn.Module): | |
| """ | |
| Attention gate | |
| Parameters: | |
| ----------- | |
| in_chns : int | |
| Number of input channels. | |
| Forward Input: | |
| -------------- | |
| x : torch.Tensor | |
| Input tensor of shape [B, C, H, W]. | |
| Returns: | |
| -------- | |
| torch.Tensor | |
| Reweighted tensor of the same shape as input. | |
| """ | |
| def __init__(self, in_chns): | |
| super(AttentionBlock, self).__init__() | |
| # Attention generation | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(in_chns // 2), | |
| nn.ReLU(), | |
| nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| # Project decoder output to intermediate space | |
| g = self.psi(x) | |
| return x * g | |
| class Fusion_Block(nn.Module): | |
| """ | |
| Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. | |
| """ | |
| def __init__( | |
| self, | |
| in_chans, | |
| out_chans, | |
| ): | |
| super().__init__() | |
| self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) | |
| def forward(self, x, D): | |
| F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ? | |
| out = torch.cat([D, F_up], dim=1) | |
| out = self.conv(out) | |
| return out | |
| class MIPHEIViT(nn.Module): | |
| """ | |
| U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin) | |
| as encoder and a convolutional decoder. Designed for dense image prediction tasks, | |
| such as image-to-image translation. | |
| Parameters: | |
| ----------- | |
| encoder : nn.Module | |
| A ViT- or Swin-based encoder that outputs spatial feature maps. | |
| decoder : nn.Module | |
| A decoder module that maps encoder features (and optionally the original image) | |
| to the output prediction. | |
| Example: | |
| -------- | |
| model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder()) | |
| output = model(input_tensor) | |
| """ | |
| def __init__(self, | |
| encoder, | |
| decoder, | |
| ): | |
| super(MIPHEIViT, self).__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.initialize() | |
| def forward(self, x): | |
| features = self.encoder(x) | |
| outputs = self.decoder(features, x) | |
| return outputs | |
| def initialize(self): | |
| pass | |
| def from_pretrained_hf(cls, repo_path=None, repo_id=None): | |
| from safetensors.torch import load_file | |
| import json | |
| if repo_path: | |
| weights_path = os.path.join(repo_path, "model.safetensors") | |
| config_path = os.path.join(repo_path, "config_hf.json") | |
| else: | |
| from huggingface_hub import hf_hub_download | |
| weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json") | |
| # Load config values | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| img_size = config["img_size"] | |
| nc_out = len(config["targ_channel_names"]) | |
| use_attention = config["use_attention"] | |
| hoptimus_hf_id = config["hoptimus_hf_id"] | |
| vit = get_hoptimus0_hf(hoptimus_hf_id) | |
| vit.set_input_size(img_size=(img_size, img_size)) | |
| encoder = Encoder(vit) | |
| decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh()) | |
| model = cls(encoder=encoder, decoder=decoder) | |
| state_dict = load_file(weights_path) | |
| state_dict = merge_lora_weights(model, state_dict) | |
| load_info = model.load_state_dict(state_dict, strict=False) | |
| validate_load_info(load_info) | |
| model.eval() | |
| return model | |
| def set_input_size(self, img_size): | |
| if any((s & (s - 1)) != 0 or s == 0 for s in img_size): | |
| raise ValueError("Both height and width in img_size must be powers of 2") | |
| if any(s < 128 for s in img_size): | |
| raise ValueError("Height and width must be greater or equal to 128") | |
| self.encoder.vit.set_input_size(img_size=img_size) | |
| self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size | |
| class Encoder(nn.Module): | |
| """ | |
| Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible | |
| with U-Net-like architectures. It reshapes and resizes transformer outputs | |
| into spatial feature maps. | |
| Parameters: | |
| ----------- | |
| vit : VisionTransformer or SwinTransformer | |
| A pretrained transformer model from `timm` that outputs patch embeddings. | |
| """ | |
| def __init__(self, vit): | |
| super().__init__() | |
| if not isinstance(vit, (VisionTransformer, SwinTransformer)): | |
| raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}") | |
| self.vit = vit | |
| self.is_swint = isinstance(vit, SwinTransformer) | |
| self.grid_size = self.vit.patch_embed.grid_size | |
| if self.is_swint: | |
| self.num_prefix_tokens = 0 | |
| self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1) | |
| else: | |
| self.num_prefix_tokens = self.vit.num_prefix_tokens | |
| self.embed_dim = self.vit.embed_dim | |
| patch_size = self.vit.patch_embed.patch_size | |
| img_size = self.vit.patch_embed.img_size | |
| assert img_size[0] % 16 == 0 | |
| assert img_size[1] % 16 == 0 | |
| if self.is_swint: | |
| self.scale_factor = (2., 2.) | |
| else: | |
| if patch_size != (16, 16): | |
| target_grid_size = (img_size[0] / 16, img_size[1] / 16) | |
| self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1]) | |
| else: | |
| self.scale_factor = None | |
| def forward(self, x): | |
| features = self.vit(x) | |
| if self.is_swint: | |
| features = features.permute(0, 3, 1, 2) | |
| else: | |
| features = features[:, self.num_prefix_tokens:] | |
| features = features.permute(0, 2, 1) | |
| features = features.view((-1, self.embed_dim, *self.grid_size)) | |
| if self.scale_factor is not None: | |
| features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic") | |
| return features | |
| class Detail_Capture(nn.Module): | |
| """ | |
| Simple and Lightweight Detail Capture Module for ViT Matting. | |
| """ | |
| def __init__( | |
| self, | |
| emb_chans, | |
| in_chans=3, | |
| out_chans=1, | |
| convstream_out = [48, 96, 192], | |
| fusion_out = [256, 128, 64, 32], | |
| use_attention=True, | |
| activation=torch.nn.Identity() | |
| ): | |
| super().__init__() | |
| assert len(fusion_out) == len(convstream_out) + 1 | |
| self.convstream = ConvStream(in_chans=in_chans) | |
| self.conv_chans = self.convstream.conv_chans | |
| self.num_heads = out_chans | |
| self.fusion_blks = nn.ModuleList() | |
| self.fus_channs = fusion_out.copy() | |
| self.fus_channs.insert(0, emb_chans) | |
| for i in range(len(self.fus_channs)-1): | |
| self.fusion_blks.append( | |
| Fusion_Block( | |
| in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], | |
| out_chans = self.fus_channs[i+1], | |
| ) | |
| ) | |
| for idx in range(self.num_heads): | |
| setattr(self, f'segmentation_head_{idx}', SegmentationHead( | |
| in_channels=fusion_out[-1], | |
| out_channels=1, | |
| activation=activation, | |
| kernel_size=3, | |
| use_attention=use_attention | |
| )) | |
| def forward(self, features, images): | |
| detail_features = self.convstream(images) | |
| for i in range(len(self.fusion_blks)): | |
| d_name_ = 'D'+str(len(self.fusion_blks)-i-1) | |
| features = self.fusion_blks[i](features, detail_features[d_name_]) | |
| outputs = [] | |
| for idx_head in range(self.num_heads): | |
| segmentation_head = getattr(self, f'segmentation_head_{idx_head}') | |
| output = segmentation_head(features) | |
| outputs.append(output) | |
| outputs = torch.cat(outputs, dim=1) | |
| return outputs | |
| def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"): | |
| """ | |
| Merges LoRA weights into the base attention Q and V projection weights for each transformer block. | |
| We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo. | |
| Parameters: | |
| ----------- | |
| model : torch.nn.Module | |
| The model containing the transformer blocks to modify (e.g., ViT backbone). | |
| state_dict : dict | |
| The state_dict containing LoRA matrices with keys formatted as | |
| '{block_prefix}.{idx}.attn.qkv.lora_q.A', etc. | |
| This dict is modified in-place to remove LoRA weights after merging. | |
| alpha : float, optional | |
| Scaling factor for the LoRA update. Defaults to 1.0. | |
| block_prefix : str, optional | |
| Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks". | |
| Returns: | |
| -------- | |
| dict | |
| The modified state_dict with LoRA weights removed after merging. | |
| """ | |
| with torch.no_grad(): | |
| for idx in range(len(model.encoder.vit.blocks)): | |
| prefix = f"{block_prefix}.{idx}.attn.qkv" | |
| # Extract LoRA matrices | |
| A_q = state_dict.pop(f"{prefix}.lora_q.A") | |
| B_q = state_dict.pop(f"{prefix}.lora_q.B") | |
| A_v = state_dict.pop(f"{prefix}.lora_v.A") | |
| B_v = state_dict.pop(f"{prefix}.lora_v.B") | |
| # Compute low-rank updates (transposed to match weight shape) | |
| delta_q = (alpha * A_q @ B_q).T | |
| delta_v = (alpha * A_v @ B_v).T | |
| # Get original QKV weight matrix (shape: [3*dim, dim]) | |
| W = model.get_parameter(f"{prefix}.weight") | |
| dim = delta_q.shape[0] | |
| assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}" | |
| # Apply LoRA deltas to Q and V projections | |
| W[:dim, :] += delta_q # Q projection | |
| W[2 * dim:, :] += delta_v # V projection | |
| return state_dict | |
| def get_hoptimus0_hf(repo_id): | |
| """ Hoptimus foundation model from hugginface repo id | |
| """ | |
| model = timm.create_model( | |
| "vit_giant_patch14_reg4_dinov2", img_size=224, | |
| drop_path_rate=0., num_classes=0, | |
| global_pool="", pretrained=False, init_values=1e-5, | |
| dynamic_img_size=False) | |
| state_dict = load_state_dict_from_hf(repo_id, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def validate_load_info(load_info): | |
| """ | |
| Validates the result of model.load_state_dict(..., strict=False). | |
| Raises: | |
| ValueError if unexpected keys are found, | |
| or if missing keys are not related to the allowed encoder modules. | |
| """ | |
| # 1. Raise if any unexpected keys | |
| if load_info.unexpected_keys: | |
| raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}") | |
| # 2. Raise if any missing keys are not part of allowed encoder modules | |
| for key in load_info.missing_keys: | |
| if ".lora" in key: | |
| raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}") | |
| elif not any(part in key for part in ["encoder.vit.", "encoder.model."]): | |
| raise ValueError(f"Missing key in state_dict: {key}") | |