Spaces:
Runtime error
Runtime error
| import copy | |
| import pdb | |
| import os | |
| import math | |
| from typing import List | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch import Tensor | |
| from util import box_ops | |
| from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz | |
| from util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, | |
| get_world_size, interpolate, | |
| is_dist_avail_and_initialized, inverse_sigmoid) | |
| from .backbones import build_backbone | |
| from .matcher import build_matcher | |
| from .transformer import build_transformer | |
| from .utils import PoseProjector, sigmoid_focal_loss, MLP | |
| from .postprocesses import PostProcess_SMPLX, PostProcess_aios | |
| from .postprocesses import PostProcess_SMPLX_Multi as PostProcess_SMPLX | |
| from .postprocesses import PostProcess_SMPLX_Multi_Box | |
| from .postprocesses import PostProcess_SMPLX_Multi_Infer, PostProcess_SMPLX_Multi_Infer_Box | |
| from .criterion_smplx import SetCriterion, SetCriterion_Box | |
| from ..registry import MODULE_BUILD_FUNCS | |
| from detrsmpl.core.conventions.keypoints_mapping import convert_kps | |
| from detrsmpl.models.body_models.builder import build_body_model | |
| from util.human_models import smpl_x | |
| from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idxs_by_part | |
| import numpy as np | |
| import random | |
| from detrsmpl.utils.geometry import (rot6d_to_rotmat) | |
| from detrsmpl.utils.transforms import rotmat_to_aa | |
| import cv2 | |
| from config.config import cfg | |
| class AiOSSMPLX(nn.Module): | |
| def __init__( | |
| self, | |
| backbone, | |
| transformer, | |
| num_classes, | |
| num_queries, | |
| aux_loss=False, | |
| iter_update=True, | |
| query_dim=4, | |
| random_refpoints_xy=False, | |
| fix_refpoints_hw=-1, | |
| num_feature_levels=1, | |
| nheads=8, | |
| two_stage_type='no', | |
| dec_pred_class_embed_share=False, | |
| dec_pred_bbox_embed_share=False, | |
| dec_pred_pose_embed_share=False, | |
| two_stage_class_embed_share=True, | |
| two_stage_bbox_embed_share=True, | |
| dn_number=100, | |
| dn_box_noise_scale=0.4, | |
| dn_label_noise_ratio=0.5, | |
| dn_batch_gt_fuse=False, | |
| dn_labelbook_size=100, | |
| dn_attn_mask_type_list=['group2group'], | |
| cls_no_bias=False, | |
| num_group=100, | |
| num_body_points=17, | |
| num_hand_points=10, | |
| num_face_points=10, | |
| num_box_decoder_layers=2, | |
| num_hand_face_decoder_layers=4, | |
| body_model=dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| keypoint_dst='smplx_137', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True), | |
| train=True, | |
| inference=False, | |
| focal_length=[5000., 5000.], | |
| camera_3d_size=2.5 | |
| ): | |
| super().__init__() | |
| self.num_queries = num_queries | |
| self.transformer = transformer | |
| self.num_classes = num_classes | |
| self.hidden_dim = hidden_dim = transformer.d_model | |
| self.num_feature_levels = num_feature_levels | |
| self.nheads = nheads | |
| self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) | |
| self.num_body_points = num_body_points | |
| self.num_hand_points = num_hand_points | |
| self.num_face_points = num_face_points | |
| self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points | |
| self.num_box_decoder_layers = num_box_decoder_layers | |
| self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
| self.focal_length = focal_length | |
| self.camera_3d_size=camera_3d_size | |
| self.inference = inference | |
| if train: | |
| self.smpl_convention = 'smplx' | |
| else: | |
| self.smpl_convention = 'h36m' | |
| # setting query dim | |
| self.query_dim = query_dim | |
| assert query_dim == 4 | |
| self.random_refpoints_xy = random_refpoints_xy # False | |
| self.fix_refpoints_hw = fix_refpoints_hw # -1 | |
| # for dn training | |
| self.dn_number = dn_number | |
| self.dn_box_noise_scale = dn_box_noise_scale | |
| self.dn_label_noise_ratio = dn_label_noise_ratio | |
| self.dn_batch_gt_fuse = dn_batch_gt_fuse | |
| self.dn_labelbook_size = dn_labelbook_size | |
| self.dn_attn_mask_type_list = dn_attn_mask_type_list | |
| assert all([ | |
| i in ['match2dn', 'dn2dn', 'group2group'] | |
| for i in dn_attn_mask_type_list | |
| ]) | |
| assert not dn_batch_gt_fuse | |
| # build human body | |
| # if train: | |
| # self.body_model = build_body_model(body_model) | |
| if inference: | |
| body_model=dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| num_betas=10, | |
| keypoint_dst='smplx', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True) | |
| self.body_model = build_body_model(body_model) | |
| for param in self.body_model.parameters(): | |
| param.requires_grad = False | |
| # prepare input projection layers | |
| if num_feature_levels > 1: | |
| num_backbone_outs = len(backbone.num_channels) # 3 | |
| input_proj_list = [] | |
| for _ in range(num_backbone_outs): | |
| in_channels = backbone.num_channels[_] | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, hidden_dim, kernel_size=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| )) | |
| for _ in range(num_feature_levels - num_backbone_outs): | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, | |
| hidden_dim, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| )) | |
| in_channels = hidden_dim | |
| self.input_proj = nn.ModuleList(input_proj_list) | |
| else: | |
| assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!' | |
| self.input_proj = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(backbone.num_channels[-1], | |
| hidden_dim, | |
| kernel_size=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ]) | |
| self.backbone = backbone | |
| self.aux_loss = aux_loss | |
| self.box_pred_damping = box_pred_damping = None | |
| self.iter_update = iter_update | |
| assert iter_update, 'Why not iter_update?' | |
| # prepare pred layers | |
| self.dec_pred_class_embed_share = dec_pred_class_embed_share # false | |
| self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false | |
| # 1.1 prepare class & box embed | |
| _class_embed = nn.Linear(hidden_dim, | |
| num_classes, | |
| bias=(not cls_no_bias)) | |
| if not cls_no_bias: | |
| prior_prob = 0.01 | |
| bias_value = -math.log((1 - prior_prob) / prior_prob) | |
| _class_embed.bias.data = torch.ones(self.num_classes) * bias_value | |
| # 1.2 box embed layer list | |
| if dec_pred_class_embed_share: | |
| class_embed_layerlist = [ | |
| _class_embed for i in range(transformer.num_decoder_layers) | |
| ] | |
| else: | |
| class_embed_layerlist = [ | |
| copy.deepcopy(_class_embed) | |
| for i in range(transformer.num_decoder_layers) | |
| ] | |
| ########################################################################### | |
| # body bbox + l/r hand box + face box | |
| ########################################################################### | |
| # 1.1 body bbox embed | |
| _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) | |
| nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) | |
| # 1.2 body bbox embed layer list | |
| self.num_group = num_group | |
| if dec_pred_bbox_embed_share: | |
| box_body_embed_layerlist = [ | |
| _bbox_embed for i in range(transformer.num_decoder_layers) | |
| ] | |
| else: | |
| box_body_embed_layerlist = [ | |
| copy.deepcopy(_bbox_embed) | |
| for i in range(transformer.num_decoder_layers) | |
| ] | |
| # 2.1 lhand bbox embed | |
| _bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4 | |
| nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0) | |
| _bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0) | |
| # 2.2 lhand bbox embed layer list | |
| if dec_pred_pose_embed_share: | |
| box_hand_embed_layerlist = \ | |
| [_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] | |
| else: | |
| box_hand_embed_layerlist = [ | |
| copy.deepcopy(_bbox_hand_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers + 1) | |
| ] | |
| if dec_pred_pose_embed_share: | |
| box_hand_hw_embed_layerlist = [ | |
| _bbox_hand_hw_embed for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers) | |
| ] | |
| else: | |
| box_hand_hw_embed_layerlist = [ | |
| copy.deepcopy(_bbox_hand_hw_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 4.1 face bbox embed | |
| _bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0) | |
| _bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0) | |
| # 4.2 face bbox embed layer list | |
| if dec_pred_pose_embed_share: | |
| box_face_embed_layerlist = [ | |
| _bbox_face_embed for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers + 1) | |
| ] | |
| else: | |
| box_face_embed_layerlist = [ | |
| copy.deepcopy(_bbox_face_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers + 1) | |
| ] | |
| if dec_pred_pose_embed_share: | |
| box_face_hw_embed_layerlist = [ | |
| _bbox_face_hw_embed for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers)] | |
| else: | |
| box_face_hw_embed_layerlist = [ | |
| copy.deepcopy(_bbox_face_hw_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| ########################################################################### | |
| # body kp2d + l/r hand kp2d + face kp2d | |
| ########################################################################### | |
| ######## body ####### | |
| # 1.1 body kp2d embed | |
| _pose_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_pose_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_pose_embed.layers[-1].bias.data, 0) | |
| # 1.2 body kp2d embed layer list | |
| if num_body_points == 17: | |
| if dec_pred_pose_embed_share: | |
| pose_embed_layerlist = \ | |
| [_pose_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] | |
| else: | |
| pose_embed_layerlist = [ | |
| copy.deepcopy(_pose_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers + 1) | |
| ] | |
| else: | |
| if dec_pred_pose_embed_share: | |
| pose_embed_layerlist = [ | |
| _pose_embed for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| pose_embed_layerlist = [ | |
| copy.deepcopy(_pose_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 1.3 body kp bbox embed | |
| _pose_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| # 1.4 body kp bbox embed layer list | |
| pose_hw_embed_layerlist = [ | |
| _pose_hw_embed for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| ######## lhand ####### | |
| # 2.1 lhand kp2d embed | |
| _pose_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_pose_hand_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_pose_hand_embed.layers[-1].bias.data, 0) | |
| # 2.2 lhand kp2d embed layer list | |
| if dec_pred_pose_embed_share: | |
| pose_hand_embed_layerlist = \ | |
| [_pose_hand_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)] | |
| else: | |
| pose_hand_embed_layerlist = [ | |
| copy.deepcopy(_pose_hand_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_hand_face_decoder_layers + 1) | |
| ] | |
| # 2.3 lhand kp bbox embed | |
| _pose_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| # 2.4 lhand kp bbox embed layer list | |
| pose_hand_hw_embed_layerlist = [ | |
| _pose_hand_hw_embed for i in range(transformer.num_decoder_layers - | |
| num_hand_face_decoder_layers) | |
| ] | |
| ######## face ####### | |
| # 4.1 face kp2d embed | |
| _pose_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_pose_face_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_pose_face_embed.layers[-1].bias.data, 0) | |
| # 4.2 face kp2d embed layer list | |
| if dec_pred_pose_embed_share: | |
| pose_face_embed_layerlist = \ | |
| [_pose_face_embed for i in range(transformer.num_decoder_layers - num_hand_face_decoder_layers+1)] | |
| else: | |
| pose_face_embed_layerlist = [ | |
| copy.deepcopy(_pose_face_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_hand_face_decoder_layers + 1) | |
| ] | |
| # 4.3 face kp bbox embed | |
| _pose_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| # 4.4 face kp bbox embed layer list | |
| pose_face_hw_embed_layerlist = [ | |
| _pose_face_hw_embed for i in range(transformer.num_decoder_layers - | |
| num_hand_face_decoder_layers) | |
| ] | |
| ########################################################################### | |
| # smpl pose + betas + kp2d + kp3d + cam | |
| ########################################################################### | |
| # 1. smpl pose embed | |
| if body_model['type'].upper()=='SMPL': | |
| self.body_model_joint_num = 24 | |
| elif body_model['type'].upper()=='SMPLX': | |
| self.body_model_joint_num = 22 | |
| else: | |
| raise ValueError( | |
| f'Only supports SMPL or SMPLX, but get {body_model.type}') | |
| #TODO: | |
| _smpl_pose_embed = MLP(hidden_dim * (self.num_body_points + 4), | |
| hidden_dim, self.body_model_joint_num * 6, 3) | |
| nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smpl_pose_embed_layerlist = [ | |
| _smpl_pose_embed | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smpl_pose_embed_layerlist = [ | |
| copy.deepcopy(_smpl_pose_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 2. smpl betas embed | |
| _smpl_beta_embed = MLP(hidden_dim * (self.num_body_points + 4), | |
| hidden_dim, 10, 3) | |
| nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smpl_beta_embed_layerlist = [ | |
| _smpl_beta_embed | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smpl_beta_embed_layerlist = [ | |
| copy.deepcopy(_smpl_beta_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 3. smpl cam embed | |
| _cam_embed = MLP(hidden_dim * (self.num_body_points + 4), hidden_dim, | |
| 3, 3) | |
| nn.init.constant_(_cam_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_cam_embed.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| cam_embed_layerlist = [ | |
| _cam_embed for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| cam_embed_layerlist = [ | |
| copy.deepcopy(_cam_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| ########################################################################### | |
| # smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam | |
| ########################################################################### | |
| # 1. smplx body pose embed | |
| # _smplx_pose_embed = MLP(hidden_dim * (self.num_body_points + 1), | |
| # hidden_dim, 23 * 6, 3) | |
| # nn.init.constant_(_smplx_pose_embed.layers[-1].weight.data, 0) | |
| # nn.init.constant_(_smplx_pose_embed.layers[-1].bias.data, 0) | |
| # if dec_pred_bbox_embed_share: | |
| # smplx_pose_embed_layerlist = [ | |
| # _smplx_pose_embed | |
| # for i in range(transformer.num_decoder_layers - | |
| # num_box_decoder_layers + 1) | |
| # ] | |
| # else: | |
| # smplx_pose_embed_layerlist = [ | |
| # copy.deepcopy(_smplx_pose_embed) | |
| # for i in range(transformer.num_decoder_layers - | |
| # num_box_decoder_layers + 1) | |
| # ] | |
| # 2. smplx hand pose embed | |
| _smplx_hand_pose_embed_layer_2_3 = \ | |
| MLP(hidden_dim, hidden_dim, 15 * 6, 3) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0) | |
| _smplx_hand_pose_embed_layer_4_5 = \ | |
| MLP(hidden_dim * (self.num_hand_points + 3), hidden_dim, 15 * 6, 3) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smplx_hand_pose_embed_layerlist = [ | |
| _smplx_hand_pose_embed_layer_2_3 | |
| if i<2 else _smplx_hand_pose_embed_layer_4_5 | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smplx_hand_pose_embed_layerlist = [ | |
| copy.deepcopy(_smplx_hand_pose_embed_layer_2_3) | |
| if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 3. smplx face expression | |
| _smplx_expression_embed_layer_2_3 = \ | |
| MLP(hidden_dim, hidden_dim, 10, 3) | |
| nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0) | |
| _smplx_expression_embed_layer_4_5 = \ | |
| MLP(hidden_dim * (self.num_hand_points + 2), hidden_dim, 10, 3) | |
| nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smplx_expression_embed_layerlist = [ | |
| _smplx_expression_embed_layer_2_3 | |
| if i<2 else _smplx_expression_embed_layer_4_5 | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smplx_expression_embed_layerlist = [ | |
| copy.deepcopy(_smplx_expression_embed_layer_2_3) | |
| if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # _smplx_expression_embed = MLP(hidden_dim * (self.num_face_points + 2), | |
| # hidden_dim, 10, 3) | |
| # nn.init.constant_(_smplx_expression_embed.layers[-1].weight.data, 0) | |
| # nn.init.constant_(_smplx_expression_embed.layers[-1].bias.data, 0) | |
| # if dec_pred_bbox_embed_share: | |
| # smplx_expression_embed_layerlist = [ | |
| # _smplx_expression_embed | |
| # for i in range(transformer.num_decoder_layers - | |
| # num_hand_face_decoder_layers) | |
| # ] | |
| # else: | |
| # smplx_expression_embed_layerlist = [ | |
| # copy.deepcopy(_smplx_expression_embed) | |
| # for i in range(transformer.num_decoder_layers - | |
| # num_hand_face_decoder_layers) | |
| # ] | |
| # 4. smplx jaw pose embed | |
| _smplx_jaw_embed_2_3 = MLP(hidden_dim * 1, | |
| hidden_dim, 6, 3) | |
| nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0) | |
| _smplx_jaw_embed_4_5 = MLP(hidden_dim * (self.num_face_points + 2), | |
| hidden_dim, 6, 3) | |
| nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smplx_jaw_embed_layerlist = [ | |
| _smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5 | |
| for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers) | |
| ] | |
| else: | |
| smplx_jaw_embed_layerlist = [ | |
| copy.deepcopy(_smplx_jaw_embed_2_3) | |
| if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5) | |
| for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers) | |
| ] | |
| ############### | |
| self.bbox_embed = nn.ModuleList(box_body_embed_layerlist) | |
| self.class_embed = nn.ModuleList(class_embed_layerlist) | |
| self.pose_embed = nn.ModuleList(pose_embed_layerlist) | |
| self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) | |
| self.transformer.decoder.bbox_embed = self.bbox_embed | |
| self.transformer.decoder.pose_embed = self.pose_embed | |
| self.transformer.decoder.pose_hw_embed = self.pose_hw_embed | |
| self.transformer.decoder.class_embed = self.class_embed | |
| # smpl | |
| self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist) | |
| self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist) | |
| self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist) | |
| # self.smpl_cam_f_embed = nn.ModuleList(f_embed_layerlist) | |
| # self.transformer.decoder.smpl_pose_embed = self.smpl_pose_embed | |
| # self.transformer.decoder.smpl_beta_embed = self.smpl_beta_embed | |
| # self.transformer.decoder.smpl_cam_embed = self.smpl_cam_embed | |
| # smplx lhand kp | |
| self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist) | |
| self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist) | |
| self.pose_hand_embed = nn.ModuleList(pose_hand_embed_layerlist) | |
| self.pose_hand_hw_embed = nn.ModuleList(pose_hand_hw_embed_layerlist) | |
| self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed | |
| self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed | |
| self.transformer.decoder.pose_hand_embed = self.pose_hand_embed | |
| self.transformer.decoder.pose_hand_hw_embed = self.pose_hand_hw_embed | |
| # smplx face kp | |
| self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist) | |
| self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist) | |
| self.pose_face_embed = nn.ModuleList(pose_face_embed_layerlist) | |
| self.pose_face_hw_embed = nn.ModuleList(pose_face_hw_embed_layerlist) | |
| self.transformer.decoder.bbox_face_embed = self.bbox_face_embed | |
| self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed | |
| self.transformer.decoder.pose_face_embed = self.pose_face_embed | |
| self.transformer.decoder.pose_face_hw_embed = self.pose_face_hw_embed | |
| # smplx | |
| self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist) | |
| # self.smplx_rhand_pose_embed = nn.ModuleList(smplx_rhand_pose_embed_layerlist) | |
| self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist) | |
| self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist) | |
| # self.transformer.decoder.smplx_hand_pose_embed = self.smplx_hand_pose_embed | |
| # self.transformer.decoder.smplx_rhand_pose_embed = self.smplx_rhand_pose_embed | |
| # self.transformer.decoder.num_whole_bosmpl_expr_embeddy_points = self.smplx_expression_embed | |
| # self.transformer.decoder.smpl_jaw_embed = self.smplx_jaw_embed | |
| ######### | |
| self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
| self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers | |
| self.transformer.decoder.num_body_points = num_body_points | |
| self.transformer.decoder.num_hand_points = num_hand_points | |
| self.transformer.decoder.num_face_points = num_face_points | |
| # two stage | |
| self.two_stage_type = two_stage_type | |
| assert two_stage_type in [ | |
| 'no', 'standard' | |
| ], 'unknown param {} of two_stage_type'.format(two_stage_type) | |
| if two_stage_type != 'no': | |
| if two_stage_bbox_embed_share: | |
| assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
| self.transformer.enc_out_bbox_embed = _bbox_embed | |
| else: | |
| self.transformer.enc_out_bbox_embed = copy.deepcopy( | |
| _bbox_embed) | |
| if two_stage_class_embed_share: | |
| assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
| self.transformer.enc_out_class_embed = _class_embed | |
| else: | |
| self.transformer.enc_out_class_embed = copy.deepcopy( | |
| _class_embed) | |
| self.refpoint_embed = None | |
| self._reset_parameters() | |
| def get_camera_trans(self, cam_param, input_body_shape): | |
| # camera translation | |
| t_xy = cam_param[:, :2] | |
| gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive | |
| k_value = torch.FloatTensor( | |
| [ | |
| math.sqrt( | |
| self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size / | |
| (input_body_shape[0] * input_body_shape[1]) | |
| ) | |
| ] | |
| ).cuda().view(-1) | |
| t_z = k_value * gamma | |
| cam_trans = torch.cat((t_xy, t_z[:, None]), 1) | |
| return cam_trans | |
| def _reset_parameters(self): | |
| # init input_proj | |
| for proj in self.input_proj: | |
| nn.init.xavier_uniform_(proj[0].weight, gain=1) | |
| nn.init.constant_(proj[0].bias, 0) | |
| def prepare_for_dn2(self, targets): | |
| if not self.training: | |
| device = targets[0]['boxes'].device | |
| bs = len(targets) | |
| num_points = self.num_body_points + 4 | |
| attn_mask2 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| self.num_group * num_points, | |
| self.num_group * num_points, | |
| device=device, | |
| dtype=torch.bool) | |
| group_bbox_kpt = num_points | |
| group_nobbox_kpt = self.num_body_points | |
| kpt_index = [ | |
| x for x in range(self.num_group * num_points) | |
| if x % num_points in [ | |
| 0, | |
| self.num_body_points+1, | |
| self.num_body_points+2, | |
| self.num_body_points+3 | |
| ] | |
| ] | |
| for matchj in range(self.num_group * num_points): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| if sj > 0: | |
| attn_mask2[:, :, matchj, :sj] = True | |
| if ej < self.num_group * num_points: | |
| attn_mask2[:, :, matchj, ej:] = True | |
| for match_x in range(self.num_group * num_points): | |
| if match_x % group_bbox_kpt in [0, | |
| self.num_body_points+1, | |
| self.num_body_points+2, | |
| self.num_body_points+3]: | |
| attn_mask2[:,:,match_x,kpt_index]=False | |
| num_points = self.num_whole_body_points + 4 | |
| attn_mask3 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| self.num_group * (num_points), | |
| self.num_group * (num_points), | |
| device=device, | |
| dtype=torch.bool) | |
| group_bbox_kpt = (num_points) | |
| # group_nobbox_kpt = self.num_body_points | |
| kpt_index = [ | |
| x for x in range(self.num_group * (num_points)) if x % (num_points) in | |
| [0, | |
| 1+self.num_body_points, | |
| 2+self.num_body_points+self.num_hand_points, | |
| 3+self.num_body_points+self.num_hand_points*2 | |
| ] | |
| ] | |
| for matchj in range(self.num_group * num_points): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| if sj > 0: | |
| attn_mask3[:, :, matchj, :sj] = True | |
| if ej < self.num_group * num_points: | |
| attn_mask3[:, :, matchj, ej:] = True | |
| for match_x in range(self.num_group * num_points): | |
| if match_x % group_bbox_kpt in [ | |
| 0, | |
| 1 + self.num_body_points, | |
| 2 + self.num_body_points + self.num_hand_points, | |
| 3 + self.num_body_points + self.num_hand_points * 2]: | |
| attn_mask3[:, :, match_x, kpt_index] = False | |
| # num_points = self.num_whole_body_points + 4 | |
| # device = targets[0]['boxes'].device | |
| # bs = len(targets) | |
| # attn_mask_infere = torch.zeros( | |
| # bs, | |
| # self.nheads, | |
| # self.num_group * num_points, | |
| # self.num_group * num_points, | |
| # device=device, | |
| # dtype=torch.bool) | |
| # group_bbox_kpt = num_points | |
| # group_nobbox_kpt = self.num_body_points | |
| # kpt_index = [ | |
| # x for x in range(self.num_group * num_points) | |
| # if x % num_points == 0 | |
| # ] | |
| # for matchj in range(self.num_group * num_points): | |
| # sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| # ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt | |
| # if sj > 0: | |
| # attn_mask_infere[:, :, matchj, :sj] = True | |
| # if ej < self.num_group * num_points: | |
| # attn_mask_infere[:, :, matchj, ej:] = True | |
| # for match_x in range(self.num_group * num_points): | |
| # if match_x % group_bbox_kpt == 0: | |
| # attn_mask_infere[:, :, match_x, kpt_index] = False | |
| # attn_mask_infere = attn_mask_infere.flatten(0, 1) | |
| attn_mask2 = attn_mask2.flatten(0, 1) | |
| attn_mask3 = attn_mask3.flatten(0, 1) | |
| return None, None, None, attn_mask2, attn_mask3, None | |
| # targets, dn_scalar, noise_scale = dn_args | |
| device = targets[0]['boxes'].device | |
| bs = len(targets) | |
| dn_number = self.dn_number # 100 | |
| dn_box_noise_scale = self.dn_box_noise_scale # 0.4 | |
| dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5 | |
| # gather gt boxes and labels | |
| gt_boxes = [t['boxes'] for t in targets] | |
| gt_labels = [t['labels'] for t in targets] | |
| gt_keypoints = [t['keypoints'] for t in targets] | |
| # repeat them | |
| def get_indices_for_repeat(now_num, target_num, device='cuda'): | |
| """ | |
| Input: | |
| - now_num: int | |
| - target_num: int | |
| Output: | |
| - indices: tensor[target_num] | |
| """ | |
| out_indice = [] | |
| base_indice = torch.arange(now_num).to(device) | |
| multiplier = target_num // now_num | |
| out_indice.append(base_indice.repeat(multiplier)) | |
| residue = target_num % now_num | |
| out_indice.append(base_indice[torch.randint(0, | |
| now_num, (residue, ), | |
| device=device)]) | |
| return torch.cat(out_indice) | |
| if self.dn_batch_gt_fuse: | |
| raise NotImplementedError | |
| gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4 | |
| gt_labels_bsall = torch.cat(gt_labels) | |
| num_gt_bsall = gt_boxes_bsall.shape[0] | |
| if num_gt_bsall > 0: | |
| indices = get_indices_for_repeat(num_gt_bsall, dn_number, | |
| device) | |
| gt_boxes_expand = gt_boxes_bsall[indices][None].repeat( | |
| bs, 1, 1) # bs, num_dn, 4 | |
| gt_labels_expand = gt_labels_bsall[indices][None].repeat( | |
| bs, 1) # bs, num_dn | |
| else: | |
| # all negative samples when no gt boxes | |
| gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device) | |
| gt_labels_expand = torch.ones( | |
| bs, dn_number, dtype=torch.int64, device=device) * int( | |
| self.num_classes) | |
| else: | |
| gt_boxes_expand = [] | |
| gt_labels_expand = [] | |
| gt_keypoints_expand = [] # here | |
| for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( | |
| zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id | |
| num_gt_i = gt_boxes_i.shape[0] # instance num | |
| if num_gt_i > 0: | |
| indices = get_indices_for_repeat(num_gt_i, dn_number, | |
| device) | |
| gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 | |
| gt_labels_expand_i = gt_labels_i[indices] # add smpl | |
| gt_keypoints_expand_i = gt_keypoint_i[indices] | |
| else: | |
| # all negative samples when no gt boxes | |
| gt_boxes_expand_i = torch.rand(dn_number, 4, device=device) | |
| gt_labels_expand_i = torch.ones( | |
| dn_number, dtype=torch.int64, device=device) * int( | |
| self.num_classes) | |
| gt_keypoints_expand_i = torch.rand(dn_number, | |
| self.num_body_points * | |
| 3, | |
| device=device) | |
| gt_boxes_expand.append(gt_boxes_expand_i) # add smpl | |
| gt_labels_expand.append(gt_labels_expand_i) | |
| gt_keypoints_expand.append(gt_keypoints_expand_i) | |
| gt_boxes_expand = torch.stack(gt_boxes_expand) | |
| gt_labels_expand = torch.stack(gt_labels_expand) | |
| gt_keypoints_expand = torch.stack(gt_keypoints_expand) | |
| knwon_boxes_expand = gt_boxes_expand.clone() | |
| knwon_labels_expand = gt_labels_expand.clone() | |
| # add noise | |
| if dn_label_noise_ratio > 0: | |
| prob = torch.rand_like(knwon_labels_expand.float()) | |
| chosen_indice = prob < dn_label_noise_ratio | |
| new_label = torch.randint_like( | |
| knwon_labels_expand[chosen_indice], 0, | |
| self.dn_labelbook_size) # randomly put a new one here | |
| knwon_labels_expand[chosen_indice] = new_label | |
| if dn_box_noise_scale > 0: | |
| diff = torch.zeros_like(knwon_boxes_expand) | |
| diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 | |
| diff[..., 2:] = knwon_boxes_expand[..., 2:] | |
| knwon_boxes_expand += torch.mul( | |
| (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), | |
| diff) * dn_box_noise_scale | |
| knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) | |
| input_query_label = self.label_enc(knwon_labels_expand) | |
| input_query_bbox = inverse_sigmoid(knwon_boxes_expand) | |
| # prepare mask | |
| body_mask, body_kps_mask, lhand_mask, lhand_kps_mask, rhand_mask, \ | |
| rhand_kps_mask, face_mask, face_kps_mask = \ | |
| False, False, False, False, False, False, False, False | |
| if random.random() < 0.2: | |
| body_mask = True | |
| if random.random() < 0.5: | |
| body_kps_mask = True | |
| if random.random() < 0.2: | |
| lhand_mask = True | |
| if random.random() < 0.5: | |
| lhand_kps_mask = True | |
| if random.random() < 0.2: | |
| rhand_mask = True | |
| if random.random() < 0.5: | |
| rhand_kps_mask = True | |
| if random.random() < 0.2: | |
| face_mask = True | |
| if random.random() < 0.5: | |
| face_kps_mask = True | |
| if 'group2group' in self.dn_attn_mask_type_list: | |
| attn_mask = torch.zeros(bs, | |
| self.nheads, | |
| dn_number + self.num_queries, | |
| dn_number + self.num_queries, | |
| device=device, | |
| dtype=torch.bool) | |
| attn_mask[:, :, dn_number:, :dn_number] = True | |
| for idx, (gt_boxes_i, | |
| gt_labels_i) in enumerate(zip(gt_boxes, | |
| gt_labels)): # for batch | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(dn_number): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask[idx, :, matchi, :si] = True | |
| if ei < dn_number: | |
| attn_mask[idx, :, matchi, ei:dn_number] = True | |
| attn_mask = attn_mask.flatten(0, 1) | |
| if 'group2group' in self.dn_attn_mask_type_list: | |
| # self.num_body_points = self.num_body_points +3 | |
| inter_body_mask = [] | |
| if body_mask: | |
| inter_body_mask.append(0) | |
| if body_kps_mask: | |
| indices = sorted(random.sample(range(1, self.num_body_points+1), k=6)) | |
| inter_body_mask.extend(indices) | |
| if lhand_mask: | |
| inter_body_mask.append(self.num_body_points+1) | |
| if rhand_mask: | |
| inter_body_mask.append(self.num_body_points+2) | |
| if face_mask: | |
| inter_body_mask.append(self.num_body_points+3) | |
| num_points = self.num_body_points + 4 | |
| attn_mask2 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| dn_number + self.num_group * num_points, | |
| dn_number + self.num_group * num_points, | |
| device=device, | |
| dtype=torch.bool) | |
| attn_mask2[:, :, dn_number:, :dn_number] = True | |
| group_bbox_kpt = num_points | |
| # group_nobbox_kpt = self.num_body_points | |
| kpt_index = [x for x in range(self.num_group * num_points) | |
| if x % num_points in [ | |
| 0, self.num_body_points+1, self.num_body_points+2, self.num_body_points+3]] | |
| for matchj in range(self.num_group * num_points): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| if sj > 0: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
| if ej < self.num_group * num_points: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
| if (matchj // group_bbox_kpt) == 0: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, sj:ej][..., inter_body_mask] = True | |
| for match_x in range(self.num_group * num_points): | |
| if match_x % group_bbox_kpt == 0 and body_mask != False: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, ::num_points]=False | |
| if match_x % group_bbox_kpt == self.num_body_points + 1 and lhand_mask != False: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 1::num_points]=False | |
| if match_x % group_bbox_kpt == self.num_body_points + 2 and rhand_mask != False: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 2::num_points]=False | |
| if match_x % group_bbox_kpt == self.num_body_points + 3 and face_mask != False: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, 3::num_points]=False | |
| # if match_x % group_bbox_kpt in [0, | |
| # self.num_body_points+1, | |
| # self.num_body_points+2, | |
| # self.num_body_points+3]: | |
| # attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False | |
| for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(dn_number): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask2[idx, :, matchi, :si] = True | |
| if ei < dn_number: | |
| attn_mask2[idx, :, matchi, ei:dn_number] = True | |
| attn_mask2 = attn_mask2.flatten(0, 1) | |
| if 'group2group' in self.dn_attn_mask_type_list: | |
| inter_body_mask = [] | |
| if body_mask: | |
| inter_body_mask.append(0) | |
| if body_kps_mask: | |
| indices = sorted(random.sample(range(1, self.num_body_points+1), k=6)) | |
| inter_body_mask.extend(indices) | |
| if lhand_mask: | |
| inter_body_mask.append(self.num_body_points+1) | |
| if lhand_kps_mask: | |
| indices = sorted(random.sample(range(self.num_body_points+2, self.num_body_points+8), k=3)) | |
| inter_body_mask.extend(indices) | |
| if rhand_mask: | |
| inter_body_mask.append(self.num_body_points+8) | |
| if rhand_kps_mask: | |
| indices = sorted(random.sample(range(self.num_body_points+9, self.num_body_points+15), k=3)) | |
| inter_body_mask.extend(indices) | |
| if face_mask: | |
| inter_body_mask.append(self.num_body_points+15) | |
| if face_kps_mask: | |
| indices = sorted(random.sample(range(self.num_body_points+16, self.num_body_points+22), k=3) ) | |
| inter_body_mask.extend(indices) | |
| # self.num_body_points = self.num_body_points +3 | |
| num_points = self.num_whole_body_points + 4 | |
| attn_mask3 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| dn_number + self.num_group * (num_points), dn_number + self.num_group * (num_points), | |
| device=device, dtype=torch.bool) | |
| attn_mask3[:, :, dn_number:, :dn_number] = True | |
| group_bbox_kpt = (num_points) | |
| # group_nobbox_kpt = self.num_body_points | |
| kpt_index = [ | |
| x for x in range(self.num_group * (num_points)) if x % (num_points) in | |
| [0, | |
| 1+self.num_body_points, | |
| 2+self.num_body_points+self.num_hand_points, | |
| 3+self.num_body_points+self.num_hand_points*2 | |
| ] | |
| ] | |
| for matchj in range(self.num_group * num_points): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| if sj > 0: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
| if ej < self.num_group * num_points: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
| if (matchj // group_bbox_kpt) == 0: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, sj:ej][..., inter_body_mask] = True | |
| for match_x in range(self.num_group * num_points): | |
| if match_x % group_bbox_kpt == 0 and body_mask != False: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, ::num_points]=False | |
| if match_x % group_bbox_kpt == 1 + self.num_body_points and lhand_mask != False: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 1::num_points]=False | |
| if match_x % group_bbox_kpt == 2 + self.num_body_points + self.num_hand_points and rhand_mask != False: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 2::num_points]=False | |
| if match_x % group_bbox_kpt == 3 + self.num_body_points + self.num_hand_points * 2 and face_mask != False: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, 3::num_points]=False | |
| # if match_x % group_bbox_kpt in [0, | |
| # 1 + self.num_body_points, | |
| # 2 + self.num_body_points + self.num_hand_points, | |
| # 3 + self.num_body_points + self.num_hand_points * 2]: | |
| # attn_mask3[:, :, dn_number:, dn_number:][:,:,match_x,kpt_index]=False | |
| for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(dn_number): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask3[idx, :, matchi, :si] = True | |
| if ei < dn_number: | |
| attn_mask3[idx, :, matchi, ei:dn_number] = True | |
| attn_mask3 = attn_mask3.flatten(0, 1) | |
| mask_dict = { | |
| 'pad_size': dn_number, | |
| 'known_bboxs': gt_boxes_expand, | |
| 'known_labels': gt_labels_expand, | |
| 'known_keypoints': gt_keypoints_expand | |
| } | |
| return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict | |
| def dn_post_process2(self, outputs_class, outputs_coord, | |
| outputs_body_keypoints_list, mask_dict): | |
| if mask_dict and mask_dict['pad_size'] > 0: | |
| output_known_class = [ | |
| outputs_class_i[:, :mask_dict['pad_size'], :] | |
| for outputs_class_i in outputs_class | |
| ] | |
| output_known_coord = [ | |
| outputs_coord_i[:, :mask_dict['pad_size'], :] | |
| for outputs_coord_i in outputs_coord | |
| ] | |
| outputs_class = [ | |
| outputs_class_i[:, mask_dict['pad_size']:, :] | |
| for outputs_class_i in outputs_class | |
| ] | |
| outputs_coord = [ | |
| outputs_coord_i[:, mask_dict['pad_size']:, :] | |
| for outputs_coord_i in outputs_coord | |
| ] | |
| outputs_keypoint = outputs_body_keypoints_list | |
| mask_dict.update({ | |
| 'output_known_coord': output_known_coord, | |
| 'output_known_class': output_known_class | |
| }) | |
| return outputs_class, outputs_coord, outputs_keypoint | |
| def forward(self, data_batch: NestedTensor, targets: List = None): | |
| """The forward expects a NestedTensor, which consists of: | |
| - samples.tensor: batched images, of shape [batch_size x 3 x H x W] | |
| - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
| It returns a dict with the following elements: | |
| - "pred_logits": the classification logits (including no-object) for all queries. | |
| Shape= [batch_size x num_queries x num_classes] | |
| - "pred_boxes": The normalized boxes coordinates for all queries, represented as | |
| (center_x, center_y, width, height). These values are normalized in [0, 1], | |
| relative to the size of each individual image (disregarding possible padding). | |
| See PostProcess for information on how to retrieve the unnormalized bounding box. | |
| - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
| dictionnaries containing the two above keys for each decoder layer. | |
| """ | |
| if isinstance(data_batch, dict): | |
| samples, targets = self.prepare_targets(data_batch) | |
| # import pdb; pdb.set_trace() | |
| elif isinstance(data_batch, (list, torch.Tensor)): | |
| samples = nested_tensor_from_tensor_list(data_batch) | |
| else: | |
| samples = data_batch | |
| # print(samples.data['img'].shape) | |
| # exit() | |
| features, poss = self.backbone(samples) | |
| srcs = [] | |
| masks = [] | |
| for l, feat in enumerate(features): # len(features=3) | |
| src, mask = feat.decompose() | |
| srcs.append(self.input_proj[l](src)) | |
| masks.append(mask) | |
| assert mask is not None | |
| if self.num_feature_levels > len(srcs): | |
| _len_srcs = len(srcs) | |
| for l in range(_len_srcs, self.num_feature_levels): | |
| if l == _len_srcs: | |
| src = self.input_proj[l](features[-1].tensors) | |
| else: | |
| src = self.input_proj[l](srcs[-1]) | |
| m = samples.mask | |
| mask = F.interpolate(m[None].float(), | |
| size=src.shape[-2:]).to(torch.bool)[0] | |
| pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| srcs.append(src) | |
| masks.append(mask) | |
| poss.append(pos_l) | |
| if self.dn_number > 0 or targets is not None: | |
| input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\ | |
| self.prepare_for_dn2(targets) | |
| else: | |
| assert targets is None | |
| input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None | |
| hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( | |
| srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, | |
| attn_mask2, attn_mask3) | |
| # update human boxes | |
| effective_dn_number = self.dn_number if self.training else 0 | |
| outputs_body_bbox_list = [] | |
| outputs_class = [] | |
| for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed, | |
| layer_hs) in enumerate( | |
| zip(reference[:-1], self.bbox_embed, | |
| self.class_embed, hs)): | |
| if dec_lid < self.num_box_decoder_layers: | |
| # human det | |
| layer_delta_unsig = layer_body_bbox_embed(layer_hs) | |
| layer_body_box_outputs_unsig = \ | |
| layer_delta_unsig + inverse_sigmoid(layer_ref_sig) | |
| layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid() | |
| layer_cls = layer_cls_embed(layer_hs) | |
| # import mmcv | |
| # import cv2 | |
| # img = (data_batch['img'][0]*255).permute(1,2,0).int().detach().cpu().numpy() | |
| # bbox = (box_ops.box_cxcywh_to_xyxy(layer_body_box_outputs_unsig[0][0]).reshape(2,2).detach().cpu().numpy()*data_batch['img_shape'].cpu().numpy()[0, ::-1]).reshape(1,4) | |
| # img = mmcv.imshow_bboxes(img.copy(), bbox, show=False) | |
| # cv2.imwrite('test.png',img) | |
| outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
| outputs_class.append(layer_cls) | |
| elif dec_lid < self.num_box_decoder_layers + 2: | |
| bs = layer_ref_sig.shape[0] | |
| # dn body bbox | |
| layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
| reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
| layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
| # norm body bbox | |
| layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
| :, 0::(self.num_body_points + 4), :] # norm content query | |
| reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][ | |
| :, 0::(self.num_body_points+ 4), :] # norm position query | |
| layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
| layer_body_box_outputs_unsig = torch.cat( | |
| (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
| # classfication | |
| layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
| layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
| layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
| outputs_class.append(layer_cls) | |
| outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
| else: | |
| bs = layer_ref_sig.shape[0] | |
| # dn body bbox | |
| layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
| reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
| layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
| # norm body bbox | |
| layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
| :, 0::(self.num_whole_body_points + 4), :] # norm content query | |
| reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][ | |
| :, 0::(self.num_whole_body_points + 4), :] # norm position query | |
| layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
| layer_body_box_outputs_unsig = torch.cat( | |
| (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
| # classfication | |
| layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
| layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
| layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
| outputs_class.append(layer_cls) | |
| outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
| # 找query | |
| q_index = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[1] | |
| q_value = torch.topk(layer_cls_norm.max(-1)[0], 100, dim=1)[0] | |
| # update hand and face boxes | |
| outputs_lhand_bbox_list = [] | |
| outputs_rhand_bbox_list = [] | |
| outputs_face_bbox_list = [] | |
| # update keypoints boxes | |
| outputs_body_keypoints_list = [] | |
| outputs_body_keypoints_hw = [] | |
| outputs_lhand_keypoints_list = [] | |
| outputs_lhand_keypoints_hw = [] | |
| outputs_rhand_keypoints_list = [] | |
| outputs_rhand_keypoints_hw = [] | |
| outputs_face_keypoints_list = [] | |
| outputs_face_keypoints_hw = [] | |
| outputs_smpl_pose_list = [] | |
| outputs_smpl_lhand_pose_list = [] | |
| outputs_smpl_rhand_pose_list = [] | |
| outputs_smpl_expr_list = [] | |
| outputs_smpl_jaw_pose_list = [] | |
| outputs_smpl_beta_list = [] | |
| outputs_smpl_cam_list = [] | |
| # outputs_smpl_cam_f_list = [] | |
| outputs_smpl_kp2d_list = [] | |
| outputs_smpl_kp3d_list = [] | |
| outputs_smpl_verts_list = [] | |
| body_kpt_index = [ | |
| x for x in range(self.num_group * (self.num_body_points + 4)) | |
| if x % (self.num_body_points + 4) in range(1,self.num_body_points+1) | |
| ] | |
| body_kpt_index_2 = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in range(1,self.num_body_points+1)) | |
| ] | |
| lhand_bbox_index = [ | |
| x for x in range(self.num_group * (self.num_body_points + 4)) | |
| if x % (self.num_body_points + 4) != 1 | |
| ] | |
| lhand_kpt_index = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in range( | |
| self.num_body_points+2, | |
| self.num_body_points+self.num_hand_points+2))] | |
| rhand_bbox_index = [ | |
| x for x in range(self.num_group * (self.num_body_points + 4)) | |
| if x % (self.num_body_points + 4) != 2 | |
| ] | |
| rhand_kpt_index = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in range( | |
| self.num_body_points+self.num_hand_points+3, | |
| self.num_body_points+self.num_hand_points*2+3)) | |
| ] | |
| face_bbox_index = [ | |
| x for x in range(self.num_group * (self.num_body_points + 4)) | |
| if x % (self.num_body_points + 4) != 3 | |
| ] | |
| face_kpt_index = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in range( | |
| self.num_body_points+self.num_hand_points*2+4, | |
| self.num_body_points+self.num_hand_points*2+self.num_face_points+4)) | |
| ] | |
| # smpl pose | |
| # body box, kps, lhand box | |
| body_index = list(range(0,self.num_body_points+2)) | |
| # rhand box and face box | |
| body_index.extend( | |
| [self.num_body_points + self.num_hand_points + 2, self.num_body_points + 2 * self.num_hand_points + 3] | |
| ) | |
| smpl_pose_index = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in body_index) | |
| ] | |
| # smpl lhand | |
| lhand_index = list(range(self.num_body_points+1, self.num_body_points+self.num_hand_points+3)) | |
| # body box | |
| lhand_index.insert(0, 0) | |
| smpl_lhand_pose_index = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in lhand_index)] | |
| # smpl rhand | |
| rhand_index = list(range(self.num_body_points + self.num_hand_points + 2, self.num_body_points + self.num_hand_points * 2 +3)) | |
| rhand_index.insert(0,self.num_body_points+1) | |
| rhand_index.insert(0,0) | |
| smpl_rhand_pose_index = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in rhand_index)] | |
| # smpl face | |
| face_index = list(range(self.num_body_points + self.num_hand_points * 2 + 3, self.num_body_points + self.num_hand_points * 2 + self.num_face_points + 4)) | |
| face_index.insert(0,0) | |
| smpl_face_pose_index = [ | |
| x for x in range(self.num_group * (self.num_whole_body_points + 4)) | |
| if (x % (self.num_whole_body_points + 4) in face_index)] | |
| for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): | |
| if dec_lid < self.num_box_decoder_layers: | |
| assert isinstance(layer_hs, torch.Tensor) | |
| bs = layer_hs.shape[0] | |
| layer_body_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_body_points * 3)) # [-, 900, 42] | |
| outputs_body_keypoints_list.append(layer_body_kps_res) | |
| # lhand | |
| layer_lhand_bbox_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, 4)) # [-, 900, 42] | |
| outputs_lhand_bbox_list.append(layer_lhand_bbox_res) | |
| layer_lhand_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_hand_points * 3)) # [-, 900, 42] | |
| outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
| # rhand | |
| layer_rhand_bbox_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, 4)) # [-, 900, 42] | |
| outputs_rhand_bbox_list.append(layer_rhand_bbox_res) | |
| layer_rhand_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_hand_points * 3)) # [-, 900, 42] | |
| outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
| # face | |
| layer_face_bbox_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, 4)) # [-, 900, 42] | |
| outputs_face_bbox_list.append(layer_face_bbox_res) | |
| layer_face_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_face_points * 3)) # [-, 900, 42] | |
| outputs_face_keypoints_list.append(layer_face_kps_res) | |
| # smpl or smplx | |
| smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3)) | |
| smpl_rhand_pose = layer_hs.new_zeros( | |
| (bs, self.num_queries, 15 * 3)) | |
| smpl_lhand_pose = layer_hs.new_zeros( | |
| (bs, self.num_queries, 15 * 3)) | |
| smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
| smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6)) | |
| smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
| smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3)) | |
| # smpl_cam_f = layer_hs.new_zeros((bs, self.num_queries, 1)) | |
| # smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3)) | |
| smpl_kp3d = layer_hs.new_zeros( | |
| (bs, self.num_queries, self.num_body_points, 4)) | |
| outputs_smpl_pose_list.append(smpl_pose) | |
| outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
| outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
| outputs_smpl_expr_list.append(smpl_expr) | |
| outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
| outputs_smpl_beta_list.append(smpl_beta) | |
| outputs_smpl_cam_list.append(smpl_cam) | |
| # outputs_smpl_cam_f_list.append(smpl_cam_f) | |
| # outputs_smpl_kp2d_list.append(smpl_kp2d) | |
| outputs_smpl_kp3d_list.append(smpl_kp3d) | |
| elif dec_lid < self.num_box_decoder_layers +2: | |
| bs = layer_ref_sig.shape[0] | |
| layer_hs_body_kpt = \ | |
| layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(body_kpt_index, device=layer_hs.device)) | |
| # body kp2d | |
| delta_body_kp_xy_unsig = \ | |
| self.pose_embed[dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt) | |
| layer_ref_sig_body_kpt = \ | |
| layer_ref_sig[:,effective_dn_number:, :].index_select(1,torch.tensor(body_kpt_index,device=layer_hs.device)) | |
| layer_outputs_unsig_body_keypoints = delta_body_kp_xy_unsig + inverse_sigmoid( | |
| layer_ref_sig_body_kpt[..., :2]) | |
| vis_xy_unsig = torch.ones_like( | |
| layer_outputs_unsig_body_keypoints, | |
| device=layer_outputs_unsig_body_keypoints.device) | |
| xyv = torch.cat((layer_outputs_unsig_body_keypoints, | |
| vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
| dim=-1) | |
| xyv = xyv.sigmoid() | |
| # from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
| # img =(data_batch['img'][0].permute(1,2,0)*255).int().cpu().numpy() | |
| # gt_kp2d = xyv[0,:17] | |
| # coco_kps = gt_kp2d[:,:2].reshape(17,2).detach().cpu().numpy() * data_batch['img_shape'].cpu().numpy()[0,None,None,::-1] | |
| # visualize_kp2d( | |
| # coco_kps, | |
| # output_path='.', | |
| # image_array=img.copy()[None], | |
| # data_source='coco', | |
| # overwrite=True) | |
| layer_res = xyv.reshape( | |
| (bs, self.num_group, self.num_body_points, | |
| 3)).flatten(2, 3) | |
| layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape( | |
| bs, self.num_group, self.num_body_points, 2).flatten(2, 3) | |
| layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) | |
| outputs_body_keypoints_list.append(layer_res) | |
| outputs_body_keypoints_hw.append(layer_hw) | |
| # lhand bbox | |
| layer_hs_lhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][:, (self.num_body_points + 1)::(self.num_body_points + 4), :] | |
| delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_sig_lhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][ | |
| :, (self.num_body_points + 1)::(self.num_body_points + 4), :].clone() | |
| layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
| delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
| layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
| outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
| layer_lhand_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_hand_points * 3)) # [-, 900, 42] | |
| outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
| # rhand bbox | |
| layer_hs_rhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + 2)::(self.num_body_points + 4), :] | |
| delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_sig_rhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][ | |
| :, (self.num_body_points + 2)::(self.num_body_points + 4), :].clone() | |
| layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
| delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
| layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
| outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
| # rhand kps | |
| layer_rhand_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_hand_points * 3)) # [-, 900, 42] | |
| outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
| # face bbox | |
| layer_hs_face_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + 3)::(self.num_body_points + 4), :] | |
| delta_face_bbox_xy_unsig = self.bbox_face_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_sig_face_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][ | |
| :, (self.num_body_points + 3)::(self.num_body_points + 4), :].clone() | |
| layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
| delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
| layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
| outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
| # face kps | |
| layer_face_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_face_points * 3)) # [-, 900, 42] | |
| outputs_face_keypoints_list.append(layer_face_kps_res) | |
| # smpl or smplx | |
| bs, _, feat_dim = layer_hs.shape | |
| smpl_feats = layer_hs[:, effective_dn_number:, :].reshape( | |
| bs, -1, feat_dim * (self.num_body_points + 4)) | |
| smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + 1):: (self.num_body_points + 4), :].reshape( | |
| bs, -1, feat_dim) | |
| smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + 2):: (self.num_body_points + 4), :].reshape( | |
| bs, -1, feat_dim) | |
| smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + 3):: (self.num_body_points + 4), :].reshape( | |
| bs, -1, feat_dim) | |
| smpl_pose = self.smpl_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_feats) | |
| smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, self.body_model_joint_num, 3, 3) | |
| smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
| smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
| smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_jaw_pose = self.smpl_jaw_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, 1, 3, 3) | |
| smpl_beta = self.smpl_beta_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_feats) | |
| smpl_cam = self.smpl_cam_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_feats) | |
| # smpl_cam_f = self.smpl_cam_f_embed[ | |
| # dec_lid - self.num_box_decoder_layers](smpl_feats) | |
| # zero | |
| # smpl_lhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3) | |
| # smpl_rhand_pose = layer_hs.new_zeros(bs, self.num_group, 15, 3, 3) | |
| # smpl_expr = layer_hs.new_zeros(bs, self.num_group, 10) | |
| smpl_expr = self.smpl_expr_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| # smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3) | |
| leye_pose = torch.zeros_like(smpl_jaw_pose) | |
| reye_pose = torch.zeros_like(smpl_jaw_pose) | |
| if self.body_model is not None: | |
| smpl_pose_ = rotmat_to_aa(smpl_pose) | |
| # smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
| # smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
| smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
| smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
| smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
| leye_pose_ = rotmat_to_aa(leye_pose) | |
| reye_pose_ = rotmat_to_aa(reye_pose) | |
| pred_output = self.body_model( | |
| betas=smpl_beta.reshape(-1, 10), | |
| body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
| global_orient=smpl_pose_[:, :, 0].reshape( | |
| -1, 3).unsqueeze(1), | |
| left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
| right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
| leye_pose=leye_pose_, | |
| reye_pose=reye_pose_, | |
| jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
| # expression=smpl_expr.reshape(-1, 10), | |
| expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10) | |
| ) | |
| smpl_kp3d = pred_output['joints'].reshape( | |
| bs, self.num_group, -1, 3) | |
| smpl_verts = pred_output['vertices'].reshape( | |
| bs, self.num_group, -1, 3) | |
| # pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) | |
| outputs_smpl_pose_list.append(smpl_pose) | |
| outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
| outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
| outputs_smpl_expr_list.append(smpl_expr) | |
| outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
| outputs_smpl_beta_list.append(smpl_beta) | |
| outputs_smpl_cam_list.append(smpl_cam) | |
| # outputs_smpl_cam_f_list.append(smpl_cam_f) | |
| outputs_smpl_kp3d_list.append(smpl_kp3d) | |
| else: | |
| bs = layer_ref_sig.shape[0] | |
| layer_hs_body_kpt = \ | |
| layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(body_kpt_index_2, device=layer_hs.device)) | |
| # body kp2d | |
| delta_body_kp_xy_unsig = \ | |
| self.pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_body_kpt) | |
| layer_ref_sig_body_kpt = \ | |
| layer_ref_sig[:,effective_dn_number:, :].index_select( | |
| 1,torch.tensor(body_kpt_index_2,device=layer_hs.device)) | |
| layer_outputs_unsig_body_keypoints = \ | |
| delta_body_kp_xy_unsig + inverse_sigmoid( | |
| layer_ref_sig_body_kpt[..., :2]) | |
| vis_xy_unsig = torch.ones_like( | |
| layer_outputs_unsig_body_keypoints, | |
| device=layer_outputs_unsig_body_keypoints.device) | |
| xyv = torch.cat((layer_outputs_unsig_body_keypoints, | |
| vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
| dim=-1) | |
| xyv = xyv.sigmoid() | |
| layer_res = xyv.reshape( | |
| (bs, self.num_group, self.num_body_points, | |
| 3)).flatten(2, 3) | |
| layer_hw = layer_ref_sig_body_kpt[..., 2:].reshape( | |
| bs, self.num_group, self.num_body_points, 2).flatten(2, 3) | |
| layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) | |
| outputs_body_keypoints_list.append(layer_res) | |
| outputs_body_keypoints_hw.append(layer_hw) | |
| # lhand bbox | |
| layer_hs_lhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :] | |
| delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_sig_lhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][ | |
| :, (self.num_body_points + 1)::(self.num_whole_body_points + 4), :].clone() | |
| layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
| delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
| layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
| outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
| # lhand kps | |
| layer_hs_lhand_kps_res = \ | |
| layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(lhand_kpt_index, device=layer_hs.device)) | |
| delta_lhand_kp_xy_unsig = \ | |
| self.pose_hand_embed[ | |
| dec_lid - self.num_hand_face_decoder_layers](layer_hs_lhand_kps_res) | |
| layer_ref_sig_lhand_kpt = \ | |
| layer_ref_sig[:,effective_dn_number:, :].index_select( | |
| 1,torch.tensor(lhand_kpt_index,device=layer_hs.device)) | |
| layer_outputs_unsig_lhand_keypoints = delta_lhand_kp_xy_unsig + inverse_sigmoid( | |
| layer_ref_sig_lhand_kpt[..., :2]) | |
| lhand_vis_xy_unsig = torch.ones_like( | |
| layer_outputs_unsig_lhand_keypoints, | |
| device=layer_outputs_unsig_lhand_keypoints.device) | |
| lhand_xyv = torch.cat((layer_outputs_unsig_lhand_keypoints, | |
| lhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
| dim=-1) | |
| lhand_xyv = lhand_xyv.sigmoid() | |
| layer_lhand_kps_res = lhand_xyv.reshape( | |
| (bs, self.num_group, self.num_hand_points, | |
| 3)).flatten(2, 3) | |
| layer_lhand_hw = layer_ref_sig_lhand_kpt[..., 2:].reshape( | |
| bs, self.num_group, self.num_hand_points, 2).flatten(2, 3) | |
| layer_lhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_lhand_kps_res) | |
| outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
| outputs_lhand_keypoints_hw.append(layer_lhand_hw) | |
| # rhand bbox | |
| layer_hs_rhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :] | |
| delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_sig_rhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][ | |
| :, (self.num_body_points + self.num_hand_points + 2)::(self.num_whole_body_points + 4), :].clone() | |
| layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
| delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
| layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
| outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
| # rhand kps | |
| layer_hs_rhand_kps_res = \ | |
| layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(rhand_kpt_index, device=layer_hs.device)) | |
| delta_rhand_kp_xy_unsig = \ | |
| self.pose_hand_embed[ | |
| dec_lid - self.num_hand_face_decoder_layers](layer_hs_rhand_kps_res) | |
| layer_ref_sig_rhand_kpt = \ | |
| layer_ref_sig[:,effective_dn_number:, :].index_select( | |
| 1,torch.tensor(rhand_kpt_index,device=layer_hs.device)) | |
| layer_outputs_unsig_rhand_keypoints = delta_rhand_kp_xy_unsig + inverse_sigmoid( | |
| layer_ref_sig_rhand_kpt[..., :2]) | |
| rhand_vis_xy_unsig = torch.ones_like( | |
| layer_outputs_unsig_rhand_keypoints, | |
| device=layer_outputs_unsig_rhand_keypoints.device) | |
| rhand_xyv = torch.cat((layer_outputs_unsig_rhand_keypoints, | |
| rhand_vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
| dim=-1) | |
| rhand_xyv = rhand_xyv.sigmoid() | |
| layer_rhand_kps_res = rhand_xyv.reshape( | |
| (bs, self.num_group, self.num_hand_points, | |
| 3)).flatten(2, 3) | |
| layer_rhand_hw = layer_ref_sig_rhand_kpt[..., 2:].reshape( | |
| bs, self.num_group, self.num_hand_points, 2).flatten(2, 3) | |
| layer_rhand_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_rhand_kps_res) | |
| outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
| outputs_rhand_keypoints_hw.append(layer_rhand_hw) | |
| # face bbox | |
| layer_hs_face_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][ | |
| :, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :] | |
| delta_face_bbox_xy_unsig = self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_sig_face_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][ | |
| :, (self.num_body_points + 2 * self.num_hand_points + 3)::(self.num_whole_body_points + 4), :].clone() | |
| layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
| delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
| layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
| outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
| # face kps | |
| layer_hs_face_kps_res = \ | |
| layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(face_kpt_index, device=layer_hs.device)) | |
| delta_face_kp_xy_unsig = \ | |
| self.pose_face_embed[ | |
| dec_lid - self.num_hand_face_decoder_layers](layer_hs_face_kps_res) | |
| layer_ref_sig_face_kpt = \ | |
| layer_ref_sig[:,effective_dn_number:, :].index_select( | |
| 1,torch.tensor(face_kpt_index,device=layer_hs.device)) | |
| layer_outputs_unsig_face_keypoints = delta_face_kp_xy_unsig + inverse_sigmoid( | |
| layer_ref_sig_face_kpt[..., :2]) | |
| face_vis_xy_unsig = torch.ones_like( | |
| layer_outputs_unsig_face_keypoints, | |
| device=layer_outputs_unsig_face_keypoints.device) | |
| face_xyv = torch.cat((layer_outputs_unsig_face_keypoints, | |
| face_vis_xy_unsig[:, :, 0].unsqueeze(-1)), | |
| dim=-1) | |
| face_xyv = face_xyv.sigmoid() | |
| layer_face_kps_res = face_xyv.reshape( | |
| (bs, self.num_group, self.num_face_points, | |
| 3)).flatten(2, 3) | |
| layer_face_hw = layer_ref_sig_face_kpt[..., 2:].reshape( | |
| bs, self.num_group, self.num_face_points, 2).flatten(2, 3) | |
| layer_face_kps_res = keypoint_xyzxyz_to_xyxyzz(layer_face_kps_res) | |
| outputs_face_keypoints_list.append(layer_face_kps_res) | |
| outputs_face_keypoints_hw.append(layer_face_hw) | |
| # pdb.set_trace() | |
| bs, _, feat_dim = layer_hs.shape | |
| smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * (self.num_body_points + 4)) | |
| smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * (self.num_hand_points + 3)) | |
| smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * (self.num_hand_points + 3)) | |
| smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * (self.num_face_points + 2)) | |
| smpl_pose = self.smpl_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
| smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, self.body_model_joint_num, 3, 3) | |
| smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
| smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
| smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_expr = self.smpl_expr_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| smpl_jaw_pose = self.smpl_jaw_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, 1, 3, 3) | |
| smpl_beta = self.smpl_beta_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
| smpl_cam = self.smpl_cam_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
| # smpl_cam_f = self.smpl_cam_f_embed[ | |
| # dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
| num_samples = smpl_beta.reshape(-1, 10).shape[0] | |
| device = smpl_beta.device | |
| leye_pose = torch.zeros_like(smpl_jaw_pose) | |
| reye_pose = torch.zeros_like(smpl_jaw_pose) | |
| if self.body_model is not None: | |
| # print(smpl_pose) | |
| # exit() | |
| smpl_pose_ = rotmat_to_aa(smpl_pose) | |
| smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
| smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
| smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
| leye_pose_ = rotmat_to_aa(leye_pose) | |
| reye_pose_ = rotmat_to_aa(reye_pose) | |
| pred_output = self.body_model( | |
| betas=smpl_beta.reshape(-1, 10), | |
| body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
| global_orient=smpl_pose_[:, :, 0].reshape( | |
| -1, 3).unsqueeze(1), | |
| left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
| right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
| leye_pose=leye_pose_, | |
| reye_pose=reye_pose_, | |
| jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
| expression=smpl_expr.reshape(-1, 10), | |
| # expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10), | |
| ) | |
| smpl_kp3d = pred_output['joints'].reshape( | |
| bs, self.num_group, -1, 3) | |
| smpl_verts = pred_output['vertices'].reshape( | |
| bs, self.num_group, -1, 3) | |
| # pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) | |
| # from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d | |
| # visualize_kp3d(smpl_kp3d[0,:100].detach().cpu().numpy(), | |
| # output_path='./figs/pred3d', | |
| # data_source='smplx_137') | |
| # import numpy as np | |
| # from pytorch3d.io import save_obj | |
| # save_obj( | |
| # '1.obj', | |
| # torch.tensor(pred_output['vertices'][0]), | |
| # torch.tensor(self.body_model.faces.astype(np.float))) | |
| # exit() | |
| outputs_smpl_pose_list.append(smpl_pose) | |
| outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
| outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
| outputs_smpl_expr_list.append(smpl_expr) | |
| outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
| outputs_smpl_beta_list.append(smpl_beta) | |
| outputs_smpl_cam_list.append(smpl_cam) | |
| # outputs_smpl_cam_f_list.append(smpl_cam_f) | |
| outputs_smpl_kp3d_list.append(smpl_kp3d) | |
| if not self.training: | |
| outputs_smpl_verts_list.append(smpl_verts) | |
| dn_mask_dict = mask_dict | |
| if self.dn_number > 0 and dn_mask_dict is not None: | |
| outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list = self.dn_post_process2( | |
| outputs_class, outputs_body_bbox_list, outputs_body_keypoints_list, | |
| dn_mask_dict) | |
| dn_class_input = dn_mask_dict['known_labels'] | |
| dn_bbox_input = dn_mask_dict['known_bboxs'] | |
| dn_class_pred = dn_mask_dict['output_known_class'] | |
| dn_bbox_pred = dn_mask_dict['output_known_coord'] | |
| for idx, (_out_class, _out_bbox, _out_keypoint) in enumerate( | |
| zip(outputs_class, outputs_body_bbox_list, | |
| outputs_body_keypoints_list)): | |
| assert _out_class.shape[1] == _out_bbox.shape[ | |
| 1] == _out_keypoint.shape[1] | |
| out = { | |
| 'pred_logits': outputs_class[-1], | |
| 'pred_boxes': outputs_body_bbox_list[-1], | |
| 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
| 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
| 'pred_face_boxes': outputs_face_bbox_list[-1], | |
| 'pred_keypoints': outputs_body_keypoints_list[-1], | |
| 'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1], | |
| 'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1], | |
| 'pred_face_keypoints': outputs_face_keypoints_list[-1], | |
| 'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
| 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
| 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
| 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
| 'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
| 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
| 'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
| # 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1], | |
| 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1] | |
| } | |
| if not self.training: | |
| full_pose = torch.cat((outputs_smpl_pose_list[-1], | |
| outputs_smpl_lhand_pose_list[-1], | |
| outputs_smpl_rhand_pose_list[-1], | |
| outputs_smpl_jaw_pose_list[-1]),dim=2) | |
| bs,num_q,_,_,_ = full_pose.shape | |
| full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3) | |
| out = { | |
| 'pred_logits': outputs_class[-1], | |
| 'pred_boxes': outputs_body_bbox_list[-1], | |
| 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
| 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
| 'pred_face_boxes': outputs_face_bbox_list[-1], | |
| 'pred_keypoints': outputs_body_keypoints_list[-1], | |
| 'pred_lhand_keypoints': outputs_lhand_keypoints_list[-1], | |
| 'pred_rhand_keypoints': outputs_rhand_keypoints_list[-1], | |
| 'pred_face_keypoints': outputs_face_keypoints_list[-1], | |
| 'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
| 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
| 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
| 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
| 'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
| 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
| 'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
| # 'pred_smpl_cam_f': outputs_smpl_cam_f_list[-1], | |
| 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1], | |
| 'pred_smpl_verts': outputs_smpl_verts_list[-1], | |
| 'pred_smpl_fullpose': full_pose | |
| } | |
| if self.dn_number > 0 and dn_mask_dict is not None: | |
| out.update({ | |
| 'dn_class_input': dn_class_input, | |
| 'dn_bbox_input': dn_bbox_input, | |
| 'dn_class_pred': dn_class_pred[-1], | |
| 'dn_bbox_pred': dn_bbox_pred[-1], | |
| 'num_tgt': dn_mask_dict['pad_size'] | |
| }) | |
| if self.aux_loss: | |
| out['aux_outputs'] = \ | |
| self._set_aux_loss( | |
| outputs_class, | |
| outputs_body_bbox_list, | |
| outputs_lhand_bbox_list, | |
| outputs_rhand_bbox_list, | |
| outputs_face_bbox_list, | |
| outputs_body_keypoints_list, | |
| outputs_lhand_keypoints_list, | |
| outputs_rhand_keypoints_list, | |
| outputs_face_keypoints_list, | |
| outputs_smpl_pose_list, | |
| outputs_smpl_rhand_pose_list, | |
| outputs_smpl_lhand_pose_list, | |
| outputs_smpl_jaw_pose_list, | |
| outputs_smpl_expr_list, | |
| outputs_smpl_beta_list, | |
| outputs_smpl_cam_list, | |
| # outputs_smpl_cam_f_list, | |
| outputs_smpl_kp3d_list | |
| ) # with key pred_logits, pred_bbox, pred_keypoints | |
| if self.dn_number > 0 and dn_mask_dict is not None: | |
| assert len(dn_class_pred[:-1]) == len( | |
| dn_bbox_pred[:-1]) == len(out['aux_outputs']) | |
| for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip( | |
| out['aux_outputs'], dn_class_pred, dn_bbox_pred): | |
| aux_out.update({ | |
| 'dn_class_input': dn_class_input, | |
| 'dn_bbox_input': dn_bbox_input, | |
| 'dn_class_pred': dn_class_pred_i, | |
| 'dn_bbox_pred': dn_bbox_pred_i, | |
| 'num_tgt': dn_mask_dict['pad_size'] | |
| }) | |
| # for encoder output | |
| if hs_enc is not None: | |
| interm_coord = ref_enc[-1] | |
| interm_class = self.transformer.enc_out_class_embed(hs_enc[-1]) | |
| interm_pose = torch.zeros_like(outputs_body_keypoints_list[0]) | |
| out['interm_outputs'] = { | |
| 'pred_logits': interm_class, | |
| 'pred_boxes': interm_coord, | |
| 'pred_keypoints': interm_pose | |
| } | |
| return out, targets, data_batch | |
| def _set_aux_loss(self, | |
| outputs_class, | |
| outputs_body_coord, | |
| outputs_lhand_coord, | |
| outputs_rhand_coord, | |
| outputs_face_coord, | |
| outputs_body_keypoints, | |
| outputs_lhand_keypoints, | |
| outputs_rhand_keypoints, | |
| outputs_face_keypoints, | |
| outputs_smpl_pose, | |
| outputs_smpl_rhand_pose, | |
| outputs_smpl_lhand_pose, | |
| outputs_smpl_jaw_pose, | |
| outputs_smpl_expr, | |
| outputs_smpl_beta, | |
| outputs_smpl_cam, | |
| # outputs_smpl_cam_f, | |
| outputs_smpl_kp3d): | |
| return [{ | |
| 'pred_logits': a, | |
| 'pred_boxes': b, | |
| 'pred_lhand_boxes': c, | |
| 'pred_rhand_boxes': d, | |
| 'pred_face_boxes': e, | |
| 'pred_keypoints': f, | |
| 'pred_lhand_keypoints': g, | |
| 'pred_rhand_keypoints': h, | |
| 'pred_face_keypoints': i, | |
| 'pred_smpl_pose': j, | |
| 'pred_smpl_rhand_pose': k, | |
| 'pred_smpl_lhand_pose': l, | |
| 'pred_smpl_jaw_pose': m, | |
| 'pred_smpl_expr': n, | |
| 'pred_smpl_beta': o, | |
| 'pred_smpl_cam': p, | |
| # 'pred_smpl_cam_f': q, | |
| 'pred_smpl_kp3d': q | |
| } for a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q in zip( | |
| outputs_class[:-1], | |
| outputs_body_coord[:-1], | |
| outputs_lhand_coord[:-1], | |
| outputs_rhand_coord[:-1], | |
| outputs_face_coord[:-1], | |
| outputs_body_keypoints[:-1], | |
| outputs_lhand_keypoints[:-1], | |
| outputs_rhand_keypoints[:-1], | |
| outputs_face_keypoints[:-1], | |
| outputs_smpl_pose[:-1], | |
| outputs_smpl_rhand_pose[:-1], | |
| outputs_smpl_lhand_pose[:-1], | |
| outputs_smpl_jaw_pose[:-1], | |
| outputs_smpl_expr[:-1], | |
| outputs_smpl_beta[:-1], | |
| outputs_smpl_cam[:-1], | |
| outputs_smpl_kp3d[:-1])] | |
| def prepare_targets(self, data_batch): | |
| data_batch_coco = [] | |
| instance_dict = {} | |
| img_list = data_batch['img'].float() | |
| # input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape'] | |
| batch_size, _, input_img_h, input_img_w = img_list.shape | |
| device = img_list.device | |
| masks = torch.ones((batch_size, input_img_h, input_img_w), | |
| dtype=torch.bool, | |
| device=device) | |
| # cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape | |
| for img_id in range(batch_size): | |
| img_h, img_w = data_batch['img_shape'][img_id] | |
| masks[img_id, :img_h, :img_w] = 0 | |
| if not self.inference: | |
| instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
| data_batch['body_bbox_size'][img_id]],dim=-1) | |
| instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\ | |
| data_batch['face_bbox_size'][img_id]],dim=-1) | |
| instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\ | |
| data_batch['lhand_bbox_size'][img_id]],dim=-1) | |
| instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\ | |
| data_batch['rhand_bbox_size'][img_id]],dim=-1) | |
| instance_kp2d = data_batch['joint_img'][img_id].clone().float() | |
| instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float() | |
| instance_kp2d[:,:,2:] = instance_kp2d_mask | |
| body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True) | |
| lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True) | |
| rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True) | |
| face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True) | |
| # from util.vis_utils import show_bbox | |
| # show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy) | |
| body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1) | |
| lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1) | |
| rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1) | |
| face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1) | |
| instance_dict = {} | |
| instance_dict['boxes'] = instance_body_bbox.float() | |
| instance_dict['face_boxes'] = instance_face_bbox.float() | |
| instance_dict['lhand_boxes'] = instance_lhand_bbox.float() | |
| instance_dict['rhand_boxes'] = instance_rhand_bbox.float() | |
| instance_dict['keypoints'] = body_kp2d.float() | |
| instance_dict['lhand_keypoints'] = lhand_kp2d.float() | |
| instance_dict['rhand_keypoints'] = rhand_kp2d.float() | |
| instance_dict['face_keypoints'] = face_kp2d.float() | |
| # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
| instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
| instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3] | |
| instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3] | |
| instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3] | |
| instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3] | |
| instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0], | |
| dtype=torch.long, | |
| device=device) | |
| data_batch_coco.append(instance_dict) | |
| # body_bbox = data_batch['body_bbox'][img_id].clone().float().reshape(-1, 4) | |
| # lhand_bbox = data_batch['lhand_bbox'][img_id].clone().float().reshape(-1, 4) | |
| # rhand_bbox = data_batch['rhand_bbox'][img_id].clone().float().reshape(-1, 4) | |
| # face_bbox = data_batch['face_bbox'][img_id].clone().float().reshape(-1, 4) | |
| # vis = False | |
| # if vis: | |
| # import mmcv | |
| # body_bbox[:, 0] *= img_w | |
| # body_bbox[:, 1] *= img_h | |
| # body_bbox[:, 2] *= img_w | |
| # body_bbox[:, 3] *= img_h | |
| # img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy() | |
| # img = mmcv.imshow_bboxes(img.copy(), face_bbox.cpu().numpy(), show=False) | |
| # cv2.imwrite('test.png', img) | |
| # instance_kp2d[:,:,0] = instance_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w | |
| # instance_kp2d[:,:,1] = instance_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h | |
| # from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d | |
| # img = (data_batch['img'][img_id]*255).int().permute(1,2,0).cpu().detach().numpy() | |
| # img1 = visualize_kp2d(instance_kp2d.cpu().detach().numpy(),image_array=img[None].copy(),return_array=True) | |
| # cv2.imwrite('test.png',img1[0]) | |
| # lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2]*img_w | |
| # lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1]*img_h | |
| # lhand_kp2d = convert_kps(lhand_kp2d, 'smplx_lhand', 'smplx', approximate=True)[0] | |
| else: | |
| instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
| data_batch['body_bbox_size'][img_id]],dim=-1) | |
| instance_dict = {} | |
| # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
| instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
| instance_dict['boxes'] = instance_body_bbox.float() | |
| data_batch_coco.append(instance_dict) | |
| input_img = NestedTensor(img_list, masks) | |
| return input_img, data_batch_coco | |
| def keypoints_to_scaled_bbox_bfh( | |
| self, keypoints, occ=None, | |
| body_scale=1.0, fh_scale=1.0, | |
| convention='smplx'): | |
| '''Obtain scaled bbox in xyxy format given keypoints | |
| Args: | |
| keypoints (np.ndarray): Keypoints | |
| scale (float): Bounding Box scale | |
| Returns: | |
| bbox_xyxy (np.ndarray): Bounding box in xyxy format | |
| ''' | |
| bboxs = [] | |
| # supported kps.shape: (1, n, k) or (n, k), k = 2 or 3 | |
| if keypoints.ndim == 3: | |
| keypoints = keypoints[0] | |
| if keypoints.shape[-1] != 2: | |
| keypoints = keypoints[:, :2] | |
| for body_part in ['body', 'head', 'left_hand', 'right_hand']: | |
| if body_part == 'body': | |
| scale = body_scale | |
| kps = keypoints | |
| else: | |
| scale = fh_scale | |
| kp_id = get_keypoint_idxs_by_part(body_part, convention=convention) | |
| kps = keypoints[kp_id] | |
| if not occ is None: | |
| occ_p = occ[kp_id] | |
| if np.sum(occ_p) / len(kp_id) >= 0.1: | |
| conf = 0 | |
| # print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip') | |
| else: | |
| # print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}') | |
| conf = 1 | |
| else: | |
| conf = 1 | |
| if body_part == 'body': | |
| conf = 1 | |
| xmin, ymin = np.amin(kps, axis=0) | |
| xmax, ymax = np.amax(kps, axis=0) | |
| width = (xmax - xmin) * scale | |
| height = (ymax - ymin) * scale | |
| x_center = 0.5 * (xmax + xmin) | |
| y_center = 0.5 * (ymax + ymin) | |
| xmin = x_center - 0.5 * width | |
| xmax = x_center + 0.5 * width | |
| ymin = y_center - 0.5 * height | |
| ymax = y_center + 0.5 * height | |
| bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32) | |
| bboxs.append(bbox) | |
| return bboxs | |
| def build_aios_smplx(args, cfg): | |
| # pdb.set_trace() | |
| num_classes = args.num_classes # 2 | |
| device = torch.device(args.device) | |
| backbone = build_backbone(args) | |
| transformer = build_transformer(args) | |
| dn_labelbook_size = args.dn_labelbook_size | |
| dec_pred_class_embed_share = args.dec_pred_class_embed_share | |
| dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share | |
| if args.eval: | |
| body_model = args.body_model_test | |
| train = False | |
| else: | |
| body_model = args.body_model_train | |
| train = True | |
| model = AiOSSMPLX( | |
| backbone, | |
| transformer, | |
| num_classes=num_classes, # 2 | |
| num_queries=args.num_queries, # 900 | |
| aux_loss=True, | |
| iter_update=True, | |
| query_dim=4, | |
| random_refpoints_xy=args.random_refpoints_xy, # False | |
| fix_refpoints_hw=args.fix_refpoints_hw, # -1 | |
| num_feature_levels=args.num_feature_levels, # 4 | |
| nheads=args.nheads, # 8 | |
| dec_pred_class_embed_share=dec_pred_class_embed_share, # false | |
| dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False | |
| # two stage | |
| two_stage_type=args.two_stage_type, | |
| # box_share | |
| two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False | |
| two_stage_class_embed_share=args.two_stage_class_embed_share, # False | |
| dn_number=args.dn_number if args.use_dn else 0, # 100 | |
| dn_box_noise_scale=args.dn_box_noise_scale, # 0.4 | |
| dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5 | |
| dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false | |
| dn_attn_mask_type_list=args.dn_attn_mask_type_list, | |
| dn_labelbook_size=dn_labelbook_size, # 100 | |
| cls_no_bias=args.cls_no_bias, # False | |
| num_group=args.num_group, # 100 | |
| num_body_points=args.num_body_points, # 17 | |
| num_hand_points=args.num_hand_points, # 17 | |
| num_face_points=args.num_face_points, # 17 | |
| num_box_decoder_layers=args.num_box_decoder_layers, # 2 | |
| num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
| # smpl_convention=convention | |
| body_model=body_model, | |
| train=train, | |
| inference=args.inference) | |
| matcher = build_matcher(args) | |
| # prepare weight dict | |
| weight_dict = { | |
| 'loss_ce': args.cls_loss_coef, # 2 | |
| # bbox | |
| 'loss_body_bbox': args.body_bbox_loss_coef, # 5 | |
| 'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5 | |
| 'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5 | |
| 'loss_face_bbox': args.face_bbox_loss_coef, # 5 | |
| # bbox giou | |
| 'loss_body_giou': args.body_giou_loss_coef, # 2 | |
| 'loss_rhand_giou': args.rhand_giou_loss_coef, # 2 | |
| 'loss_lhand_giou': args.lhand_giou_loss_coef, # 2 | |
| 'loss_face_giou': args.face_giou_loss_coef, # 2 | |
| # 2d kp | |
| 'loss_keypoints': args.keypoints_loss_coef, # 10 | |
| 'loss_rhand_keypoints': args.rhand_keypoints_loss_coef, # 10 | |
| 'loss_lhand_keypoints': args.lhand_keypoints_loss_coef, # 10 | |
| 'loss_face_keypoints': args.face_keypoints_loss_coef, # 10 | |
| # 2d kp oks | |
| 'loss_oks': args.oks_loss_coef, # 4 | |
| 'loss_rhand_oks': args.rhand_oks_loss_coef, # 4 | |
| 'loss_lhand_oks': args.lhand_oks_loss_coef, # 4 | |
| 'loss_face_oks': args.face_oks_loss_coef, # 4 | |
| # smpl param | |
| 'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0 | |
| 'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0 | |
| 'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0 | |
| 'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0 | |
| 'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0 | |
| 'loss_smpl_beta': args.smpl_beta_loss_coef, # 0 | |
| 'loss_smpl_expr': args.smpl_expr_loss_coef, | |
| # smpl kp3d ra | |
| 'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0 | |
| 'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0 | |
| 'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0 | |
| 'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0 | |
| # smpl kp3d | |
| 'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0 | |
| 'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0 | |
| 'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0 | |
| 'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0 | |
| # smpl kp2d | |
| 'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0 | |
| 'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0 | |
| 'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0 | |
| 'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0 | |
| # smpl kp2d ba | |
| 'loss_smpl_body_kp2d_ba': args.smpl_body_kp2d_ba_loss_coef, | |
| 'loss_smpl_face_kp2d_ba': args.smpl_face_kp2d_ba_loss_coef, | |
| 'loss_smpl_lhand_kp2d_ba': args.smpl_lhand_kp2d_ba_loss_coef, | |
| 'loss_smpl_rhand_kp2d_ba': args.smpl_rhand_kp2d_ba_loss_coef, | |
| } | |
| clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) | |
| if args.use_dn: | |
| weight_dict.update({ | |
| 'dn_loss_ce': | |
| args.dn_label_coef, # 0.3 | |
| 'dn_loss_bbox': | |
| args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5 | |
| 'dn_loss_giou': | |
| args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5 | |
| }) | |
| clean_weight_dict = copy.deepcopy(weight_dict) | |
| if args.aux_loss: | |
| aux_weight_dict = {} | |
| for i in range(args.dec_layers - 1): # from 0 t 4 # ??? | |
| for k, v in clean_weight_dict.items(): | |
| if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k): | |
| continue | |
| if i < args.num_box_decoder_layers and k in [ | |
| 'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox', | |
| 'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']: | |
| continue | |
| if i < args.num_hand_face_decoder_layers and k in [ | |
| 'loss_rhand_keypoints', 'loss_lhand_keypoints', | |
| 'loss_face_keypoints', 'loss_rhand_oks', | |
| 'loss_lhand_oks', 'loss_face_oks']: | |
| continue | |
| if i < args.num_box_decoder_layers and 'smpl' in k: | |
| continue | |
| aux_weight_dict.update({k + f'_{i}': v}) | |
| weight_dict.update(aux_weight_dict) | |
| if args.two_stage_type != 'no': | |
| interm_weight_dict = {} | |
| try: | |
| no_interm_box_loss = args.no_interm_box_loss | |
| except: | |
| no_interm_box_loss = False | |
| _coeff_weight_dict = { | |
| 'loss_ce': 1.0, | |
| # bbox | |
| 'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| # bbox giou | |
| 'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| # 2d kp | |
| 'loss_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_rhand_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_lhand_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_face_keypoints': 1.0 if not no_interm_box_loss else 0.0, | |
| # 2d oks | |
| 'loss_oks': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_rhand_oks': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_lhand_oks': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_face_oks': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl param | |
| 'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl kp3d ra | |
| 'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl kp3d | |
| 'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl kp2d | |
| 'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl kp2d ba | |
| 'loss_smpl_body_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_lhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_rhand_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_face_kp2d_ba': 1.0 if not no_interm_box_loss else 0.0, | |
| } | |
| try: | |
| interm_loss_coef = args.interm_loss_coef # 1 | |
| except: | |
| interm_loss_coef = 1.0 | |
| interm_weight_dict.update({ | |
| k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] | |
| for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k | |
| }) | |
| weight_dict.update(interm_weight_dict) | |
| interm_weight_dict.update({ | |
| k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k] | |
| for k, v in clean_weight_dict_wo_dn.items() | |
| }) # ??? | |
| weight_dict.update(interm_weight_dict) | |
| losses = cfg.losses | |
| if args.dn_number > 0: | |
| losses += ['dn_label', 'dn_bbox'] | |
| losses += ['matching'] | |
| criterion = SetCriterion( | |
| num_classes, | |
| matcher=matcher, | |
| weight_dict=weight_dict, | |
| focal_alpha=args.focal_alpha, | |
| losses=losses, | |
| num_box_decoder_layers=args.num_box_decoder_layers, | |
| num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
| num_body_points=args.num_body_points, | |
| num_hand_points=args.num_hand_points, | |
| num_face_points=args.num_face_points, | |
| ) | |
| criterion.to(device) | |
| if args.inference: | |
| postprocessors = { | |
| 'bbox': | |
| PostProcess_SMPLX_Multi_Infer( | |
| num_select=args.num_select, | |
| nms_iou_threshold=args.nms_iou_threshold, | |
| num_body_points=args.num_body_points), | |
| } | |
| else: | |
| postprocessors = { | |
| 'bbox': | |
| PostProcess_SMPLX( | |
| num_select=args.num_select, | |
| nms_iou_threshold=args.nms_iou_threshold, | |
| num_body_points=args.num_body_points), | |
| } | |
| postprocessors_aios = { | |
| 'bbox': | |
| PostProcess_aios(num_select=args.num_select, | |
| nms_iou_threshold=args.nms_iou_threshold, | |
| num_body_points=args.num_body_points), | |
| } | |
| # criterion_smpl=build_architecture(cfg['smpl_loss']) | |
| return model, criterion, postprocessors, postprocessors_aios | |
| class AiOSSMPLX_Box(nn.Module): | |
| def __init__( | |
| self, | |
| backbone, | |
| transformer, | |
| num_classes, | |
| num_queries, | |
| aux_loss=False, | |
| iter_update=True, | |
| query_dim=4, | |
| random_refpoints_xy=False, | |
| fix_refpoints_hw=-1, | |
| num_feature_levels=1, | |
| nheads=8, | |
| two_stage_type='no', | |
| dec_pred_class_embed_share=False, | |
| dec_pred_bbox_embed_share=False, | |
| dec_pred_pose_embed_share=False, | |
| two_stage_class_embed_share=True, | |
| two_stage_bbox_embed_share=True, | |
| dn_number=100, | |
| dn_box_noise_scale=0.4, | |
| dn_label_noise_ratio=0.5, | |
| dn_batch_gt_fuse=False, | |
| dn_labelbook_size=100, | |
| dn_attn_mask_type_list=['group2group'], | |
| cls_no_bias=False, | |
| num_group=100, | |
| num_body_points=0, | |
| num_hand_points=0, | |
| num_face_points=0, | |
| num_box_decoder_layers=2, | |
| num_hand_face_decoder_layers=4, | |
| body_model=dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| keypoint_dst='smplx_137', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True), | |
| train=True, | |
| inference=False, | |
| focal_length=[5000., 5000.], | |
| camera_3d_size=2.5 | |
| ): | |
| super().__init__() | |
| self.num_queries = num_queries | |
| self.transformer = transformer | |
| self.num_classes = num_classes | |
| self.hidden_dim = hidden_dim = transformer.d_model | |
| self.num_feature_levels = num_feature_levels | |
| self.nheads = nheads | |
| self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) | |
| self.num_body_points = num_body_points | |
| self.num_hand_points = num_hand_points | |
| self.num_face_points = num_face_points | |
| self.num_whole_body_points = num_body_points + 2*num_hand_points + num_face_points | |
| self.num_box_decoder_layers = num_box_decoder_layers | |
| self.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
| self.focal_length = focal_length | |
| self.camera_3d_size=camera_3d_size | |
| self.inference = inference | |
| if train: | |
| self.smpl_convention = 'smplx' | |
| else: | |
| self.smpl_convention = 'h36m' | |
| # setting query dim | |
| self.query_dim = query_dim | |
| assert query_dim == 4 | |
| self.random_refpoints_xy = random_refpoints_xy # False | |
| self.fix_refpoints_hw = fix_refpoints_hw # -1 | |
| # for dn training | |
| self.dn_number = dn_number | |
| self.dn_box_noise_scale = dn_box_noise_scale | |
| self.dn_label_noise_ratio = dn_label_noise_ratio | |
| self.dn_batch_gt_fuse = dn_batch_gt_fuse | |
| self.dn_labelbook_size = dn_labelbook_size | |
| self.dn_attn_mask_type_list = dn_attn_mask_type_list | |
| assert all([ | |
| i in ['match2dn', 'dn2dn', 'group2group'] | |
| for i in dn_attn_mask_type_list | |
| ]) | |
| assert not dn_batch_gt_fuse | |
| # build human body | |
| # if train: | |
| # self.body_model = build_body_model(body_model) | |
| if inference: | |
| body_model=dict( | |
| type='smplx', | |
| keypoint_src='smplx', | |
| num_expression_coeffs=10, | |
| num_betas=10, | |
| keypoint_dst='smplx', | |
| model_path='data/body_models/smplx', | |
| use_pca=False, | |
| use_face_contour=True) | |
| self.body_model = build_body_model(body_model) | |
| for param in self.body_model.parameters(): | |
| param.requires_grad = False | |
| # prepare input projection layers | |
| if num_feature_levels > 1: | |
| num_backbone_outs = len(backbone.num_channels) # 3 | |
| input_proj_list = [] | |
| for _ in range(num_backbone_outs): | |
| in_channels = backbone.num_channels[_] | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, hidden_dim, kernel_size=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| )) | |
| for _ in range(num_feature_levels - num_backbone_outs): | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, | |
| hidden_dim, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| )) | |
| in_channels = hidden_dim | |
| self.input_proj = nn.ModuleList(input_proj_list) | |
| else: | |
| assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!' | |
| self.input_proj = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(backbone.num_channels[-1], | |
| hidden_dim, | |
| kernel_size=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ]) | |
| self.backbone = backbone | |
| self.aux_loss = aux_loss | |
| self.box_pred_damping = box_pred_damping = None | |
| self.iter_update = iter_update | |
| assert iter_update, 'Why not iter_update?' | |
| # prepare pred layers | |
| self.dec_pred_class_embed_share = dec_pred_class_embed_share # false | |
| self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share # false | |
| # 1.1 prepare class & box embed | |
| _class_embed = nn.Linear(hidden_dim, | |
| num_classes, | |
| bias=(not cls_no_bias)) | |
| if not cls_no_bias: | |
| prior_prob = 0.01 | |
| bias_value = -math.log((1 - prior_prob) / prior_prob) | |
| _class_embed.bias.data = torch.ones(self.num_classes) * bias_value | |
| # 1.2 box embed layer list | |
| if dec_pred_class_embed_share: | |
| class_embed_layerlist = [ | |
| _class_embed for i in range(transformer.num_decoder_layers) | |
| ] | |
| else: | |
| class_embed_layerlist = [ | |
| copy.deepcopy(_class_embed) | |
| for i in range(transformer.num_decoder_layers) | |
| ] | |
| ########################################################################### | |
| # body bbox + l/r hand box + face box | |
| ########################################################################### | |
| # 1.1 body bbox embed | |
| _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) | |
| nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) | |
| # 1.2 body bbox embed layer list | |
| self.num_group = num_group | |
| if dec_pred_bbox_embed_share: | |
| box_body_embed_layerlist = [ | |
| _bbox_embed for i in range(transformer.num_decoder_layers) | |
| ] | |
| else: | |
| box_body_embed_layerlist = [ | |
| copy.deepcopy(_bbox_embed) | |
| for i in range(transformer.num_decoder_layers) | |
| ] | |
| # 2.1 lhand bbox embed | |
| _bbox_hand_embed = MLP(hidden_dim, hidden_dim, 2, 3) # TODO: the out shape should be 2 not 4 | |
| nn.init.constant_(_bbox_hand_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_hand_embed.layers[-1].bias.data, 0) | |
| _bbox_hand_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_bbox_hand_hw_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_hand_hw_embed.layers[-1].bias.data, 0) | |
| # 2.2 lhand bbox embed layer list | |
| if dec_pred_pose_embed_share: | |
| box_hand_embed_layerlist = \ | |
| [_bbox_hand_embed for i in range(transformer.num_decoder_layers - num_box_decoder_layers+1)] | |
| else: | |
| box_hand_embed_layerlist = [ | |
| copy.deepcopy(_bbox_hand_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers + 1) | |
| ] | |
| if dec_pred_pose_embed_share: | |
| box_hand_hw_embed_layerlist = [ | |
| _bbox_hand_hw_embed for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers) | |
| ] | |
| else: | |
| box_hand_hw_embed_layerlist = [ | |
| copy.deepcopy(_bbox_hand_hw_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 4.1 face bbox embed | |
| _bbox_face_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_bbox_face_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_face_embed.layers[-1].bias.data, 0) | |
| _bbox_face_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| nn.init.constant_(_bbox_face_hw_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_bbox_face_hw_embed.layers[-1].bias.data, 0) | |
| # 4.2 face bbox embed layer list | |
| if dec_pred_pose_embed_share: | |
| box_face_embed_layerlist = [ | |
| _bbox_face_embed for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers + 1) | |
| ] | |
| else: | |
| box_face_embed_layerlist = [ | |
| copy.deepcopy(_bbox_face_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers + 1) | |
| ] | |
| if dec_pred_pose_embed_share: | |
| box_face_hw_embed_layerlist = [ | |
| _bbox_face_hw_embed for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers)] | |
| else: | |
| box_face_hw_embed_layerlist = [ | |
| copy.deepcopy(_bbox_face_hw_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 1. smpl pose embed | |
| if body_model['type'].upper()=='SMPL': | |
| self.body_model_joint_num = 24 | |
| elif body_model['type'].upper()=='SMPLX': | |
| self.body_model_joint_num = 22 | |
| else: | |
| raise ValueError( | |
| f'Only supports SMPL or SMPLX, but get {body_model.type}') | |
| #TODO: | |
| _smpl_pose_embed = MLP(hidden_dim * 4, hidden_dim, self.body_model_joint_num * 6, 3) | |
| nn.init.constant_(_smpl_pose_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smpl_pose_embed.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smpl_pose_embed_layerlist = [ | |
| _smpl_pose_embed | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smpl_pose_embed_layerlist = [ | |
| copy.deepcopy(_smpl_pose_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 2. smpl betas embed | |
| _smpl_beta_embed = MLP(hidden_dim * 4, hidden_dim, 10, 3) | |
| nn.init.constant_(_smpl_beta_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smpl_beta_embed.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smpl_beta_embed_layerlist = [ | |
| _smpl_beta_embed | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smpl_beta_embed_layerlist = [ | |
| copy.deepcopy(_smpl_beta_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 3. smpl cam embed | |
| _cam_embed = MLP(hidden_dim * 4, hidden_dim, 3, 3) | |
| nn.init.constant_(_cam_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(_cam_embed.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| cam_embed_layerlist = [ | |
| _cam_embed for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| cam_embed_layerlist = [ | |
| copy.deepcopy(_cam_embed) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| ########################################################################### | |
| # smplx body pose + hand pose + expression + betas + kp2d + kp3d + cam | |
| ########################################################################### | |
| # 2. smplx hand pose embed | |
| _smplx_hand_pose_embed_layer_2_3 = \ | |
| MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_2_3.layers[-1].bias.data, 0) | |
| _smplx_hand_pose_embed_layer_4_5 = \ | |
| MLP(hidden_dim * 2, hidden_dim, 15 * 6, 3) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_hand_pose_embed_layer_4_5.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smplx_hand_pose_embed_layerlist = [ | |
| _smplx_hand_pose_embed_layer_2_3 | |
| if i<2 else _smplx_hand_pose_embed_layer_4_5 | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smplx_hand_pose_embed_layerlist = [ | |
| copy.deepcopy(_smplx_hand_pose_embed_layer_2_3) | |
| if i<2 else copy.deepcopy(_smplx_hand_pose_embed_layer_4_5) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 3. smplx face expression | |
| _smplx_expression_embed_layer_2_3 = \ | |
| MLP(hidden_dim*2, hidden_dim, 10, 3) | |
| nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_expression_embed_layer_2_3.layers[-1].bias.data, 0) | |
| _smplx_expression_embed_layer_4_5 = \ | |
| MLP(hidden_dim * 2, hidden_dim, 10, 3) | |
| nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_expression_embed_layer_4_5.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smplx_expression_embed_layerlist = [ | |
| _smplx_expression_embed_layer_2_3 | |
| if i<2 else _smplx_expression_embed_layer_4_5 | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| else: | |
| smplx_expression_embed_layerlist = [ | |
| copy.deepcopy(_smplx_expression_embed_layer_2_3) | |
| if i<2 else copy.deepcopy(_smplx_expression_embed_layer_4_5) | |
| for i in range(transformer.num_decoder_layers - | |
| num_box_decoder_layers) | |
| ] | |
| # 4. smplx jaw pose embed | |
| _smplx_jaw_embed_2_3 = MLP(hidden_dim * 2, hidden_dim, 6, 3) | |
| nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_jaw_embed_2_3.layers[-1].bias.data, 0) | |
| _smplx_jaw_embed_4_5 = MLP(hidden_dim * 2, hidden_dim, 6, 3) | |
| nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].weight.data, 0) | |
| nn.init.constant_(_smplx_jaw_embed_4_5.layers[-1].bias.data, 0) | |
| if dec_pred_bbox_embed_share: | |
| smplx_jaw_embed_layerlist = [ | |
| _smplx_jaw_embed_2_3 if i<2 else _smplx_jaw_embed_4_5 | |
| for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers) | |
| ] | |
| else: | |
| smplx_jaw_embed_layerlist = [ | |
| copy.deepcopy(_smplx_jaw_embed_2_3) | |
| if i<2 else copy.deepcopy(_smplx_jaw_embed_4_5) | |
| for i in range( | |
| transformer.num_decoder_layers - num_box_decoder_layers) | |
| ] | |
| self.bbox_embed = nn.ModuleList(box_body_embed_layerlist) | |
| self.class_embed = nn.ModuleList(class_embed_layerlist) | |
| self.transformer.decoder.bbox_embed = self.bbox_embed | |
| self.transformer.decoder.class_embed = self.class_embed | |
| # smpl | |
| self.smpl_pose_embed = nn.ModuleList(smpl_pose_embed_layerlist) | |
| self.smpl_beta_embed = nn.ModuleList(smpl_beta_embed_layerlist) | |
| self.smpl_cam_embed = nn.ModuleList(cam_embed_layerlist) | |
| # smplx lhand kp | |
| self.bbox_hand_embed = nn.ModuleList(box_hand_embed_layerlist) | |
| self.bbox_hand_hw_embed = nn.ModuleList(box_hand_hw_embed_layerlist) | |
| self.transformer.decoder.bbox_hand_embed = self.bbox_hand_embed | |
| self.transformer.decoder.bbox_hand_hw_embed = self.bbox_hand_hw_embed | |
| # smplx face kp | |
| self.bbox_face_embed = nn.ModuleList(box_face_embed_layerlist) | |
| self.bbox_face_hw_embed = nn.ModuleList(box_face_hw_embed_layerlist) | |
| self.transformer.decoder.bbox_face_embed = self.bbox_face_embed | |
| self.transformer.decoder.bbox_face_hw_embed = self.bbox_face_hw_embed | |
| # smplx | |
| self.smpl_hand_pose_embed = nn.ModuleList(smplx_hand_pose_embed_layerlist) | |
| self.smpl_expr_embed = nn.ModuleList(smplx_expression_embed_layerlist) | |
| self.smpl_jaw_embed = nn.ModuleList(smplx_jaw_embed_layerlist) | |
| self.transformer.decoder.num_hand_face_decoder_layers = num_hand_face_decoder_layers | |
| self.transformer.decoder.num_box_decoder_layers = num_box_decoder_layers | |
| self.transformer.decoder.num_body_points = num_body_points | |
| self.transformer.decoder.num_hand_points = num_hand_points | |
| self.transformer.decoder.num_face_points = num_face_points | |
| # two stage | |
| self.two_stage_type = two_stage_type | |
| assert two_stage_type in [ | |
| 'no', 'standard' | |
| ], 'unknown param {} of two_stage_type'.format(two_stage_type) | |
| if two_stage_type != 'no': | |
| if two_stage_bbox_embed_share: | |
| assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
| self.transformer.enc_out_bbox_embed = _bbox_embed | |
| else: | |
| self.transformer.enc_out_bbox_embed = copy.deepcopy( | |
| _bbox_embed) | |
| if two_stage_class_embed_share: | |
| assert dec_pred_class_embed_share and dec_pred_bbox_embed_share | |
| self.transformer.enc_out_class_embed = _class_embed | |
| else: | |
| self.transformer.enc_out_class_embed = copy.deepcopy( | |
| _class_embed) | |
| self.refpoint_embed = None | |
| self._reset_parameters() | |
| def get_camera_trans(self, cam_param, input_body_shape): | |
| # camera translation | |
| t_xy = cam_param[:, :2] | |
| gamma = torch.sigmoid(cam_param[:, 2]) # apply sigmoid to make it positive | |
| k_value = torch.FloatTensor( | |
| [ | |
| math.sqrt( | |
| self.focal_length[0] * self.focal_length[1] * self.camera_3d_size * self.camera_3d_size / | |
| (input_body_shape[0] * input_body_shape[1]) | |
| ) | |
| ] | |
| ).cuda().view(-1) | |
| t_z = k_value * gamma | |
| cam_trans = torch.cat((t_xy, t_z[:, None]), 1) | |
| return cam_trans | |
| def _reset_parameters(self): | |
| # init input_proj | |
| for proj in self.input_proj: | |
| nn.init.xavier_uniform_(proj[0].weight, gain=1) | |
| nn.init.constant_(proj[0].bias, 0) | |
| def prepare_for_dn2(self, targets): | |
| if not self.training: | |
| device = targets[0]['boxes'].device | |
| bs = len(targets) | |
| num_points = 4 | |
| attn_mask2 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| self.num_group * 4, | |
| self.num_group * 4, | |
| device=device, | |
| dtype=torch.bool) | |
| group_bbox_kpt = 4 | |
| # body bbox index | |
| kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
| for matchj in range(self.num_group * 4): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| # for each instance, they should associate with their query (body hand face) | |
| if sj > 0: | |
| attn_mask2[:, :, matchj, :sj] = True | |
| if ej < self.num_group * 4: | |
| attn_mask2[:, :, matchj, ej:] = True | |
| for match_x in range(self.num_group * 4): | |
| if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
| # each query (hand face body) should associate with all body query | |
| attn_mask2[:,:,match_x, kpt_index]=False | |
| num_points = 4 | |
| attn_mask3 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| self.num_group * 4, | |
| self.num_group * 4, | |
| device=device, | |
| dtype=torch.bool) | |
| group_bbox_kpt = 4 | |
| kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
| for matchj in range(self.num_group * 4): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| # for each instance, they should associate with their query (body hand face) | |
| if sj > 0: | |
| attn_mask3[:, :, matchj, :sj] = True | |
| if ej < self.num_group * 4: | |
| attn_mask3[:, :, matchj, ej:] = True | |
| for match_x in range(self.num_group * 4): | |
| if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
| # each query (hand face body) should associate with all body query | |
| attn_mask3[:, :, match_x, kpt_index] = False | |
| attn_mask2 = attn_mask2.flatten(0, 1) | |
| attn_mask3 = attn_mask3.flatten(0, 1) | |
| return None, None, None, attn_mask2, attn_mask3, None | |
| # targets, dn_scalar, noise_scale = dn_args | |
| device = targets[0]['boxes'].device | |
| bs = len(targets) | |
| dn_number = self.dn_number # 100 | |
| dn_box_noise_scale = self.dn_box_noise_scale # 0.4 | |
| dn_label_noise_ratio = self.dn_label_noise_ratio # 0.5 | |
| # gather gt boxes and labels | |
| gt_boxes = [t['boxes'] for t in targets] | |
| gt_labels = [t['labels'] for t in targets] | |
| gt_keypoints = [t['keypoints'] for t in targets] | |
| # repeat them | |
| def get_indices_for_repeat(now_num, target_num, device='cuda'): | |
| """ | |
| Input: | |
| - now_num: int | |
| - target_num: int | |
| Output: | |
| - indices: tensor[target_num] | |
| """ | |
| out_indice = [] | |
| base_indice = torch.arange(now_num).to(device) | |
| multiplier = target_num // now_num | |
| out_indice.append(base_indice.repeat(multiplier)) | |
| residue = target_num % now_num | |
| out_indice.append(base_indice[torch.randint(0, | |
| now_num, (residue, ), | |
| device=device)]) | |
| return torch.cat(out_indice) | |
| if self.dn_batch_gt_fuse: | |
| raise NotImplementedError | |
| gt_boxes_bsall = torch.cat(gt_boxes) # num_boxes, 4 | |
| gt_labels_bsall = torch.cat(gt_labels) | |
| num_gt_bsall = gt_boxes_bsall.shape[0] | |
| if num_gt_bsall > 0: | |
| indices = get_indices_for_repeat(num_gt_bsall, dn_number, | |
| device) | |
| gt_boxes_expand = gt_boxes_bsall[indices][None].repeat( | |
| bs, 1, 1) # bs, num_dn, 4 | |
| gt_labels_expand = gt_labels_bsall[indices][None].repeat( | |
| bs, 1) # bs, num_dn | |
| else: | |
| # all negative samples when no gt boxes | |
| gt_boxes_expand = torch.rand(bs, dn_number, 4, device=device) | |
| gt_labels_expand = torch.ones( | |
| bs, dn_number, dtype=torch.int64, device=device) * int( | |
| self.num_classes) | |
| else: | |
| gt_boxes_expand = [] | |
| gt_labels_expand = [] | |
| gt_keypoints_expand = [] # here | |
| for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( | |
| zip(gt_boxes, gt_labels, gt_keypoints)): # idx -> batch id | |
| num_gt_i = gt_boxes_i.shape[0] # instance num | |
| if num_gt_i > 0: | |
| indices = get_indices_for_repeat(num_gt_i, dn_number, | |
| device) | |
| gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 | |
| gt_labels_expand_i = gt_labels_i[indices] # add smpl | |
| gt_keypoints_expand_i = gt_keypoint_i[indices] | |
| else: | |
| # all negative samples when no gt boxes | |
| gt_boxes_expand_i = torch.rand(dn_number, 4, device=device) | |
| gt_labels_expand_i = torch.ones( | |
| dn_number, dtype=torch.int64, device=device) * int( | |
| self.num_classes) | |
| gt_keypoints_expand_i = torch.rand(dn_number, | |
| self.num_body_points * | |
| 3, | |
| device=device) | |
| gt_boxes_expand.append(gt_boxes_expand_i) # add smpl | |
| gt_labels_expand.append(gt_labels_expand_i) | |
| gt_keypoints_expand.append(gt_keypoints_expand_i) | |
| gt_boxes_expand = torch.stack(gt_boxes_expand) | |
| gt_labels_expand = torch.stack(gt_labels_expand) | |
| gt_keypoints_expand = torch.stack(gt_keypoints_expand) | |
| knwon_boxes_expand = gt_boxes_expand.clone() | |
| knwon_labels_expand = gt_labels_expand.clone() | |
| # add noise | |
| if dn_label_noise_ratio > 0: | |
| prob = torch.rand_like(knwon_labels_expand.float()) | |
| chosen_indice = prob < dn_label_noise_ratio | |
| new_label = torch.randint_like( | |
| knwon_labels_expand[chosen_indice], 0, | |
| self.dn_labelbook_size) # randomly put a new one here | |
| knwon_labels_expand[chosen_indice] = new_label | |
| if dn_box_noise_scale > 0: | |
| diff = torch.zeros_like(knwon_boxes_expand) | |
| diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 | |
| diff[..., 2:] = knwon_boxes_expand[..., 2:] | |
| knwon_boxes_expand += torch.mul( | |
| (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), | |
| diff) * dn_box_noise_scale | |
| knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) | |
| input_query_label = self.label_enc(knwon_labels_expand) | |
| input_query_bbox = inverse_sigmoid(knwon_boxes_expand) | |
| # prepare mask | |
| if 'group2group' in self.dn_attn_mask_type_list: | |
| attn_mask = torch.zeros(bs, | |
| self.nheads, | |
| dn_number + self.num_queries, | |
| dn_number + self.num_queries, | |
| device=device, | |
| dtype=torch.bool) | |
| attn_mask[:, :, dn_number:, :dn_number] = True | |
| for idx, (gt_boxes_i, gt_labels_i) in enumerate( | |
| zip(gt_boxes, gt_labels)): # for batch | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(dn_number): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask[idx, :, matchi, :si] = True | |
| if ei < dn_number: | |
| attn_mask[idx, :, matchi, ei:dn_number] = True | |
| attn_mask = attn_mask.flatten(0, 1) | |
| if 'group2group' in self.dn_attn_mask_type_list: | |
| # self.num_body_points = self.num_body_points +3 | |
| num_points = 4 | |
| attn_mask2 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| dn_number + self.num_group * 4, | |
| dn_number + self.num_group * 4, | |
| device=device, | |
| dtype=torch.bool) | |
| attn_mask2[:, :, dn_number:, :dn_number] = True | |
| group_bbox_kpt = 4 | |
| for matchj in range(self.num_group * 4): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| # for each instance, they should associate their body, hand, and face bbox | |
| if sj > 0: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
| if ej < self.num_group * 4: | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
| # body bbox index | |
| kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
| for match_x in range(self.num_group * 4): | |
| if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
| # for each instance, they should associate their each query with | |
| # other instances' body query | |
| attn_mask2[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False | |
| for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(dn_number): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask2[idx, :, matchi, :si] = True | |
| if ei < dn_number: | |
| attn_mask2[idx, :, matchi, ei:dn_number] = True | |
| attn_mask2 = attn_mask2.flatten(0, 1) | |
| if 'group2group' in self.dn_attn_mask_type_list: | |
| num_points = 4 | |
| attn_mask3 = torch.zeros( | |
| bs, | |
| self.nheads, | |
| dn_number + self.num_group * 4, dn_number + self.num_group * 4, | |
| device=device, dtype=torch.bool) | |
| attn_mask3[:, :, dn_number:, :dn_number] = True | |
| group_bbox_kpt = 4 | |
| for matchj in range(self.num_group * 4): | |
| sj = (matchj // group_bbox_kpt) * group_bbox_kpt | |
| ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt | |
| # for each instance, they should associate their body, hand, and face bbox | |
| if sj > 0: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, :sj] = True | |
| if ej < self.num_group * 4: | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, matchj, ej:] = True | |
| kpt_index = [x for x in range(self.num_group * 4) if x % 4 in [0]] | |
| for match_x in range(self.num_group * 4): | |
| if match_x % group_bbox_kpt in [0, 1, 2, 3]: | |
| # for each instance, they should associate their each query with | |
| # other instances' body query | |
| attn_mask3[:, :, dn_number:, dn_number:][:, :, match_x, kpt_index]=False | |
| for idx, (gt_boxes_i, gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): | |
| num_gt_i = gt_boxes_i.shape[0] | |
| if num_gt_i == 0: | |
| continue | |
| for matchi in range(dn_number): | |
| si = (matchi // num_gt_i) * num_gt_i | |
| ei = (matchi // num_gt_i + 1) * num_gt_i | |
| if si > 0: | |
| attn_mask3[idx, :, matchi, :si] = True | |
| if ei < dn_number: | |
| attn_mask3[idx, :, matchi, ei:dn_number] = True | |
| attn_mask3 = attn_mask3.flatten(0, 1) | |
| mask_dict = { | |
| 'pad_size': dn_number, | |
| 'known_bboxs': gt_boxes_expand, | |
| 'known_labels': gt_labels_expand, | |
| 'known_keypoints': gt_keypoints_expand | |
| } | |
| return input_query_label, input_query_bbox, attn_mask, attn_mask2, attn_mask3, mask_dict | |
| def dn_post_process2(self, outputs_class, outputs_coord, mask_dict): | |
| if mask_dict and mask_dict['pad_size'] > 0: | |
| output_known_class = [ | |
| outputs_class_i[:, :mask_dict['pad_size'], :] | |
| for outputs_class_i in outputs_class | |
| ] | |
| output_known_coord = [ | |
| outputs_coord_i[:, :mask_dict['pad_size'], :] | |
| for outputs_coord_i in outputs_coord | |
| ] | |
| outputs_class = [ | |
| outputs_class_i[:, mask_dict['pad_size']:, :] | |
| for outputs_class_i in outputs_class | |
| ] | |
| outputs_coord = [ | |
| outputs_coord_i[:, mask_dict['pad_size']:, :] | |
| for outputs_coord_i in outputs_coord | |
| ] | |
| mask_dict.update({ | |
| 'output_known_coord': output_known_coord, | |
| 'output_known_class': output_known_class | |
| }) | |
| return outputs_class, outputs_coord | |
| def forward(self, data_batch: NestedTensor, targets: List = None): | |
| """The forward expects a NestedTensor, which consists of: | |
| - samples.tensor: batched images, of shape [batch_size x 3 x H x W] | |
| - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
| It returns a dict with the following elements: | |
| - "pred_logits": the classification logits (including no-object) for all queries. | |
| Shape= [batch_size x num_queries x num_classes] | |
| - "pred_boxes": The normalized boxes coordinates for all queries, represented as | |
| (center_x, center_y, width, height). These values are normalized in [0, 1], | |
| relative to the size of each individual image (disregarding possible padding). | |
| See PostProcess for information on how to retrieve the unnormalized bounding box. | |
| - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
| dictionnaries containing the two above keys for each decoder layer. | |
| """ | |
| if isinstance(data_batch, dict): | |
| samples, targets = self.prepare_targets(data_batch) | |
| # import pdb; pdb.set_trace() | |
| elif isinstance(data_batch, (list, torch.Tensor)): | |
| samples = nested_tensor_from_tensor_list(data_batch) | |
| else: | |
| samples = data_batch | |
| features, poss = self.backbone(samples) | |
| srcs = [] | |
| masks = [] | |
| for l, feat in enumerate(features): # len(features=3) | |
| src, mask = feat.decompose() | |
| srcs.append(self.input_proj[l](src)) | |
| masks.append(mask) | |
| assert mask is not None | |
| if self.num_feature_levels > len(srcs): | |
| _len_srcs = len(srcs) | |
| for l in range(_len_srcs, self.num_feature_levels): | |
| if l == _len_srcs: | |
| src = self.input_proj[l](features[-1].tensors) | |
| else: | |
| src = self.input_proj[l](srcs[-1]) | |
| m = samples.mask | |
| mask = F.interpolate(m[None].float(), | |
| size=src.shape[-2:]).to(torch.bool)[0] | |
| pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| srcs.append(src) | |
| masks.append(mask) | |
| poss.append(pos_l) | |
| if self.dn_number > 0 or targets is not None: | |
| input_query_label, input_query_bbox, attn_mask,attn_mask2, attn_mask3, mask_dict =\ | |
| self.prepare_for_dn2(targets) | |
| else: | |
| assert targets is None | |
| input_query_bbox = input_query_label = attn_mask = attn_mask2 = attn_mask3 = mask_dict = None | |
| hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( | |
| srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, | |
| attn_mask2, attn_mask3) | |
| # update human boxes | |
| effective_dn_number = self.dn_number if self.training else 0 | |
| outputs_body_bbox_list = [] | |
| outputs_class = [] | |
| for dec_lid, (layer_ref_sig, layer_body_bbox_embed, layer_cls_embed, | |
| layer_hs) in enumerate( | |
| zip(reference[:-1], self.bbox_embed, | |
| self.class_embed, hs)): | |
| if dec_lid < self.num_box_decoder_layers: | |
| # human det | |
| layer_delta_unsig = layer_body_bbox_embed(layer_hs) | |
| layer_body_box_outputs_unsig = \ | |
| layer_delta_unsig + inverse_sigmoid(layer_ref_sig) | |
| layer_body_box_outputs_unsig = layer_body_box_outputs_unsig.sigmoid() | |
| layer_cls = layer_cls_embed(layer_hs) | |
| outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
| outputs_class.append(layer_cls) | |
| elif dec_lid < self.num_box_decoder_layers + 2: | |
| bs = layer_ref_sig.shape[0] | |
| # dn body bbox | |
| layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
| reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
| layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
| # norm body bbox | |
| layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
| :, 0::(self.num_body_points + 4), :] # norm content query | |
| reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][ | |
| :, 0::(self.num_body_points+ 4), :] # norm position query | |
| layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
| layer_body_box_outputs_unsig = torch.cat( | |
| (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
| # classfication | |
| layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
| layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
| layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
| outputs_class.append(layer_cls) | |
| outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
| else: | |
| bs = layer_ref_sig.shape[0] | |
| # dn body bbox | |
| layer_hs_body_bbox_dn = layer_hs[:, :effective_dn_number, :] # dn content query | |
| reference_before_sigmoid_body_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] # dn position query | |
| layer_body_box_delta_unsig_dn = layer_body_bbox_embed(layer_hs_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_delta_unsig_dn + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_dn) | |
| layer_body_box_outputs_unsig_dn = layer_body_box_outputs_unsig_dn.sigmoid() | |
| # norm body bbox | |
| layer_hs_body_bbox_norm = layer_hs[:, effective_dn_number:, :][ | |
| :, 0::(self.num_whole_body_points + 4), :] # norm content query | |
| reference_before_sigmoid_body_bbox_norm = layer_ref_sig[:,effective_dn_number:, :][ | |
| :, 0::(self.num_whole_body_points + 4), :] # norm position query | |
| layer_body_box_delta_unsig_norm = layer_body_bbox_embed(layer_hs_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_delta_unsig_norm + inverse_sigmoid( | |
| reference_before_sigmoid_body_bbox_norm) | |
| layer_body_box_outputs_unsig_norm = layer_body_box_outputs_unsig_norm.sigmoid() | |
| layer_body_box_outputs_unsig = torch.cat( | |
| (layer_body_box_outputs_unsig_dn, layer_body_box_outputs_unsig_norm), dim=1) | |
| # classfication | |
| layer_cls_dn = layer_cls_embed(layer_hs_body_bbox_dn) | |
| layer_cls_norm = layer_cls_embed(layer_hs_body_bbox_norm) | |
| layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) | |
| outputs_class.append(layer_cls) | |
| outputs_body_bbox_list.append(layer_body_box_outputs_unsig) | |
| # update hand and face boxes | |
| outputs_lhand_bbox_list = [] | |
| outputs_rhand_bbox_list = [] | |
| outputs_face_bbox_list = [] | |
| # update keypoints boxes | |
| outputs_body_keypoints_list = [] | |
| outputs_body_keypoints_hw = [] | |
| outputs_lhand_keypoints_list = [] | |
| outputs_lhand_keypoints_hw = [] | |
| outputs_rhand_keypoints_list = [] | |
| outputs_rhand_keypoints_hw = [] | |
| outputs_face_keypoints_list = [] | |
| outputs_face_keypoints_hw = [] | |
| outputs_smpl_pose_list = [] | |
| outputs_smpl_lhand_pose_list = [] | |
| outputs_smpl_rhand_pose_list = [] | |
| outputs_smpl_expr_list = [] | |
| outputs_smpl_jaw_pose_list = [] | |
| outputs_smpl_beta_list = [] | |
| outputs_smpl_cam_list = [] | |
| outputs_smpl_kp2d_list = [] | |
| outputs_smpl_kp3d_list = [] | |
| outputs_smpl_verts_list = [] | |
| # smpl pose | |
| # body box, kps, lhand box | |
| body_index = [0, 1, 2, 3] | |
| smpl_pose_index = [ | |
| x for x in range(self.num_group * 4) if (x % 4 in body_index)] | |
| # smpl lhand | |
| lhand_index = [0, 1] | |
| smpl_lhand_pose_index = [ | |
| x for x in range(self.num_group * 4) if (x % 4 in lhand_index)] | |
| # smpl rhand | |
| rhand_index = [0, 2] | |
| smpl_rhand_pose_index = [ | |
| x for x in range(self.num_group * 4) if (x % 4 in rhand_index)] | |
| # smpl face | |
| face_index = [0, 3] | |
| smpl_face_pose_index = [ | |
| x for x in range(self.num_group * 4) if (x % 4 in face_index)] | |
| for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): | |
| if dec_lid < self.num_box_decoder_layers: | |
| assert isinstance(layer_hs, torch.Tensor) | |
| bs = layer_hs.shape[0] | |
| layer_body_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_body_points * 3)) # [-, 900, 42] | |
| outputs_body_keypoints_list.append(layer_body_kps_res) | |
| # lhand | |
| layer_lhand_bbox_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, 4)) # [-, 900, 42] | |
| outputs_lhand_bbox_list.append(layer_lhand_bbox_res) | |
| layer_lhand_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_hand_points * 3)) # [-, 900, 42] | |
| outputs_lhand_keypoints_list.append(layer_lhand_kps_res) | |
| # rhand | |
| layer_rhand_bbox_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, 4)) # [-, 900, 42] | |
| outputs_rhand_bbox_list.append(layer_rhand_bbox_res) | |
| layer_rhand_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_hand_points * 3)) # [-, 900, 42] | |
| outputs_rhand_keypoints_list.append(layer_rhand_kps_res) | |
| # face | |
| layer_face_bbox_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, 4)) # [-, 900, 42] | |
| outputs_face_bbox_list.append(layer_face_bbox_res) | |
| layer_face_kps_res = layer_hs.new_zeros( | |
| (bs, self.num_queries, | |
| self.num_face_points * 3)) # [-, 900, 42] | |
| outputs_face_keypoints_list.append(layer_face_kps_res) | |
| # smpl or smplx | |
| smpl_pose = layer_hs.new_zeros((bs, self.num_queries, self.body_model_joint_num * 3)) | |
| smpl_rhand_pose = layer_hs.new_zeros( | |
| (bs, self.num_queries, 15 * 3)) | |
| smpl_lhand_pose = layer_hs.new_zeros( | |
| (bs, self.num_queries, 15 * 3)) | |
| smpl_expr = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
| smpl_jaw_pose = layer_hs.new_zeros((bs, self.num_queries, 6)) | |
| smpl_beta = layer_hs.new_zeros((bs, self.num_queries, 10)) | |
| smpl_cam = layer_hs.new_zeros((bs, self.num_queries, 3)) | |
| # smpl_kp2d = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points,3)) | |
| smpl_kp3d = layer_hs.new_zeros( | |
| (bs, self.num_queries, self.num_body_points, 4)) | |
| outputs_smpl_pose_list.append(smpl_pose) | |
| outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
| outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
| outputs_smpl_expr_list.append(smpl_expr) | |
| outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
| outputs_smpl_beta_list.append(smpl_beta) | |
| outputs_smpl_cam_list.append(smpl_cam) | |
| # outputs_smpl_kp2d_list.append(smpl_kp2d) | |
| outputs_smpl_kp3d_list.append(smpl_kp3d) | |
| elif dec_lid < self.num_box_decoder_layers +2: | |
| bs = layer_ref_sig.shape[0] | |
| # lhand bbox | |
| layer_hs_lhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][:, 1::4, :] | |
| delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_sig_lhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone() | |
| layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
| delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
| layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
| outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
| # rhand bbox | |
| layer_hs_rhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][:, 2::4, :] | |
| delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_sig_rhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone() | |
| layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
| delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
| layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
| outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
| # face bbox | |
| layer_hs_face_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][:, 3::4, :] | |
| delta_face_bbox_xy_unsig = self.bbox_face_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_sig_face_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone() | |
| layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
| delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
| layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
| outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
| # smpl or smplx | |
| bs, _, feat_dim = layer_hs.shape | |
| smpl_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 4) | |
| smpl_lhand_pose_feats = \ | |
| layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 2) | |
| smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 2) | |
| smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 2) | |
| smpl_pose = self.smpl_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_feats) | |
| smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, self.body_model_joint_num, 3, 3) | |
| smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
| smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
| smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_jaw_pose = self.smpl_jaw_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, 1, 3, 3) | |
| smpl_beta = self.smpl_beta_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_feats) | |
| smpl_cam = self.smpl_cam_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_feats) | |
| smpl_expr = self.smpl_expr_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| # smpl_jaw_pose = layer_hs.new_zeros(bs, self.num_group, 3) | |
| leye_pose = torch.zeros_like(smpl_jaw_pose) | |
| reye_pose = torch.zeros_like(smpl_jaw_pose) | |
| if self.body_model is not None: | |
| smpl_pose_ = rotmat_to_aa(smpl_pose) | |
| # smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
| # smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
| smpl_lhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
| smpl_rhand_pose_ = layer_hs.new_zeros(bs, self.num_group, 15, 3) | |
| smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
| leye_pose_ = rotmat_to_aa(leye_pose) | |
| reye_pose_ = rotmat_to_aa(reye_pose) | |
| pred_output = self.body_model( | |
| betas=smpl_beta.reshape(-1, 10), | |
| body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
| global_orient=smpl_pose_[:, :, 0].reshape( | |
| -1, 3).unsqueeze(1), | |
| left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
| right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
| leye_pose=leye_pose_, | |
| reye_pose=reye_pose_, | |
| jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
| # expression=smpl_expr.reshape(-1, 10), | |
| expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10) | |
| ) | |
| smpl_kp3d = pred_output['joints'].reshape( | |
| bs, self.num_group, -1, 3) | |
| smpl_verts = pred_output['vertices'].reshape( | |
| bs, self.num_group, -1, 3) | |
| # pred_vertices = pred_output['vertices'].reshape(bs, -1, 6890, 3) | |
| outputs_smpl_pose_list.append(smpl_pose) | |
| outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
| outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
| outputs_smpl_expr_list.append(smpl_expr) | |
| outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
| outputs_smpl_beta_list.append(smpl_beta) | |
| outputs_smpl_cam_list.append(smpl_cam) | |
| outputs_smpl_kp3d_list.append(smpl_kp3d) | |
| else: | |
| bs = layer_ref_sig.shape[0] | |
| # lhand bbox | |
| layer_hs_lhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][:, 1::4, :] | |
| delta_lhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_sig_lhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][:, 1::4, :].clone() | |
| layer_ref_unsig_lhand_bbox = inverse_sigmoid(layer_ref_sig_lhand_bbox) | |
| delta_lhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_lhand_bbox) | |
| layer_ref_unsig_lhand_bbox[..., :2] +=delta_lhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_lhand_bbox[..., 2:] +=delta_lhand_bbox_hw_unsig | |
| layer_ref_sig_lhand_bbox = layer_ref_unsig_lhand_bbox.sigmoid() | |
| outputs_lhand_bbox_list.append(layer_ref_sig_lhand_bbox) | |
| # rhand bbox | |
| layer_hs_rhand_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][:, 2::4, :] | |
| delta_rhand_bbox_xy_unsig = self.bbox_hand_embed[ | |
| dec_lid - self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_sig_rhand_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][:, 2::4, :].clone() | |
| layer_ref_unsig_rhand_bbox = inverse_sigmoid(layer_ref_sig_rhand_bbox) | |
| delta_rhand_bbox_hw_unsig = self.bbox_hand_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_rhand_bbox) | |
| layer_ref_unsig_rhand_bbox[..., :2] +=delta_rhand_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_rhand_bbox[..., 2:] +=delta_rhand_bbox_hw_unsig | |
| layer_ref_sig_rhand_bbox = layer_ref_unsig_rhand_bbox.sigmoid() | |
| outputs_rhand_bbox_list.append(layer_ref_sig_rhand_bbox) | |
| # face bbox | |
| layer_hs_face_bbox = \ | |
| layer_hs[:, effective_dn_number:, :][:, 3::4, :] | |
| delta_face_bbox_xy_unsig = \ | |
| self.bbox_face_embed[dec_lid - self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_sig_face_bbox = \ | |
| layer_ref_sig[:,effective_dn_number:, :][:, 3::4, :].clone() | |
| layer_ref_unsig_face_bbox = inverse_sigmoid(layer_ref_sig_face_bbox) | |
| delta_face_bbox_hw_unsig = self.bbox_face_hw_embed[ | |
| dec_lid-self.num_box_decoder_layers](layer_hs_face_bbox) | |
| layer_ref_unsig_face_bbox[..., :2] +=delta_face_bbox_xy_unsig[..., :2] | |
| layer_ref_unsig_face_bbox[..., 2:] +=delta_face_bbox_hw_unsig | |
| layer_ref_sig_face_bbox = layer_ref_unsig_face_bbox.sigmoid() | |
| outputs_face_bbox_list.append(layer_ref_sig_face_bbox) | |
| bs, _, feat_dim = layer_hs.shape | |
| smpl_body_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 4) | |
| smpl_lhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_lhand_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 2) | |
| smpl_rhand_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_rhand_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 2) | |
| smpl_face_pose_feats = layer_hs[:, effective_dn_number:, :].index_select( | |
| 1, torch.tensor(smpl_face_pose_index, device=layer_hs.device) | |
| ).reshape(bs, -1, feat_dim * 2) | |
| smpl_pose = self.smpl_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
| smpl_pose = rot6d_to_rotmat(smpl_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, self.body_model_joint_num, 3, 3) | |
| smpl_lhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_lhand_pose_feats) | |
| smpl_lhand_pose = rot6d_to_rotmat(smpl_lhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_rhand_pose = self.smpl_hand_pose_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_rhand_pose_feats) | |
| smpl_rhand_pose = rot6d_to_rotmat(smpl_rhand_pose.reshape( | |
| -1, 6)).reshape(bs, self.num_group, 15, 3, 3) | |
| smpl_expr = self.smpl_expr_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| smpl_jaw_pose = self.smpl_jaw_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_face_pose_feats) | |
| smpl_jaw_pose = rot6d_to_rotmat(smpl_jaw_pose.reshape(-1, 6)).reshape( | |
| bs, self.num_group, 1, 3, 3) | |
| smpl_beta = self.smpl_beta_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
| smpl_cam = self.smpl_cam_embed[ | |
| dec_lid - self.num_box_decoder_layers](smpl_body_pose_feats) | |
| num_samples = smpl_beta.reshape(-1, 10).shape[0] | |
| device = smpl_beta.device | |
| leye_pose = torch.zeros_like(smpl_jaw_pose) | |
| reye_pose = torch.zeros_like(smpl_jaw_pose) | |
| if self.body_model is not None: | |
| smpl_pose_ = rotmat_to_aa(smpl_pose) | |
| smpl_lhand_pose_ = rotmat_to_aa(smpl_lhand_pose) | |
| smpl_rhand_pose_ = rotmat_to_aa(smpl_rhand_pose) | |
| smpl_jaw_pose_ = rotmat_to_aa(smpl_jaw_pose) | |
| leye_pose_ = rotmat_to_aa(leye_pose) | |
| reye_pose_ = rotmat_to_aa(reye_pose) | |
| pred_output = self.body_model( | |
| betas=smpl_beta.reshape(-1, 10), | |
| body_pose=smpl_pose_[:, :, 1:].reshape(-1, 21 * 3), | |
| global_orient=smpl_pose_[:, :, 0].reshape( | |
| -1, 3).unsqueeze(1), | |
| left_hand_pose=smpl_lhand_pose_.reshape(-1, 15 * 3), | |
| right_hand_pose=smpl_rhand_pose_.reshape(-1, 15 * 3), | |
| leye_pose=leye_pose_, | |
| reye_pose=reye_pose_, | |
| jaw_pose=smpl_jaw_pose_.reshape(-1, 3), | |
| expression=smpl_expr.reshape(-1, 10), | |
| # expression=layer_hs.new_zeros(bs, self.num_group, 10).reshape(-1, 10), | |
| ) | |
| smpl_kp3d = pred_output['joints'].reshape( | |
| bs, self.num_group, -1, 3) | |
| smpl_verts = pred_output['vertices'].reshape( | |
| bs, self.num_group, -1, 3) | |
| outputs_smpl_pose_list.append(smpl_pose) | |
| outputs_smpl_rhand_pose_list.append(smpl_rhand_pose) | |
| outputs_smpl_lhand_pose_list.append(smpl_lhand_pose) | |
| outputs_smpl_expr_list.append(smpl_expr) | |
| outputs_smpl_jaw_pose_list.append(smpl_jaw_pose) | |
| outputs_smpl_beta_list.append(smpl_beta) | |
| outputs_smpl_cam_list.append(smpl_cam) | |
| outputs_smpl_kp3d_list.append(smpl_kp3d) | |
| if not self.training: | |
| outputs_smpl_verts_list.append(smpl_verts) | |
| dn_mask_dict = mask_dict | |
| if self.dn_number > 0 and dn_mask_dict is not None: | |
| outputs_class, outputs_body_bbox_list = self.dn_post_process2( | |
| outputs_class, outputs_body_bbox_list, dn_mask_dict) | |
| dn_class_input = dn_mask_dict['known_labels'] | |
| dn_bbox_input = dn_mask_dict['known_bboxs'] | |
| dn_class_pred = dn_mask_dict['output_known_class'] | |
| dn_bbox_pred = dn_mask_dict['output_known_coord'] | |
| for idx, (_out_class, _out_bbox) in enumerate(zip(outputs_class, outputs_body_bbox_list)): | |
| assert _out_class.shape[1] == _out_bbox.shape[1] | |
| out = { | |
| 'pred_logits': outputs_class[-1], | |
| 'pred_boxes': outputs_body_bbox_list[-1], | |
| 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
| 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
| 'pred_face_boxes': outputs_face_bbox_list[-1], | |
| 'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
| 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
| 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
| 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
| 'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
| 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
| 'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
| 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1] | |
| } | |
| if not self.training: | |
| full_pose = torch.cat((outputs_smpl_pose_list[-1], | |
| outputs_smpl_lhand_pose_list[-1], | |
| outputs_smpl_rhand_pose_list[-1], | |
| outputs_smpl_jaw_pose_list[-1]),dim=2) | |
| bs,num_q,_,_,_ = full_pose.shape | |
| full_pose = rotmat_to_aa(full_pose).reshape(bs,num_q,53*3) | |
| out = { | |
| 'pred_logits': outputs_class[-1], | |
| 'pred_boxes': outputs_body_bbox_list[-1], | |
| 'pred_lhand_boxes': outputs_lhand_bbox_list[-1], | |
| 'pred_rhand_boxes': outputs_rhand_bbox_list[-1], | |
| 'pred_face_boxes': outputs_face_bbox_list[-1], | |
| 'pred_smpl_pose': outputs_smpl_pose_list[-1], | |
| 'pred_smpl_rhand_pose': outputs_smpl_rhand_pose_list[-1], | |
| 'pred_smpl_lhand_pose': outputs_smpl_lhand_pose_list[-1], | |
| 'pred_smpl_jaw_pose': outputs_smpl_jaw_pose_list[-1], | |
| 'pred_smpl_expr': outputs_smpl_expr_list[-1], | |
| 'pred_smpl_beta': outputs_smpl_beta_list[-1], # [B, 100, 10] | |
| 'pred_smpl_cam': outputs_smpl_cam_list[-1], | |
| 'pred_smpl_kp3d': outputs_smpl_kp3d_list[-1], | |
| 'pred_smpl_verts': outputs_smpl_verts_list[-1], | |
| 'pred_smpl_fullpose': full_pose | |
| } | |
| if self.dn_number > 0 and dn_mask_dict is not None: | |
| out.update({ | |
| 'dn_class_input': dn_class_input, | |
| 'dn_bbox_input': dn_bbox_input, | |
| 'dn_class_pred': dn_class_pred[-1], | |
| 'dn_bbox_pred': dn_bbox_pred[-1], | |
| 'num_tgt': dn_mask_dict['pad_size'] | |
| }) | |
| if self.aux_loss: | |
| out['aux_outputs'] = \ | |
| self._set_aux_loss( | |
| outputs_class, | |
| outputs_body_bbox_list, | |
| outputs_lhand_bbox_list, | |
| outputs_rhand_bbox_list, | |
| outputs_face_bbox_list, | |
| outputs_smpl_pose_list, | |
| outputs_smpl_rhand_pose_list, | |
| outputs_smpl_lhand_pose_list, | |
| outputs_smpl_jaw_pose_list, | |
| outputs_smpl_expr_list, | |
| outputs_smpl_beta_list, | |
| outputs_smpl_cam_list, | |
| outputs_smpl_kp3d_list | |
| ) # with key pred_logits, pred_bbox, pred_keypoints | |
| if self.dn_number > 0 and dn_mask_dict is not None: | |
| assert len(dn_class_pred[:-1]) == len( | |
| dn_bbox_pred[:-1]) == len(out['aux_outputs']) | |
| for aux_out, dn_class_pred_i, dn_bbox_pred_i in zip( | |
| out['aux_outputs'], dn_class_pred, dn_bbox_pred): | |
| aux_out.update({ | |
| 'dn_class_input': dn_class_input, | |
| 'dn_bbox_input': dn_bbox_input, | |
| 'dn_class_pred': dn_class_pred_i, | |
| 'dn_bbox_pred': dn_bbox_pred_i, | |
| 'num_tgt': dn_mask_dict['pad_size'] | |
| }) | |
| # for encoder output | |
| if hs_enc is not None: | |
| interm_coord = ref_enc[-1] | |
| interm_class = self.transformer.enc_out_class_embed(hs_enc[-1]) | |
| interm_pose = torch.zeros_like(outputs_body_keypoints_list[0]) | |
| out['interm_outputs'] = { | |
| 'pred_logits': interm_class, | |
| 'pred_boxes': interm_coord, | |
| 'pred_keypoints': interm_pose | |
| } | |
| return out, targets, data_batch | |
| def _set_aux_loss(self, | |
| outputs_class, | |
| outputs_body_coord, | |
| outputs_lhand_coord, | |
| outputs_rhand_coord, | |
| outputs_face_coord, | |
| outputs_smpl_pose, | |
| outputs_smpl_rhand_pose, | |
| outputs_smpl_lhand_pose, | |
| outputs_smpl_jaw_pose, | |
| outputs_smpl_expr, | |
| outputs_smpl_beta, | |
| outputs_smpl_cam, | |
| outputs_smpl_kp3d): | |
| return [{ | |
| 'pred_logits': a, | |
| 'pred_boxes': b, | |
| 'pred_lhand_boxes': c, | |
| 'pred_rhand_boxes': d, | |
| 'pred_face_boxes': e, | |
| 'pred_smpl_pose': j, | |
| 'pred_smpl_rhand_pose': k, | |
| 'pred_smpl_lhand_pose': l, | |
| 'pred_smpl_jaw_pose': m, | |
| 'pred_smpl_expr': n, | |
| 'pred_smpl_beta': o, | |
| 'pred_smpl_cam': p, | |
| 'pred_smpl_kp3d': q | |
| } for a, b, c, d, e, j, k, l, m, n, o, p, q in zip( | |
| outputs_class[:-1], | |
| outputs_body_coord[:-1], | |
| outputs_lhand_coord[:-1], | |
| outputs_rhand_coord[:-1], | |
| outputs_face_coord[:-1], | |
| outputs_smpl_pose[:-1], | |
| outputs_smpl_rhand_pose[:-1], | |
| outputs_smpl_lhand_pose[:-1], | |
| outputs_smpl_jaw_pose[:-1], | |
| outputs_smpl_expr[:-1], | |
| outputs_smpl_beta[:-1], | |
| outputs_smpl_cam[:-1], | |
| outputs_smpl_kp3d[:-1])] | |
| def prepare_targets(self, data_batch): | |
| data_batch_coco = [] | |
| instance_dict = {} | |
| img_list = data_batch['img'].float() | |
| # input_img_h, input_img_w = data_batch['image_metas'][0]['batch_input_shape'] | |
| batch_size, _, input_img_h, input_img_w = img_list.shape | |
| device = img_list.device | |
| masks = torch.ones((batch_size, input_img_h, input_img_w), | |
| dtype=torch.bool, | |
| device=device) | |
| if self.num_body_points == 17: | |
| ed_convention = 'coco' | |
| elif self.num_body_points == 14: | |
| ed_convention = 'crowdpose' | |
| # cv2.imread(data_batch['img_metas'][img_id]['image_path']).shape | |
| for img_id in range(batch_size): | |
| img_h, img_w = data_batch['img_shape'][img_id] | |
| masks[img_id, :img_h, :img_w] = 0 | |
| if not self.inference: | |
| instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
| data_batch['body_bbox_size'][img_id]],dim=-1) | |
| instance_face_bbox = torch.cat([data_batch['face_bbox_center'][img_id],\ | |
| data_batch['face_bbox_size'][img_id]],dim=-1) | |
| instance_lhand_bbox = torch.cat([data_batch['lhand_bbox_center'][img_id],\ | |
| data_batch['lhand_bbox_size'][img_id]],dim=-1) | |
| instance_rhand_bbox = torch.cat([data_batch['rhand_bbox_center'][img_id],\ | |
| data_batch['rhand_bbox_size'][img_id]],dim=-1) | |
| instance_kp2d = data_batch['joint_img'][img_id].clone().float() | |
| instance_kp2d_mask = data_batch['joint_trunc'][img_id].clone().float() | |
| instance_kp2d[:,:,2:] = instance_kp2d_mask | |
| body_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'coco', approximate=True) | |
| lhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_lhand', approximate=True) | |
| rhand_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_rhand', approximate=True) | |
| face_kp2d, _ = convert_kps(instance_kp2d, 'smplx_137', 'smplx_face', approximate=True) | |
| # from util.vis_utils import show_bbox | |
| # show_bbox(img_list[img_id],instance_kp2d.cpu().numpy(),data_batch['bbox_xywh'][img_id].cpu().numpy) | |
| body_kp2d[:,:,0] = body_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| body_kp2d[:,:,1] = body_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| body_kp2d = torch.cat([body_kp2d[:,:,:2].flatten(1),body_kp2d[:,:,2]],dim=-1) | |
| lhand_kp2d[:,:,0] = lhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| lhand_kp2d[:,:,1] = lhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| lhand_kp2d = torch.cat([lhand_kp2d[:,:,:2].flatten(1),lhand_kp2d[:,:,2]],dim=-1) | |
| rhand_kp2d[:,:,0] = rhand_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| rhand_kp2d[:,:,1] = rhand_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| rhand_kp2d = torch.cat([rhand_kp2d[:,:,:2].flatten(1),rhand_kp2d[:,:,2]],dim=-1) | |
| face_kp2d[:,:,0] = face_kp2d[:,:,0]/cfg.output_hm_shape[2] | |
| face_kp2d[:,:,1] = face_kp2d[:,:,1]/cfg.output_hm_shape[1] | |
| face_kp2d = torch.cat([face_kp2d[:,:,:2].flatten(1),face_kp2d[:,:,2]],dim=-1) | |
| instance_dict = {} | |
| instance_dict['boxes'] = instance_body_bbox.float() | |
| instance_dict['face_boxes'] = instance_face_bbox.float() | |
| instance_dict['lhand_boxes'] = instance_lhand_bbox.float() | |
| instance_dict['rhand_boxes'] = instance_rhand_bbox.float() | |
| instance_dict['keypoints'] = body_kp2d.float() | |
| instance_dict['lhand_keypoints'] = lhand_kp2d.float() | |
| instance_dict['rhand_keypoints'] = rhand_kp2d.float() | |
| instance_dict['face_keypoints'] = face_kp2d.float() | |
| # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
| instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
| instance_dict['area'] = instance_body_bbox[:, 2] * instance_body_bbox[:, 3] | |
| instance_dict['lhand_area'] = instance_lhand_bbox[:, 2] * instance_lhand_bbox[:, 3] | |
| instance_dict['rhand_area'] = instance_rhand_bbox[:, 2] * instance_rhand_bbox[:, 3] | |
| instance_dict['face_area'] = instance_face_bbox[:, 2] * instance_face_bbox[:, 3] | |
| instance_dict['labels'] = torch.ones(instance_body_bbox.shape[0], | |
| dtype=torch.long, | |
| device=device) | |
| data_batch_coco.append(instance_dict) | |
| else: | |
| instance_body_bbox = torch.cat([data_batch['body_bbox_center'][img_id],\ | |
| data_batch['body_bbox_size'][img_id]],dim=-1) | |
| instance_dict = {} | |
| # instance_dict['orig_size'] = data_batch['ori_shape'][img_id] | |
| instance_dict['size'] = data_batch['img_shape'][img_id] # after augmentation | |
| instance_dict['boxes'] = instance_body_bbox.float() | |
| data_batch_coco.append(instance_dict) | |
| input_img = NestedTensor(img_list, masks) | |
| return input_img, data_batch_coco | |
| def keypoints_to_scaled_bbox_bfh( | |
| self, keypoints, occ=None, | |
| body_scale=1.0, fh_scale=1.0, | |
| convention='smplx'): | |
| '''Obtain scaled bbox in xyxy format given keypoints | |
| Args: | |
| keypoints (np.ndarray): Keypoints | |
| scale (float): Bounding Box scale | |
| Returns: | |
| bbox_xyxy (np.ndarray): Bounding box in xyxy format | |
| ''' | |
| bboxs = [] | |
| # supported kps.shape: (1, n, k) or (n, k), k = 2 or 3 | |
| if keypoints.ndim == 3: | |
| keypoints = keypoints[0] | |
| if keypoints.shape[-1] != 2: | |
| keypoints = keypoints[:, :2] | |
| for body_part in ['body', 'head', 'left_hand', 'right_hand']: | |
| if body_part == 'body': | |
| scale = body_scale | |
| kps = keypoints | |
| else: | |
| scale = fh_scale | |
| kp_id = get_keypoint_idxs_by_part(body_part, convention=convention) | |
| kps = keypoints[kp_id] | |
| if not occ is None: | |
| occ_p = occ[kp_id] | |
| if np.sum(occ_p) / len(kp_id) >= 0.1: | |
| conf = 0 | |
| # print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip') | |
| else: | |
| # print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}') | |
| conf = 1 | |
| else: | |
| conf = 1 | |
| if body_part == 'body': | |
| conf = 1 | |
| xmin, ymin = np.amin(kps, axis=0) | |
| xmax, ymax = np.amax(kps, axis=0) | |
| width = (xmax - xmin) * scale | |
| height = (ymax - ymin) * scale | |
| x_center = 0.5 * (xmax + xmin) | |
| y_center = 0.5 * (ymax + ymin) | |
| xmin = x_center - 0.5 * width | |
| xmax = x_center + 0.5 * width | |
| ymin = y_center - 0.5 * height | |
| ymax = y_center + 0.5 * height | |
| bbox = np.stack([xmin, ymin, xmax, ymax, conf], axis=0).astype(np.float32) | |
| bboxs.append(bbox) | |
| return bboxs | |
| def build_aios_smplx_box(args, cfg): | |
| # pdb.set_trace() | |
| num_classes = args.num_classes # 2 | |
| device = torch.device(args.device) | |
| backbone = build_backbone(args) | |
| transformer = build_transformer(args) | |
| dn_labelbook_size = args.dn_labelbook_size | |
| dec_pred_class_embed_share = args.dec_pred_class_embed_share | |
| dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share | |
| if args.eval: | |
| body_model = args.body_model_test | |
| train = False | |
| else: | |
| body_model = args.body_model_train | |
| train = True | |
| model = AiOSSMPLX_Box( | |
| backbone, | |
| transformer, | |
| num_classes=num_classes, # 2 | |
| num_queries=args.num_queries, # 900 | |
| aux_loss=True, | |
| iter_update=True, | |
| query_dim=4, | |
| random_refpoints_xy=args.random_refpoints_xy, # False | |
| fix_refpoints_hw=args.fix_refpoints_hw, # -1 | |
| num_feature_levels=args.num_feature_levels, # 4 | |
| nheads=args.nheads, # 8 | |
| dec_pred_class_embed_share=dec_pred_class_embed_share, # false | |
| dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, # False | |
| # two stage | |
| two_stage_type=args.two_stage_type, | |
| # box_share | |
| two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, # False | |
| two_stage_class_embed_share=args.two_stage_class_embed_share, # False | |
| dn_number=args.dn_number if args.use_dn else 0, # 100 | |
| dn_box_noise_scale=args.dn_box_noise_scale, # 0.4 | |
| dn_label_noise_ratio=args.dn_label_noise_ratio, # 0.5 | |
| dn_batch_gt_fuse=args.dn_batch_gt_fuse, # false | |
| dn_attn_mask_type_list=args.dn_attn_mask_type_list, | |
| dn_labelbook_size=dn_labelbook_size, # 100 | |
| cls_no_bias=args.cls_no_bias, # False | |
| num_group=args.num_group, # 100 | |
| num_body_points=0, # 17 | |
| num_hand_points=0, # 17 | |
| num_face_points=0, # 17 | |
| num_box_decoder_layers=args.num_box_decoder_layers, # 2 | |
| num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
| # smpl_convention=convention | |
| body_model=body_model, | |
| train=train, | |
| inference=args.inference) | |
| matcher = build_matcher(args) | |
| # prepare weight dict | |
| weight_dict = { | |
| 'loss_ce': args.cls_loss_coef, # 2 | |
| # bbox | |
| 'loss_body_bbox': args.body_bbox_loss_coef, # 5 | |
| 'loss_rhand_bbox': args.rhand_bbox_loss_coef, # 5 | |
| 'loss_lhand_bbox': args.lhand_bbox_loss_coef, # 5 | |
| 'loss_face_bbox': args.face_bbox_loss_coef, # 5 | |
| # bbox giou | |
| 'loss_body_giou': args.body_giou_loss_coef, # 2 | |
| 'loss_rhand_giou': args.rhand_giou_loss_coef, # 2 | |
| 'loss_lhand_giou': args.lhand_giou_loss_coef, # 2 | |
| 'loss_face_giou': args.face_giou_loss_coef, # 2 | |
| # smpl param | |
| 'loss_smpl_pose_root': args.smpl_pose_loss_root_coef, # 0 | |
| 'loss_smpl_pose_body': args.smpl_pose_loss_body_coef, # 0 | |
| 'loss_smpl_pose_lhand': args.smpl_pose_loss_lhand_coef, # 0 | |
| 'loss_smpl_pose_rhand': args.smpl_pose_loss_rhand_coef, # 0 | |
| 'loss_smpl_pose_jaw': args.smpl_pose_loss_jaw_coef, # 0 | |
| 'loss_smpl_beta': args.smpl_beta_loss_coef, # 0 | |
| 'loss_smpl_expr': args.smpl_expr_loss_coef, | |
| # smpl kp3d ra | |
| 'loss_smpl_body_kp3d_ra': args.smpl_body_kp3d_ra_loss_coef, # 0 | |
| 'loss_smpl_lhand_kp3d_ra': args.smpl_lhand_kp3d_ra_loss_coef, # 0 | |
| 'loss_smpl_rhand_kp3d_ra': args.smpl_rhand_kp3d_ra_loss_coef, # 0 | |
| 'loss_smpl_face_kp3d_ra': args.smpl_face_kp3d_ra_loss_coef, # 0 | |
| # smpl kp3d | |
| 'loss_smpl_body_kp3d': args.smpl_body_kp3d_loss_coef, # 0 | |
| 'loss_smpl_face_kp3d': args.smpl_face_kp3d_loss_coef, # 0 | |
| 'loss_smpl_lhand_kp3d': args.smpl_lhand_kp3d_loss_coef, # 0 | |
| 'loss_smpl_rhand_kp3d': args.smpl_rhand_kp3d_loss_coef, # 0 | |
| # smpl kp2d | |
| 'loss_smpl_body_kp2d': args.smpl_body_kp2d_loss_coef, # 0 | |
| 'loss_smpl_lhand_kp2d': args.smpl_lhand_kp2d_loss_coef, # 0 | |
| 'loss_smpl_rhand_kp2d': args.smpl_rhand_kp2d_loss_coef, # 0 | |
| 'loss_smpl_face_kp2d': args.smpl_face_kp2d_loss_coef, # 0 | |
| } | |
| clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) | |
| if args.use_dn: | |
| weight_dict.update({ | |
| 'dn_loss_ce': | |
| args.dn_label_coef, # 0.3 | |
| 'dn_loss_bbox': | |
| args.bbox_loss_coef * args.dn_bbox_coef, # 5 * 0.5 | |
| 'dn_loss_giou': | |
| args.giou_loss_coef * args.dn_bbox_coef, # 2 * 0.5 | |
| }) | |
| clean_weight_dict = copy.deepcopy(weight_dict) | |
| if args.aux_loss: | |
| aux_weight_dict = {} | |
| for i in range(args.dec_layers - 1): # from 0 t 4 # ??? | |
| for k, v in clean_weight_dict.items(): | |
| if i < args.num_box_decoder_layers and ('keypoints' in k or 'oks' in k): | |
| continue | |
| if i < args.num_box_decoder_layers and k in [ | |
| 'loss_rhand_bbox', 'loss_lhand_bbox', 'loss_face_bbox', | |
| 'loss_rhand_giou', 'loss_lhand_giou', 'loss_face_giou']: | |
| continue | |
| if i < args.num_hand_face_decoder_layers and k in [ | |
| 'loss_rhand_keypoints', 'loss_lhand_keypoints', | |
| 'loss_face_keypoints', 'loss_rhand_oks', | |
| 'loss_lhand_oks', 'loss_face_oks']: | |
| continue | |
| if i < args.num_box_decoder_layers and 'smpl' in k: | |
| continue | |
| aux_weight_dict.update({k + f'_{i}': v}) | |
| weight_dict.update(aux_weight_dict) | |
| if args.two_stage_type != 'no': | |
| interm_weight_dict = {} | |
| try: | |
| no_interm_box_loss = args.no_interm_box_loss | |
| except: | |
| no_interm_box_loss = False | |
| _coeff_weight_dict = { | |
| 'loss_ce': 1.0, | |
| # bbox | |
| 'loss_body_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_rhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_lhand_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_face_bbox': 1.0 if not no_interm_box_loss else 0.0, | |
| # bbox giou | |
| 'loss_body_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_rhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_lhand_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_face_giou': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl param | |
| 'loss_smpl_pose_root': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_body': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_lhand': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_rhand': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_pose_jaw': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_beta': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_expr': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl kp3d ra | |
| 'loss_smpl_body_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_lhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_rhand_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_face_kp3d_ra': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl kp3d | |
| 'loss_smpl_body_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_face_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_lhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_rhand_kp3d': 1.0 if not no_interm_box_loss else 0.0, | |
| # smpl kp2d | |
| 'loss_smpl_body_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_lhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_rhand_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| 'loss_smpl_face_kp2d': 1.0 if not no_interm_box_loss else 0.0, | |
| } | |
| try: | |
| interm_loss_coef = args.interm_loss_coef # 1 | |
| except: | |
| interm_loss_coef = 1.0 | |
| interm_weight_dict.update({ | |
| k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] | |
| for k, v in clean_weight_dict_wo_dn.items() if 'keypoints' not in k | |
| }) | |
| weight_dict.update(interm_weight_dict) | |
| interm_weight_dict.update({ | |
| k + f'_query_expand': v * interm_loss_coef * _coeff_weight_dict[k] | |
| for k, v in clean_weight_dict_wo_dn.items() | |
| }) # ??? | |
| weight_dict.update(interm_weight_dict) | |
| losses = cfg.losses | |
| if args.dn_number > 0: | |
| losses += ['dn_label', 'dn_bbox'] | |
| losses += ['matching'] | |
| criterion = SetCriterion_Box( | |
| num_classes, | |
| matcher=matcher, | |
| weight_dict=weight_dict, | |
| focal_alpha=args.focal_alpha, | |
| losses=losses, | |
| num_box_decoder_layers=args.num_box_decoder_layers, | |
| num_hand_face_decoder_layers=args.num_hand_face_decoder_layers, | |
| num_body_points=0, | |
| num_hand_points=0, | |
| num_face_points=0, | |
| ) | |
| criterion.to(device) | |
| if args.inference: | |
| postprocessors = { | |
| 'bbox': | |
| PostProcess_SMPLX_Multi_Infer_Box( | |
| num_select=args.num_select, | |
| nms_iou_threshold=args.nms_iou_threshold, | |
| num_body_points=0), | |
| } | |
| else: | |
| postprocessors = { | |
| 'bbox': | |
| PostProcess_SMPLX_Multi_Box( | |
| num_select=args.num_select, | |
| nms_iou_threshold=args.nms_iou_threshold, | |
| num_body_points=0), | |
| } | |
| postprocessors_aios = { | |
| 'bbox': | |
| PostProcess_aios(num_select=args.num_select, | |
| nms_iou_threshold=args.nms_iou_threshold, | |
| num_body_points=0), | |
| } | |
| return model, criterion, postprocessors, postprocessors_aios | |