import os import shutil import argparse import emage.mertic from moviepy.tools import verbose_print from omegaconf import OmegaConf import random import numpy as np import json import librosa from datetime import datetime import importlib import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import DataLoader from torch.nn.parallel import DistributedDataParallel as DDP # from torch.utils.tensorboard import SummaryWriter import wandb from diffusers.optimization import get_scheduler from tqdm import tqdm import smplx from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip import igraph import emage import utils.rotation_conversions as rc from create_graph import path_visualization, graph_pruning, get_motion_reps_tensor def search_path(graph, audio_low_np, audio_high_np, top_k=1, loop_penalty=0.1, search_mode="both"): T = audio_low_np.shape[0] # Total time steps # Initialize the beam with start nodes (nodes with no previous node) start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1] beam = [] for node in start_nodes: motion_low = node['motion_low'] # Shape: [C] motion_high = node['motion_high'] # Shape: [C] # cost = np.linalg.norm(audio_low_np[0] - motion_low) + np.linalg.norm(audio_high_np - motion_high) if search_mode == "both": cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T)) elif search_mode == "high_level": cost = 1 - np.dot(audio_high_np[0], motion_high.T) elif search_mode == "low_level": cost = 1 - np.dot(audio_low_np[0], motion_low.T) sequence = [node] beam.append((cost, sequence)) # Keep only the top_k initial nodes beam.sort(key=lambda x: x[0]) beam = beam[:top_k] # Beam search over time steps for t in range(1, T): new_beam = [] for cost, seq in beam: last_node = seq[-1] neighbor_indices = graph.neighbors(last_node.index, mode='OUT') if not neighbor_indices: continue # No outgoing edges from the last node for idx in neighbor_indices: neighbor = graph.vs[idx] # Check for loops if neighbor in seq: # Apply loop penalty loop_cost = cost + loop_penalty else: loop_cost = cost motion_low = neighbor['motion_low'] # Shape: [C] motion_high = neighbor['motion_high'] # Shape: [C] # cost_increment = np.linalg.norm(audio_low_np[t] - motion_low) + np.linalg.norm(audio_high_np[t] - motion_high) if search_mode == "both": cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T)) elif search_mode == "high_level": cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T) elif search_mode == "low_level": cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T) new_cost = loop_cost + cost_increment new_seq = seq + [neighbor] new_beam.append((new_cost, new_seq)) if not new_beam: break # Cannot extend any further # Keep only the top_k sequences new_beam.sort(key=lambda x: x[0]) beam = new_beam[:top_k] # Extract paths and continuity information path_list = [] is_continue_list = [] for cost, seq in beam: path_list.append(seq) print("Cost: ", cost, "path", [node.index for node in seq]) is_continue = [] for i in range(len(seq) - 1): edge_id = graph.get_eid(seq[i].index, seq[i + 1].index) is_cont = graph.es[edge_id]['is_continue'] is_continue.append(is_cont) is_continue_list.append(is_continue) return path_list, is_continue_list def search_path_dp(graph, audio_low_np, audio_high_np, loop_penalty=0.01, top_k=1, search_mode="both", continue_penalty=0.01): T = audio_low_np.shape[0] # Total time steps N = len(graph.vs) # Total number of nodes in the graph # Initialize DP tables min_cost = [{} for _ in range(T)] # min_cost[t][node.index] = (cost, predecessor_index, non_continue_count) visited_nodes = [{} for _ in range(T)] # visited_nodes[t][node.index] = dict of node visit counts # Initialize the first time step start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1] for node in start_nodes: motion_low = node['motion_low'] # Shape: [C] motion_high = node['motion_high'] # Shape: [C] # Cost using cosine similarity if search_mode == "both": cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T)) elif search_mode == "high_level": cost = 1 - np.dot(audio_high_np[0], motion_high.T) elif search_mode == "low_level": cost = 1 - np.dot(audio_low_np[0], motion_low.T) min_cost[0][node.index] = (cost, None, 0) # Initialize with no predecessor and 0 non-continue count visited_nodes[0][node.index] = {node.index: 1} # Initialize visit count as a dictionary # DP over time steps for t in range(1, T): for node in graph.vs: node_index = node.index min_cost_t = float('inf') best_predecessor = None best_visited = None best_non_continue_count = 0 # Incoming edges to the current node incoming_edges = graph.es.select(_to=node_index) for edge in incoming_edges: prev_node_index = edge.source prev_node = graph.vs[prev_node_index] if prev_node_index in min_cost[t-1]: prev_cost, _, prev_non_continue_count = min_cost[t-1][prev_node_index] prev_visited = visited_nodes[t-1][prev_node_index] # Loop punishment if node_index in prev_visited: loop_time = prev_visited[node_index] # Get the count of previous visits loop_cost = prev_cost + loop_penalty * np.exp(loop_time) # Apply exponential penalty new_visited = prev_visited.copy() new_visited[node_index] = loop_time + 1 # Increment visit count else: loop_cost = prev_cost new_visited = prev_visited.copy() new_visited[node_index] = 1 # Initialize visit count for the new node motion_low = node['motion_low'] # Shape: [C] motion_high = node['motion_high'] # Shape: [C] if search_mode == "both": cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T)) elif search_mode == "high_level": cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T) elif search_mode == "low_level": cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T) # Check if the edge is "is_continue" edge_id = edge.index is_continue = graph.es[edge_id]['is_continue'] if not is_continue: non_continue_count = prev_non_continue_count + 1 # Increment the count of non-continue edges else: non_continue_count = prev_non_continue_count # Apply the penalty based on the square of the number of non-continuous edges continue_penalty_cost = continue_penalty * non_continue_count total_cost = loop_cost + cost_increment + continue_penalty_cost if total_cost < min_cost_t: min_cost_t = total_cost best_predecessor = prev_node_index best_visited = new_visited best_non_continue_count = non_continue_count if best_predecessor is not None: min_cost[t][node_index] = (min_cost_t, best_predecessor, best_non_continue_count) visited_nodes[t][node_index] = best_visited # Store the new visit count dictionary # Find the node with the minimal cost at the last time step final_min_cost = float('inf') final_node_index = None for node_index, (cost, _, _) in min_cost[T-1].items(): if cost < final_min_cost: final_min_cost = cost final_node_index = node_index if final_node_index is None: print("No valid path found.") return [], [] # Backtrack to reconstruct the optimal path optimal_path_indices = [] current_node_index = final_node_index for t in range(T-1, -1, -1): optimal_path_indices.append(current_node_index) _, predecessor, _ = min_cost[t][current_node_index] current_node_index = predecessor if predecessor is not None else None optimal_path_indices = optimal_path_indices[::-1] # Reverse to get correct order optimal_path = [graph.vs[idx] for idx in optimal_path_indices] # Extract continuity information is_continue = [] for i in range(len(optimal_path) - 1): edge_id = graph.get_eid(optimal_path[i].index, optimal_path[i + 1].index) is_cont = graph.es[edge_id]['is_continue'] is_continue.append(is_cont) print("Optimal Cost: ", final_min_cost, "Path: ", optimal_path_indices) return [optimal_path], [is_continue] # from torch.cuda.amp import autocast, GradScaler # from torch.nn.utils import clip_grad_norm_ # # Initialize GradScaler # scaler = GradScaler() def train_val_fn(batch, model, device, mode="train", optimizer=None, lr_scheduler=None, max_grad_norm=1.0, **kwargs): if mode == "train": model.train() torch.set_grad_enabled(True) optimizer.zero_grad() else: model.eval() torch.set_grad_enabled(False) cached_rep15d = batch["cached_rep15d"].to(device) cached_audio_low = batch["cached_audio_low"].to(device) cached_audio_high = batch["cached_audio_high"].to(device) bert_time_aligned = batch["bert_time_aligned"].to(device) cached_audio_high = torch.cat([cached_audio_high, bert_time_aligned], dim=-1) audio_tensor = batch["audio_tensor"].to(device) # with autocast(): # Mixed precision context model_out = model(cached_rep15d=cached_rep15d, cached_audio_low=cached_audio_low, cached_audio_high=cached_audio_high, in_audio=audio_tensor) audio_lower = model_out["audio_low"] motion_lower = model_out["motion_low"] audio_hihger_cls = model_out["audio_cls"] motion_higher_cls = model_out["motion_cls"] high_loss = model_out["high_level_loss"] low_infonce, low_acc = model_out["low_level_loss"] loss_dict = { "low_cosine": low_infonce, "high_infonce": high_loss } loss = sum(loss_dict.values()) loss_dict["loss"] = loss loss_dict["low_acc"] = low_acc loss_dict["acc"] = compute_average_precision(audio_hihger_cls, motion_higher_cls) if mode == "train": # Use GradScaler for backward pass # scaler.scale(loss).backward() # Clip gradients to the maximum norm # scaler.unscale_(optimizer) # Unscale gradients before clipping # clip_grad_norm_(model.parameters(), max_grad_norm) # Step the optimizer # scaler.step(optimizer) # scaler.update() loss.backward() optimizer.step() lr_scheduler.step() return loss_dict def test_fn(model, device, smplx_model, iteration, fgd_fn, srgr_fn, bc_fn, l1div_fn, candidate_json_path, test_path, cfg, **kwargs): torch.set_grad_enabled(False) pool_path = "./datasets/oliver_test/show-oliver-test.pkl" graph = igraph.Graph.Read_Pickle(fname=pool_path) save_dir = os.path.join(test_path, f"retrieved_motions_{iteration}") os.makedirs(save_dir, exist_ok=True) actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model actual_model.eval() with open(candidate_json_path, 'r') as f: candidate_data = json.load(f) all_motions = {} for i, node in enumerate(graph.vs): if all_motions.get(node["name"]) is None: all_motions[node["name"]] = [node["axis_angle"].reshape(-1)] else: all_motions[node["name"]].append(node["axis_angle"].reshape(-1)) for k, v in all_motions.items(): all_motions[k] = np.stack(v) # T, J*3 window_size = cfg.data.pose_length motion_high_all = [] motion_low_all = [] for k, v in all_motions.items(): motion_tensor = torch.from_numpy(v).float().to(device).unsqueeze(0) _, t, _ = motion_tensor.shape num_chunks = t // window_size motion_high_list = [] motion_low_list = [] for i in range(num_chunks): start_idx = i * window_size end_idx = start_idx + window_size motion_slice = motion_tensor[:, start_idx:end_idx, :] motion_features = actual_model.get_motion_features(motion_slice) motion_high = motion_features["motion_high_weight"].cpu().numpy() motion_low = motion_features["motion_low"].cpu().numpy() motion_high_list.append(motion_high[0]) motion_low_list.append(motion_low[0]) remain_length = t % window_size if remain_length > 0: start_idx = t - window_size motion_slice = motion_tensor[:, start_idx:, :] motion_features = actual_model.get_motion_features(motion_slice) motion_high = motion_features["motion_high_weight"].cpu().numpy() motion_low = motion_features["motion_low"].cpu().numpy() motion_high_list.append(motion_high[0][-remain_length:]) motion_low_list.append(motion_low[0][-remain_length:]) motion_high_all.append(np.concatenate(motion_high_list, axis=0)) motion_low_all.append(np.concatenate(motion_low_list, axis=0)) motion_high_all = np.concatenate(motion_high_all, axis=0) motion_low_all = np.concatenate(motion_low_all, axis=0) # print(motion_high_all.shape, motion_low_all.shape) motion_low_all = motion_low_all / np.linalg.norm(motion_low_all, axis=1, keepdims=True) motion_high_all = motion_high_all / np.linalg.norm(motion_high_all, axis=1, keepdims=True) assert motion_high_all.shape[0] == len(graph.vs) assert motion_low_all.shape[0] == len(graph.vs) for i, node in enumerate(graph.vs): node["motion_high"] = motion_high_all[i] node["motion_low"] = motion_low_all[i] graph = graph_pruning(graph) for idx, pair in enumerate(tqdm(candidate_data, desc="Testing")): gt_motion = np.load(pair["motion_path"] + ".npz", allow_pickle=True)["poses"] target_length = gt_motion.shape[0] audio_path = pair["audio_path"] + ".wav" audio_waveform, sr = librosa.load(audio_path) audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr) audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0) window_size = int(cfg.data.audio_sr * (cfg.data.pose_length / 30)) _, t = audio_tensor.shape num_chunks = t // window_size audio_low_list = [] audio_high_list = [] for i in range(num_chunks): start_idx = i * window_size end_idx = start_idx + window_size # print(start_idx, end_idx, window_size) audio_slice = audio_tensor[:, start_idx:end_idx] model_out_candidates = actual_model.get_audio_features(audio_slice) audio_low = model_out_candidates["audio_low"] audio_high = model_out_candidates["audio_high_weight"] audio_low = F.normalize(audio_low, dim=2)[0].cpu().numpy() audio_high = F.normalize(audio_high, dim=2)[0].cpu().numpy() audio_low_list.append(audio_low) audio_high_list.append(audio_high) # print(audio_low.shape, audio_high.shape) remain_length = t % window_size if remain_length > 0: start_idx = t - window_size audio_slice = audio_tensor[:, start_idx:] model_out_candidates = actual_model.get_audio_features(audio_slice) audio_low = model_out_candidates["audio_low"] audio_high = model_out_candidates["audio_high_weight"] gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1] audio_low = F.normalize(audio_low, dim=2)[0][-gap:].cpu().numpy() audio_high = F.normalize(audio_high, dim=2)[0][-gap:].cpu().numpy() # print(audio_low.shape, audio_high.shape) audio_low_list.append(audio_low) audio_high_list.append(audio_high) audio_low_all = np.concatenate(audio_low_list, axis=0) audio_high_all = np.concatenate(audio_high_list, axis=0) # search the path with audio low features [T, c] and audio high features [T, c] path_list, is_continue_list = search_path(graph, audio_low_all, audio_high_all, top_k=1, search_mode="high_level") res_motion = [] counter = 0 for path, is_continue in zip(path_list, is_continue_list): res_motion_current = path_visualization( graph, path, is_continue, os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), audio_path=audio_path, return_motion=True, verbose_continue=True ) res_motion.append(res_motion_current) np.savez(os.path.join(save_dir, f"audio_{idx}_retri_{counter}.npz"), motion=res_motion_current) counter += 1 metrics = {} counts = {"top1": 0, "top3": 0, "top10": 0} fgd_fn.reset() l1div_fn.reset() bc_fn.reset() srgr_fn.reset() for idx, pair in enumerate(tqdm(candidate_data, desc="Evaluating")): gt_motion = np.load(pair["motion_path"] + ".npz", allow_pickle=True)["poses"] audio_path = pair["audio_path"] + ".wav" gt_motion_tensor = torch.from_numpy(gt_motion).float().to(device).unsqueeze(0) bs, n, _ = gt_motion_tensor.size() audio_waveform, sr = librosa.load(audio_path, sr=None) audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr) audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0) top1_path = os.path.join(save_dir, f"audio_{idx}_retri_0.npz") top1_motion = np.load(top1_path, allow_pickle=True)["motion"] # T 165 top1_motion_tensor = torch.from_numpy(top1_motion).float().to(device).unsqueeze(0) # Add bs, to 1 T 165 gt_vertex = smplx_model( betas=torch.zeros(bs*n, 300).to(device), transl=torch.zeros(bs*n, 3).to(device), expression=torch.zeros(bs*n, 100).to(device), jaw_pose=torch.zeros(bs*n, 3).to(device), global_orient=torch.zeros(bs*n, 3).to(device), body_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 3:21*3+3], left_hand_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 25*3:40*3], right_hand_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 40*3:55*3], return_joints=True, leye_pose=torch.zeros(bs*n, 3).to(device), reye_pose=torch.zeros(bs*n, 3).to(device), )["joints"].detach().cpu().numpy().reshape(bs, n, 127*3)[0, :, :55*3] top1_vertex = smplx_model( betas=torch.zeros(bs*n, 300).to(device), transl=torch.zeros(bs*n, 3).to(device), expression=torch.zeros(bs*n, 100).to(device), jaw_pose=torch.zeros(bs*n, 3).to(device), global_orient=torch.zeros(bs*n, 3).to(device), body_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 3:21*3+3], left_hand_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 25*3:40*3], right_hand_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 40*3:55*3], return_joints=True, leye_pose=torch.zeros(bs*n, 3).to(device), reye_pose=torch.zeros(bs*n, 3).to(device), )["joints"].detach().cpu().numpy().reshape(bs, n, 127*3)[0, :, :55*3] l1div_fn.run(top1_vertex) # print(audio_waveform.shape, top1_vertex.shape) onset_bt = bc_fn.load_audio(audio_waveform, t_start=None, without_file=True, sr_audio=cfg.data.audio_sr) beat_vel = bc_fn.load_pose(top1_vertex, 0, n, pose_fps = 30, without_file=True) # print(n) # print(onset_bt) # print(beat_vel) bc_fn.calculate_align(onset_bt, beat_vel, 30) srgr_fn.run(gt_vertex, top1_vertex) gt_motion_tensor = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3)) gt_motion_tensor = rc.matrix_to_rotation_6d(gt_motion_tensor).reshape(1, n, 55*6) top1_motion_tensor = rc.axis_angle_to_matrix(top1_motion_tensor.reshape(1, n, 55, 3)) top1_motion_tensor = rc.matrix_to_rotation_6d(top1_motion_tensor).reshape(1, n, 55*6) remain = n % 32 if remain != 0: gt_motion_tensor = gt_motion_tensor[:, :n-remain] top1_motion_tensor = top1_motion_tensor[:, :n-remain] # print(gt_motion_tensor.shape, top1_motion_tensor.shape) fgd_fn.update(gt_motion_tensor, top1_motion_tensor) metrics["fgd_top1"] = fgd_fn.compute() metrics["l1_top1"] = l1div_fn.avg() metrics["bc_top1"] = bc_fn.avg() metrics["srgr_top1"] = srgr_fn.avg() print(f"Test Metrics at Iteration {iteration}:") for key, value in metrics.items(): print(f"{key}: {value:.6f}") return metrics def compute_average_precision(feature1, feature2): # Normalize the features feature1 = F.normalize(feature1, dim=1) feature2 = F.normalize(feature2, dim=1) # Compute the similarity matrix similarity_matrix = torch.matmul(feature1, feature2.t()) # Get the top-1 predicted indices for each feature in feature1 top1_indices = torch.argmax(similarity_matrix, dim=1) # Generate ground truth labels (diagonal indices) batch_size = feature1.size(0) ground_truth = torch.arange(batch_size, device=feature1.device) # Compute the accuracy (True if the top-1 index matches the ground truth) correct_predictions = (top1_indices == ground_truth).float() # Compute average precision average_precision = correct_predictions.mean() return average_precision class CosineSimilarityLoss(nn.Module): def __init__(self): super(CosineSimilarityLoss, self).__init__() self.cosine_similarity = nn.CosineSimilarity(dim=2) def forward(self, output1, output2): # Calculate cosine similarity cosine_sim = self.cosine_similarity(output1, output2) # Loss is 1 minus the average cosine similarity return 1 - cosine_sim.mean() class InfoNCELossCross(nn.Module): def __init__(self, temperature=0.1): super(InfoNCELossCross, self).__init__() self.temperature = temperature self.criterion = nn.CrossEntropyLoss() def forward(self, feature1, feature2): """ Args: feature1: tensor of shape (batch_size, feature_dim) feature2: tensor of shape (batch_size, feature_dim) where each corresponding index in feature1 and feature2 is a positive pair, and all other combinations are negative pairs. """ batch_size = feature1.size(0) # Normalize feature vectors feature1 = F.normalize(feature1, dim=1) feature2 = F.normalize(feature2, dim=1) # Compute similarity matrix between feature1 and feature2 similarity_matrix = torch.matmul(feature1, feature2.t()) / self.temperature # Labels for each element in feature1 are the indices of their matching pairs in feature2 labels = torch.arange(batch_size, device=feature1.device) # Cross entropy loss for each positive pair with all corresponding negatives loss = self.criterion(similarity_matrix, labels) return loss class LocalContrastiveLoss(nn.Module): def __init__(self, temperature=0.1): super(LocalContrastiveLoss, self).__init__() self.temperature = temperature def forward(self, motion_feature, audio_feature, learned_temp=None): if learned_temp is not None: temperature = learned_temp else: temperature = self.temperature batch_size, T, _ = motion_feature.size() assert len(motion_feature.shape) == 3 motion_feature = F.normalize(motion_feature, dim=2) audio_feature = F.normalize(audio_feature, dim=2) motion_to_audio_loss = 0 audio_to_motion_loss = 0 motion_to_audio_correct = 0 audio_to_motion_correct = 0 # First pass: motion to audio for t in range(T): motion_feature_t = motion_feature[:, t, :] # (bs, c) # Positive pair range for motion start = max(0, t - 4) end = min(T, t + 4) positive_audio_feature = audio_feature[:, start:end, :] # (bs, pos_range, c) # Negative pair range for motion left_end = start left_start = max(0, left_end - 4 * 3) right_start = end right_end = min(T, right_start + 4 * 3) negative_audio_feature = torch.cat( [audio_feature[:, left_start:left_end, :], audio_feature[:, right_start:right_end, :]], dim=1 ) # (bs, neg_range, c) # Concatenate positive and negative samples combined_audio_feature = torch.cat([positive_audio_feature, negative_audio_feature], dim=1) # (bs, pos_range + neg_range, c) # Compute similarity scores logits = torch.matmul(motion_feature_t.unsqueeze(1), combined_audio_feature.transpose(1, 2)) / temperature # (bs, 1, pos_range + neg_range) logits = logits.squeeze(1) # (bs, pos_range + neg_range) # Compute InfoNCE loss positive_scores = logits[:, :positive_audio_feature.size(1)] loss_t = -positive_scores.logsumexp(dim=1) + torch.logsumexp(logits, dim=1) motion_to_audio_loss += loss_t.mean() # Compute accuracy max_indices = torch.argmax(logits, dim=1) correct_mask = (max_indices < positive_audio_feature.size(1)).float() # Check if indices are within the range of positive samples motion_to_audio_correct += correct_mask.sum() # Second pass: audio to motion for t in range(T): audio_feature_t = audio_feature[:, t, :] # (bs, c) # Positive pair range for audio start = max(0, t - 4) end = min(T, t + 4) positive_motion_feature = motion_feature[:, start:end, :] # (bs, pos_range, c) # Negative pair range for audio left_end = start left_start = max(0, left_end - 4 * 3) right_start = end right_end = min(T, right_start + 4 * 3) negative_motion_feature = torch.cat( [motion_feature[:, left_start:left_end, :], motion_feature[:, right_start:right_end, :]], dim=1 ) # (bs, neg_range, c) # Concatenate positive and negative samples combined_motion_feature = torch.cat([positive_motion_feature, negative_motion_feature], dim=1) # (bs, pos_range + neg_range, c) # Compute similarity scores logits = torch.matmul(audio_feature_t.unsqueeze(1), combined_motion_feature.transpose(1, 2)) / temperature # (bs, 1, pos_range + neg_range) logits = logits.squeeze(1) # (bs, pos_range + neg_range) # Compute InfoNCE loss positive_scores = logits[:, :positive_motion_feature.size(1)] loss_t = -positive_scores.logsumexp(dim=1) + torch.logsumexp(logits, dim=1) audio_to_motion_loss += loss_t.mean() # Compute accuracy max_indices = torch.argmax(logits, dim=1) correct_mask = (max_indices < positive_motion_feature.size(1)).float() # Check if indices are within the range of positive samples audio_to_motion_correct += correct_mask.sum() # Average the two losses final_loss = (motion_to_audio_loss + audio_to_motion_loss) / (2 * T) # Compute final accuracy total_correct = (motion_to_audio_correct + audio_to_motion_correct) / (2 * T * batch_size) return final_loss, total_correct class InfoNCELoss(nn.Module): def __init__(self, temperature=0.1): super(InfoNCELoss, self).__init__() self.temperature = temperature def forward(self, feature1, feature2, learned_temp=None): batch_size = feature1.size(0) assert len(feature1.shape) == 2 if learned_temp is not None: temperature = learned_temp else: temperature = self.temperature # Normalize feature vectors feature1 = F.normalize(feature1, dim=1) feature2 = F.normalize(feature2, dim=1) # Compute similarity matrix between feature1 and feature2 similarity_matrix = torch.matmul(feature1, feature2.t()) / temperature # Extract positive similarities (diagonal elements) positive_similarities = torch.diag(similarity_matrix) # Compute the denominator using logsumexp for numerical stability denominator = torch.logsumexp(similarity_matrix, dim=1) # Compute the InfoNCE loss loss = - (positive_similarities - denominator).mean() return loss def main(cfg): if "LOCAL_RANK" in os.environ: local_rank = int(os.environ["LOCAL_RANK"]) else: local_rank = 0 torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) torch.distributed.init_process_group(backend="nccl") seed_everything(cfg.seed) experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name) smplx_model = smplx.create( "./emage/smplx_models/", model_type='smplx', gender='NEUTRAL_2020', use_face_contour=False, num_betas=300, num_expression_coeffs=100, ext='npz', use_pca=False, ).to(device).eval() model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg).cuda() for param in model.parameters(): param.requires_grad = True # freeze wav2vec2 for param in model.audio_encoder.parameters(): param.requires_grad = False model.smplx_model = smplx_model model.get_motion_reps = get_motion_reps_tensor model.high_level_loss_fn = InfoNCELoss() model.low_level_loss_fn = LocalContrastiveLoss() model = DDP( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, ) if cfg.solver.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" ) optimizer_cls = bnb.optim.AdamW8bit print("using 8 bit") else: optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.solver.learning_rate, betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), weight_decay=cfg.solver.adam_weight_decay, eps=cfg.solver.adam_epsilon,) lr_scheduler = get_scheduler( cfg.solver.lr_scheduler, optimizer=optimizer, num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps, num_training_steps=cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps, ) loss_cosine = CosineSimilarityLoss().to(device) loss_mse = nn.MSELoss().to(device) loss_l1 = nn.L1Loss().to(device) loss_infonce = InfoNCELossCross().to(device) loss_fn_dict = { "loss_cosine": loss_cosine, "loss_mse": loss_mse, "loss_l1": loss_l1, "loss_infonce": loss_infonce, } fgd_fn = emage.mertic.FGD(download_path="./emage/") srgr_fn = emage.mertic.SRGR(threshold=0.3, joints=55, joint_dim=3) bc_fn = emage.mertic.BC(download_path="./emage/", sigma=0.5, order=7) l1div_fn = emage.mertic.L1div() train_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='train') test_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='test') train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = DataLoader(train_dataset, batch_size=cfg.data.train_bs, sampler=train_sampler, drop_last=True, num_workers=4) test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) test_loader = DataLoader(test_dataset, batch_size=256, sampler=test_sampler, drop_last=False, num_workers=4) if local_rank == 0: run_time = datetime.now().strftime("%Y%m%d-%H%M") wandb.init( project=cfg.wandb_project, name=cfg.exp_name + "_" + run_time, entity=cfg.wandb_entity, dir=cfg.wandb_log_dir, config=OmegaConf.to_container(cfg) # Pass config directly during initialization ) else: writer = None num_epochs = cfg.solver.max_train_steps // len(train_loader) + 1 iteration = 0 val_best = {} test_best = {} # checkpoint_path = "/content/drive/MyDrive/005_Weights/baseline_high_env0/checkpoint_3800/ckpt.pth" # checkpoint = torch.load(checkpoint_path) # state_dict = checkpoint['model_state_dict'] # #new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # model.load_state_dict(state_dict, strict=False) # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) # iteration = checkpoint["iteration"] for epoch in range(num_epochs): train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): loss_dict = train_val_fn( batch, model, device, mode="train", optimizer=optimizer, lr_scheduler=lr_scheduler, loss_fn_dict=loss_fn_dict ) if local_rank == 0 and iteration % cfg.log_period == 0: for key, value in loss_dict.items(): # writer.add_scalar(f"train/{key}", value, iteration) wandb.log({f"train/{key}": value}, step=iteration) loss_message = ", ".join([f"{k}: {v:.6f}" for k, v in loss_dict.items()]) print(f"Epoch {epoch} [{i}/{len(train_loader)}] - {loss_message}") if local_rank == 0 and iteration % cfg.validation.val_loss_steps == 0: val_loss_dict = {} val_batches = 0 for batch in tqdm(test_loader): loss_dict = train_val_fn( batch, model, device, mode="val", optimizer=optimizer, lr_scheduler=lr_scheduler, loss_fn_dict=loss_fn_dict ) for k, v in loss_dict.items(): if k not in val_loss_dict: val_loss_dict[k] = 0 val_loss_dict[k] += v.item() # Convert to float for accumulation val_batches += 1 if val_batches == 10: break val_loss_mean_dict = {k: v / val_batches for k, v in val_loss_dict.items()} for k, v in val_loss_mean_dict.items(): if k not in val_best or v > val_best[k]["value"]: val_best[k] = {"value": v, "iteration": iteration} if "acc" in k: checkpoint_path = os.path.join(experiment_ckpt_dir, f"ckpt_{k}") os.makedirs(checkpoint_path, exist_ok=True) torch.save({ 'iteration': iteration, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr_scheduler_state_dict': lr_scheduler.state_dict(), }, os.path.join(checkpoint_path, "ckpt.pth")) print(f"Val [{iteration}] - {k}: {v:.6f} (best: {val_best[k]['value']:.6f} at {val_best[k]['iteration']})") # writer.add_scalar(f"val/{k}", v, iteration) wandb.log({f"val/{k}": v}, step=iteration) checkpoint_path = os.path.join(experiment_ckpt_dir, f"checkpoint_{iteration}") os.makedirs(checkpoint_path, exist_ok=True) torch.save({ 'iteration': iteration, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr_scheduler_state_dict': lr_scheduler.state_dict(), }, os.path.join(checkpoint_path, "ckpt.pth")) checkpoints = [d for d in os.listdir(experiment_ckpt_dir) if os.path.isdir(os.path.join(experiment_ckpt_dir, d)) and d.startswith("checkpoint_")] checkpoints.sort(key=lambda x: int(x.split("_")[1])) if len(checkpoints) > 3: for ckpt_to_delete in checkpoints[:-3]: shutil.rmtree(os.path.join(experiment_ckpt_dir, ckpt_to_delete)) # if local_rank == 0 and iteration % cfg.validation.validation_steps == 0: # test_path = os.path.join(experiment_ckpt_dir, f"test_{iteration}") # os.makedirs(test_path, exist_ok=True) # test_mertic_dict = test_fn(model, device, smplx_model, iteration, fgd_fn, srgr_fn, bc_fn, l1div_fn, cfg.data.test_meta_paths, test_path, cfg) # for k, v in test_mertic_dict.items(): # if k not in test_best or v < test_best[k]["value"]: # test_best[k] = {"value": v, "iteration": iteration} # print(f"Test [{iteration}] - {k}: {v:.6f} (best: {test_best[k]['value']:.6f} at {test_best[k]['iteration']})") # # writer.add_scalar(f"test/{k}", v, iteration) # wandb.log({f"test/{k}": v}, step=iteration) # video_for_log = [] # video_res_path = os.path.join(test_path, f"retrieved_motions_{iteration}") # for mp4_file in os.listdir(video_res_path): # if mp4_file.endswith(".mp4"): # # print(mp4_file) # file_path = os.path.join(video_res_path, mp4_file) # log_video = wandb.Video(file_path, caption=f"{iteration:06d}-{mp4_file}", format="mp4") # video_for_log.append(log_video) # wandb.log( # {"test/videos": video_for_log}, # step=iteration # ) # visualize_fn(test_path) iteration += 1 if local_rank == 0: writer.close() torch.distributed.destroy_process_group() def init_class(module_name, class_name, config, **kwargs): module = importlib.import_module(module_name) model_class = getattr(module, class_name) instance = model_class(config, **kwargs) return instance def seed_everything(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def visualize_fn(test_path, **kwargs): with open(test_path, 'r') as f: test_json = json.load(f) # load top10_indices from json selected_video_path_list = [] # load video list from json with open(test_path, 'r') as f: video_list = json.load(f)["video_candidates"] for idx, data in enumerate(test_json.items()): top10_indices_path = os.path.join(test_path, f"audio_{idx}_retri_top10.json") with open(top10_indices_path, 'r') as f: top10_indices = json.load(f)["top10_indices"] selected_video_path_list.append(video_list[top10_indices[0]]) # moviepy load and add audio video = VideoFileClip(video_list[top10_indices[0]]) audio = AudioFileClip(data["audio_path"]) video = video.set_audio(audio) video.write_videofile(f"audio_{idx}_retri_top1.mp4") video.close() def prepare_all(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/train/stage2.yaml") parser.add_argument("--debug", action="store_true", help="Enable debugging mode") parser.add_argument('overrides', nargs=argparse.REMAINDER) args = parser.parse_args() if args.config.endswith(".yaml"): config = OmegaConf.load(args.config) # config.wandb_project = args.config.split("-")[1] config.exp_name = args.config.split("/")[-1][:-5] else: raise ValueError("Unsupported config file format. Only .yaml files are allowed.") if args.debug: config.wandb_project = "debug" config.exp_name = "debug" config.solver.max_train_steps = 4 if args.overrides: for arg in args.overrides: key, value = arg.split('=') try: value = eval(value) except: pass if key in config: config[key] = value else: raise ValueError(f"Key {key} not found in config.") os.environ["WANDB_API_KEY"] = config.wandb_key save_dir = os.path.join(config.output_dir, config.exp_name) os.makedirs(save_dir, exist_ok=True) os.makedirs(os.path.join(save_dir, 'sanity_check'), exist_ok=True) config_path = os.path.join(save_dir, 'sanity_check', f'{config.exp_name}.yaml') with open(config_path, 'w') as f: OmegaConf.save(config, f) current_dir = os.path.dirname(os.path.abspath(__file__)) sanity_check_dir = os.path.join(save_dir, 'sanity_check') for root, dirs, files in os.walk(current_dir): for file in files: if file.endswith(".py"): full_file_path = os.path.join(root, file) relative_path = os.path.relpath(full_file_path, current_dir) dest_path = os.path.join(sanity_check_dir, relative_path) os.makedirs(os.path.dirname(dest_path), exist_ok=True) shutil.copy(full_file_path, dest_path) return config if __name__ == "__main__": config = prepare_all() main(config)