Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import numpy as np | |
| import torch | |
| import soundfile as sf | |
| import librosa | |
| import gradio as gr | |
| import spaces # For ZeroGPU | |
| from huggingface_hub import hf_hub_download | |
| from safetensors import safe_open | |
| from xcodec2.configuration_bigcodec import BigCodecConfig | |
| from xcodec2.modeling_xcodec2 import XCodec2Model | |
| # ====== Settings ====== | |
| # Use only the FT (44.1 kHz) version | |
| FT_REPO = os.getenv("FT_REPO", "NandemoGHS/Anime-XCodec2-44.1kHz-v2") | |
| TARGET_SR = 16000 # XCodec2 expects 16 kHz input | |
| MAX_SECONDS_DEFAULT = 30 # Default max duration (seconds) | |
| def _ensure_models(): | |
| """Load the FT model to CPU once, and reuse across requests.""" | |
| global _model_ft | |
| if _model_ft is None: | |
| ckpt_path = hf_hub_download(repo_id=FT_REPO, filename="model.safetensors") | |
| ckpt = {} | |
| with safe_open(ckpt_path, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| ckpt[k.replace(".beta", ".bias")] = f.get_tensor(k) | |
| codec_config = BigCodecConfig.from_pretrained(FT_REPO) | |
| _model_ft = XCodec2Model.from_pretrained( | |
| None, config=codec_config, state_dict=ckpt | |
| ) | |
| _model_ft.eval().to("cpu") | |
| # ====== Globals (lazy CPU load; move to GPU only during inference) ====== | |
| _model_ft = None | |
| _ensure_models() | |
| def _load_audio(filepath: str, max_seconds: int): | |
| """ | |
| Load audio (wav/flac/ogg/mp3), convert to mono, resample to 16 kHz, | |
| trim to the given max length (from the beginning), and return torch.Tensor (1, T). | |
| """ | |
| # Try soundfile first, then fall back to librosa | |
| try: | |
| wav, sr = sf.read(filepath, dtype="float32", always_2d=False) | |
| except Exception: | |
| wav, sr = librosa.load(filepath, sr=None, mono=False) | |
| wav = np.asarray(wav, dtype=np.float32) | |
| # Mono | |
| if wav.ndim == 2: | |
| if wav.shape[1] in (1, 2): # (frames, ch) | |
| wav = wav.mean(axis=1) | |
| else: # (ch, frames) | |
| wav = wav.mean(axis=0) | |
| elif wav.ndim > 2: | |
| wav = np.mean(wav, axis=tuple(range(1, wav.ndim))) | |
| # Resample to 16 kHz | |
| if sr != TARGET_SR: | |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR) | |
| sr = TARGET_SR | |
| # Length cap | |
| if max_seconds is None or max_seconds <= 0: | |
| max_seconds = MAX_SECONDS_DEFAULT | |
| max_len = int(sr * max_seconds) | |
| if wav.shape[0] > max_len: | |
| wav = wav[:max_len] | |
| # Light safety normalization | |
| peak = np.max(np.abs(wav)) | |
| if peak > 1.0: | |
| wav = wav / (peak + 1e-8) | |
| wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # (1, T) | |
| return wav_tensor, sr | |
| def _codes_to_tensor(codes, device): | |
| """ | |
| Normalize the output of xcodec2.encode_code to a tensor with shape (1, 1, N). | |
| Handles version differences where the return type/shape may vary. | |
| """ | |
| if isinstance(codes, torch.Tensor): | |
| return codes.to(device) | |
| try: | |
| t = torch.as_tensor(codes[0][0], device=device) | |
| return t.unsqueeze(0).unsqueeze(0) if t.ndim == 1 else t | |
| except Exception: | |
| return torch.as_tensor(codes, device=device) | |
| def _reconstruct(model: XCodec2Model, waveform: torch.Tensor, device: str) -> np.ndarray: | |
| """Encode→decode with XCodec2 to get a reconstructed waveform (np.float32, clipped to [-1, 1]).""" | |
| with torch.inference_mode(): | |
| wave = waveform.to(device) | |
| codes = model.encode_code(input_waveform=wave) | |
| codes_t = _codes_to_tensor(codes, device=device) | |
| recon = model.decode_code(codes_t) # (1, 1, T') | |
| recon_np = recon.squeeze().detach().cpu().numpy().astype(np.float32) | |
| recon_np = np.clip(recon_np, -1.0, 1.0) | |
| return recon_np | |
| # ZeroGPU: reserve GPU only during this function call | |
| def run(audio_path, max_seconds): | |
| if audio_path is None: | |
| raise gr.Error("Please upload an audio file.") | |
| _ensure_models() | |
| waveform, sr = _load_audio(audio_path, max_seconds) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Fine-tuned | |
| ft = _model_ft.to(device) | |
| recon_ft = _reconstruct(ft, waveform, device) | |
| # Gradio Audio expects (sample_rate, np.ndarray) | |
| # 44.1 kHz version returns 44.1kHz sr | |
| return (44100, recon_ft) | |
| # ====== UI ====== | |
| # Modified DESCRIPTION for the single-model demo | |
| DESCRIPTION = """ | |
| # Anime-XCodec2-44.1kHz-v2 Reconstruction Demo | |
| This demo reconstructs audio using the **44.1 kHz fine-tuned (NandemoGHS/Anime-XCodec2-44.1kHz-v2)** model. | |
| - Supported inputs: wav / flac / ogg / mp3 | |
| - Input is automatically converted to **16 kHz** (as required by XCodec2). | |
| - ZeroGPU ready. If no GPU is available, it falls back to CPU (slower). | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp = gr.Audio( | |
| sources=["upload"], | |
| type="filepath", | |
| label="Upload an audio file", | |
| waveform_options={"show_controls": True} | |
| ) | |
| max_sec = gr.Slider( | |
| 3, 60, value=MAX_SECONDS_DEFAULT, step=1, | |
| label="Max length (seconds)", | |
| info="If the input is longer, only the first N seconds will be processed." | |
| ) | |
| run_btn = gr.Button("Run", variant="primary") | |
| gr.Markdown( | |
| f"**44.1 kHz model**: `{FT_REPO}`\n" | |
| f"**Inference device**: auto (GPU on ZeroGPU)" | |
| ) | |
| with gr.Column(scale=1): | |
| # Single audio output | |
| out_ft = gr.Audio( | |
| label="44.1kHz reconstruction (NandemoGHS/Anime-XCodec2-44.1kHz-v2)", | |
| show_download_button=True, format="wav" | |
| ) | |
| # Click action points to the single output | |
| run_btn.click(run, inputs=[inp, max_sec], outputs=[out_ft]) | |
| # In Spaces, explicit launch is optional | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() |