import os import sys current_file = os.path.abspath(__file__) current_dir = os.path.dirname(current_file) songformer_path = os.path.join(current_dir, "src", "SongFormer") if os.path.exists(songformer_path): os.chdir(songformer_path) else: print(f"The target working directory does not exist: {songformer_path}") working_dir = os.getcwd() third_party_path = os.path.join(current_dir, "src", "third_party") if os.path.exists(third_party_path): sys.path.insert(0, third_party_path) sys.path.insert(0, working_dir) musicfm_paths = [ os.path.join(current_dir, "src"), os.path.join(current_dir, "third_party"), os.path.join(current_dir, "src", "SongFormer"), ] for path in musicfm_paths: if os.path.exists(path): sys.path.insert(0, path) # monkey patch to fix issues in msaf import scipy import numpy as np scipy.inf = np.inf import gradio as gr import torch import librosa import json import math import importlib import matplotlib.pyplot as plt import matplotlib.ticker as ticker from pathlib import Path from argparse import Namespace from omegaconf import OmegaConf from ema_pytorch import EMA from muq import MuQ from musicfm.model.musicfm_25hz import MusicFM25Hz from postprocessing.functional import postprocess_functional_structure from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID from utils.fetch_pretrained import download_all import spaces # Constants MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM") BEFORE_DOWNSAMPLING_FRAME_RATES = 25 AFTER_DOWNSAMPLING_FRAME_RATES = 8.333 DATASET_LABEL = "SongForm-HX-8Class" DATASET_IDS = [5] TIME_DUR = 420 INPUT_SAMPLING_RATE = 24000 # Global model variables muq_model = None musicfm_model = None msa_model = None device = None def load_checkpoint(checkpoint_path, device=None): """Load checkpoint from path""" if device is None: device = "cpu" if checkpoint_path.endswith(".pt"): checkpoint = torch.load(checkpoint_path, map_location=device) elif checkpoint_path.endswith(".safetensors"): from safetensors.torch import load_file checkpoint = {"model_ema": load_file(checkpoint_path, device=device)} else: raise ValueError("Unsupported checkpoint format. Use .pt or .safetensors") return checkpoint def initialize_models(model_name: str, checkpoint: str, config_path: str): """Initialize all models""" global muq_model, musicfm_model, msa_model, device # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load MuQ muq_model = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter") muq_model = muq_model.to(device).eval() # Load MusicFM musicfm_model = MusicFM25Hz( is_flash=False, stat_path=os.path.join(MUSICFM_HOME_PATH, "msd_stats.json"), model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"), ) musicfm_model = musicfm_model.to(device).eval() # Load MSA model module = importlib.import_module("models." + str(model_name)) Model = getattr(module, "Model") hp = OmegaConf.load(os.path.join("configs", config_path)) msa_model = Model(hp) ckpt = load_checkpoint(checkpoint_path=os.path.join("ckpts", checkpoint)) if ckpt.get("model_ema", None) is not None: model_ema = EMA(msa_model, include_online_model=False) model_ema.load_state_dict(ckpt["model_ema"]) msa_model.load_state_dict(model_ema.ema_model.state_dict()) else: msa_model.load_state_dict(ckpt["model"]) msa_model.to(device).eval() return hp @spaces.GPU() def process_audio(audio_path, win_size=420, hop_size=420, num_classes=128): """Process audio file and return structure analysis results""" global muq_model, musicfm_model, msa_model, device if muq_model is None: hp = initialize_models() else: hp = OmegaConf.load(os.path.join("configs", "SongFormer.yaml")) # Load audio wav, sr = librosa.load(audio_path, sr=INPUT_SAMPLING_RATE) audio = torch.tensor(wav).to(device) # Prepare output total_len = ( (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR * TIME_DUR ) + TIME_DUR total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES) logits = { "function_logits": np.zeros([total_frames, num_classes]), "boundary_logits": np.zeros([total_frames]), } logits_num = { "function_logits": np.zeros([total_frames, num_classes]), "boundary_logits": np.zeros([total_frames]), } # Prepare label masks dataset_id2label_mask = {} for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items(): dataset_id2label_mask[key] = np.ones(num_classes, dtype=bool) dataset_id2label_mask[key][allowed_ids] = False lens = 0 i = 0 with torch.no_grad(): while True: start_idx = i * INPUT_SAMPLING_RATE end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1]) if start_idx >= audio.shape[-1]: break if end_idx - start_idx <= 1024: continue audio_seg = audio[start_idx:end_idx] # Get embeddings muq_output = muq_model(audio_seg.unsqueeze(0), output_hidden_states=True) muq_embd_420s = muq_output["hidden_states"][10] del muq_output torch.cuda.empty_cache() _, musicfm_hidden_states = musicfm_model.get_predictions( audio_seg.unsqueeze(0) ) musicfm_embd_420s = musicfm_hidden_states[10] del musicfm_hidden_states torch.cuda.empty_cache() # Process 30-second segments wraped_muq_embd_30s = [] wraped_musicfm_embd_30s = [] for idx_30s in range(i, i + hop_size, 30): start_idx_30s = idx_30s * INPUT_SAMPLING_RATE end_idx_30s = min( (idx_30s + 30) * INPUT_SAMPLING_RATE, audio.shape[-1], (i + hop_size) * INPUT_SAMPLING_RATE, ) if start_idx_30s >= audio.shape[-1]: break if end_idx_30s - start_idx_30s <= 1024: continue wraped_muq_embd_30s.append( muq_model( audio[start_idx_30s:end_idx_30s].unsqueeze(0), output_hidden_states=True, )["hidden_states"][10] ) torch.cuda.empty_cache() wraped_musicfm_embd_30s.append( musicfm_model.get_predictions( audio[start_idx_30s:end_idx_30s].unsqueeze(0) )[1][10] ) torch.cuda.empty_cache() if wraped_muq_embd_30s: wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1) wraped_musicfm_embd_30s = torch.concatenate( wraped_musicfm_embd_30s, dim=1 ) all_embds = [ wraped_musicfm_embd_30s, wraped_muq_embd_30s, musicfm_embd_420s, muq_embd_420s, ] # Align embedding lengths if len(all_embds) > 1: embd_lens = [x.shape[1] for x in all_embds] min_embd_len = min(embd_lens) for idx in range(len(all_embds)): all_embds[idx] = all_embds[idx][:, :min_embd_len, :] embd = torch.concatenate(all_embds, axis=-1) # Inference dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long) msa_info, chunk_logits = msa_model.infer( input_embeddings=embd, dataset_ids=dataset_ids, label_id_masks=torch.Tensor( dataset_id2label_mask[ DATASET_LABEL_TO_DATASET_ID[DATASET_LABEL] ] ) .to(device, dtype=bool) .unsqueeze(0) .unsqueeze(0), with_logits=True, ) # Accumulate logits start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES) end_frame = start_frame + min( math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES), chunk_logits["boundary_logits"][0].shape[0], ) logits["function_logits"][start_frame:end_frame, :] += ( chunk_logits["function_logits"][0].detach().cpu().numpy() ) logits["boundary_logits"][start_frame:end_frame] = ( chunk_logits["boundary_logits"][0].detach().cpu().numpy() ) logits_num["function_logits"][start_frame:end_frame, :] += 1 logits_num["boundary_logits"][start_frame:end_frame] += 1 lens += end_frame - start_frame i += hop_size # Average logits logits["function_logits"] /= np.maximum(logits_num["function_logits"], 1) logits["boundary_logits"] /= np.maximum(logits_num["boundary_logits"], 1) logits["function_logits"] = torch.from_numpy( logits["function_logits"][:lens] ).unsqueeze(0) logits["boundary_logits"] = torch.from_numpy( logits["boundary_logits"][:lens] ).unsqueeze(0) # Post-process msa_infer_output = postprocess_functional_structure(logits, hp) return logits, msa_infer_output def format_as_segments(msa_output): """Format as list of segments""" segments = [] for idx in range(len(msa_output) - 1): segments.append( { "start": str(round(msa_output[idx][0], 2)), "end": str(round(msa_output[idx + 1][0], 2)), "label": msa_output[idx][1], } ) return segments def format_as_msa(msa_output): """Format as MSA format""" lines = [] for time, label in msa_output: lines.append(f"{time:.2f} {label}") return "\n".join(lines) def format_as_json(segments): """Format as JSON""" return json.dumps(segments, indent=2, ensure_ascii=False) def create_visualization( logits, msa_output, label_num=8, frame_rates=AFTER_DOWNSAMPLING_FRAME_RATES ): """Create visualization plot""" # Assume ID_TO_LABEL mapping exists try: from dataset.label2id import ID_TO_LABEL except: ID_TO_LABEL = {i: f"Class_{i}" for i in range(128)} function_vals = logits["function_logits"].squeeze().cpu().numpy() boundary_vals = logits["boundary_logits"].squeeze().cpu().numpy() top_classes = np.argsort(function_vals.mean(axis=0))[-label_num:] T = function_vals.shape[0] time_axis = np.arange(T) / frame_rates fig, ax = plt.subplots(2, 1, figsize=(15, 8), sharex=True) # Plot function logits for cls in top_classes: ax[1].plot( time_axis, function_vals[:, cls], label=f"{ID_TO_LABEL.get(cls, f'Class_{cls}')}", ) ax[1].set_title("Top 8 Function Logits by Mean Activation") ax[1].set_xlabel("Time (seconds)") ax[1].set_ylabel("Logit") ax[1].xaxis.set_major_locator(ticker.MultipleLocator(20)) ax[1].xaxis.set_minor_locator(ticker.MultipleLocator(5)) ax[1].xaxis.set_major_formatter(ticker.FormatStrFormatter("%.1f")) ax[1].legend() ax[1].grid(True) # Plot boundary logits ax[0].plot(time_axis, boundary_vals, label="Boundary Logit", color="orange") ax[0].set_title("Boundary Logits") ax[0].set_ylabel("Logit") ax[0].legend() ax[0].grid(True) # Add vertical lines for markers for t_sec, label in msa_output: for a in ax: a.axvline(x=t_sec, color="red", linestyle="--", linewidth=0.8, alpha=0.7) if label != "end": ax[1].text( t_sec + 0.3, ax[1].get_ylim()[1] * 0.85, label, rotation=90, fontsize=8, color="red", ) plt.suptitle("Music Structure Analysis - Logits Overview", fontsize=16) plt.tight_layout() return fig def rule_post_processing(msa_list): if len(msa_list) <= 2: return msa_list result = msa_list.copy() while len(result) > 2: first_duration = result[1][0] - result[0][0] if first_duration < 1.0 and len(result) > 2: result[0] = (result[0][0], result[1][1]) result = [result[0]] + result[2:] else: break while len(result) > 2: last_label_duration = result[-1][0] - result[-2][0] if last_label_duration < 1.0: result = result[:-2] + [result[-1]] else: break while len(result) > 2: if result[0][1] == result[1][1] and result[1][0] <= 10.0: result = [(result[0][0], result[0][1])] + result[2:] else: break while len(result) > 2: last_duration = result[-1][0] - result[-2][0] if result[-2][1] == result[-3][1] and last_duration <= 10.0: result = result[:-2] + [result[-1]] else: break return result def process_and_analyze(audio_file): """Main processing function""" def format_time(t: float) -> str: minutes = int(t // 60) seconds = t % 60 return f"{minutes:02d}:{seconds:06.3f}" # 这个格式是正确的 if audio_file is None: return None, "", "", None try: # Process audio logits, msa_output = process_audio(audio_file) # Apply rule-based post-processing, if not needed, use in cli infer msa_output = rule_post_processing(msa_output) # Format outputs segments = format_as_segments(msa_output) msa_format = format_as_msa(msa_output) json_format = format_as_json(segments) # Create table data table_data = [ [ f"{float(seg['start']):.2f} ({format_time(float(seg['start']))})", f"{float(seg['end']):.2f} ({format_time(float(seg['end']))})", seg["label"], ] for seg in segments ] # Create visualization fig = create_visualization(logits, msa_output) return table_data, json_format, msa_format, fig except Exception as e: import traceback error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) # 在命令行输出完整错误 return None, "", error_msg, None # Create Gradio interface with gr.Blocks( title="Music Structure Analysis", css=""" .logo-container { text-align: center; margin-bottom: 20px; } .links-container { display: flex; justify-content: center; column-gap: 10px; margin-bottom: 10px; } .model-title { text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 30px; } """, ) as demo: # Top Logo gr.HTML("""
""") # Model title gr.HTML("""
SongFormer: Scaling Music Structure Analysis with Heterogeneous Supervision
""") # Links gr.HTML(""" """) # Main input area with gr.Row(): with gr.Column(scale=3): audio_input = gr.Audio( label="Upload Audio File", type="filepath", elem_id="audio-input" ) with gr.Column(scale=1): gr.Markdown("### 📌 Examples") gr.Examples( examples=[ # Add your example audio file paths ["examples/BC_5cd6a6.mp3"], ["examples/BC_282ece.mp3"], ["examples/BHX_0158_letitrock.wav"], ["examples/BHX_0374_drunkonyou.wav"], ], inputs=[audio_input], label="Click to load example", ) # Analyze button with gr.Row(): analyze_btn = gr.Button( "🚀 Analyze Music Structure", variant="primary", scale=1 ) # Results display area with gr.Row(): with gr.Column(scale=13): segments_table = gr.Dataframe( headers=["Start / s (m:s.ms)", "End / s (m:s.ms)", "Label"], label="Detected Music Segments", interactive=False, elem_id="result-table", ) with gr.Column(scale=8): with gr.Row(): with gr.Accordion("📄 JSON Output", open=False): json_output = gr.Textbox( label="JSON Format", lines=15, max_lines=20, interactive=False, show_copy_button=True, ) with gr.Row(): with gr.Accordion("📋 MSA Text Output", open=False): msa_output = gr.Textbox( label="MSA Format", lines=15, max_lines=20, interactive=False, show_copy_button=True, ) # Visualization plot with gr.Row(): plot_output = gr.Plot(label="Activation Curves Visualization") gr.HTML("""
""") # Set event handlers analyze_btn.click( fn=process_and_analyze, inputs=[audio_input], outputs=[segments_table, json_output, msa_output, plot_output], ) if __name__ == "__main__": # Download pretrained models if not exist download_all(use_mirror=False) # Initialize models print("Initializing models...") initialize_models( model_name="SongFormer", checkpoint="SongFormer.safetensors", config_path="SongFormer.yaml", ) print("Models loaded successfully!") # Launch interface demo.launch()