OmniAICreator's picture
Update app.py
fa5a4cf verified
import os
import numpy as np
import torch
import soundfile as sf
import librosa
import gradio as gr
import spaces # For ZeroGPU
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from xcodec2.configuration_bigcodec import BigCodecConfig
from xcodec2.modeling_xcodec2 import XCodec2Model
# ====== Settings ======
# Use only the FT (44.1 kHz) version
FT_REPO = os.getenv("FT_REPO", "NandemoGHS/Anime-XCodec2-44.1kHz-v2")
TARGET_SR = 16000 # XCodec2 expects 16 kHz input
MAX_SECONDS_DEFAULT = 30 # Default max duration (seconds)
def _ensure_models():
"""Load the FT model to CPU once, and reuse across requests."""
global _model_ft
if _model_ft is None:
ckpt_path = hf_hub_download(repo_id=FT_REPO, filename="model.safetensors")
ckpt = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for k in f.keys():
ckpt[k.replace(".beta", ".bias")] = f.get_tensor(k)
codec_config = BigCodecConfig.from_pretrained(FT_REPO)
_model_ft = XCodec2Model.from_pretrained(
None, config=codec_config, state_dict=ckpt
)
_model_ft.eval().to("cpu")
# ====== Globals (lazy CPU load; move to GPU only during inference) ======
_model_ft = None
_ensure_models()
def _load_audio(filepath: str, max_seconds: int):
"""
Load audio (wav/flac/ogg/mp3), convert to mono, resample to 16 kHz,
trim to the given max length (from the beginning), and return torch.Tensor (1, T).
"""
# Try soundfile first, then fall back to librosa
try:
wav, sr = sf.read(filepath, dtype="float32", always_2d=False)
except Exception:
wav, sr = librosa.load(filepath, sr=None, mono=False)
wav = np.asarray(wav, dtype=np.float32)
# Mono
if wav.ndim == 2:
if wav.shape[1] in (1, 2): # (frames, ch)
wav = wav.mean(axis=1)
else: # (ch, frames)
wav = wav.mean(axis=0)
elif wav.ndim > 2:
wav = np.mean(wav, axis=tuple(range(1, wav.ndim)))
# Resample to 16 kHz
if sr != TARGET_SR:
wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR)
sr = TARGET_SR
# Length cap
if max_seconds is None or max_seconds <= 0:
max_seconds = MAX_SECONDS_DEFAULT
max_len = int(sr * max_seconds)
if wav.shape[0] > max_len:
wav = wav[:max_len]
# Light safety normalization
peak = np.max(np.abs(wav))
if peak > 1.0:
wav = wav / (peak + 1e-8)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # (1, T)
return wav_tensor, sr
def _codes_to_tensor(codes, device):
"""
Normalize the output of xcodec2.encode_code to a tensor with shape (1, 1, N).
Handles version differences where the return type/shape may vary.
"""
if isinstance(codes, torch.Tensor):
return codes.to(device)
try:
t = torch.as_tensor(codes[0][0], device=device)
return t.unsqueeze(0).unsqueeze(0) if t.ndim == 1 else t
except Exception:
return torch.as_tensor(codes, device=device)
def _reconstruct(model: XCodec2Model, waveform: torch.Tensor, device: str) -> np.ndarray:
"""Encode→decode with XCodec2 to get a reconstructed waveform (np.float32, clipped to [-1, 1])."""
with torch.inference_mode():
wave = waveform.to(device)
codes = model.encode_code(input_waveform=wave)
codes_t = _codes_to_tensor(codes, device=device)
recon = model.decode_code(codes_t) # (1, 1, T')
recon_np = recon.squeeze().detach().cpu().numpy().astype(np.float32)
recon_np = np.clip(recon_np, -1.0, 1.0)
return recon_np
@spaces.GPU(duration=60) # ZeroGPU: reserve GPU only during this function call
def run(audio_path, max_seconds):
if audio_path is None:
raise gr.Error("Please upload an audio file.")
_ensure_models()
waveform, sr = _load_audio(audio_path, max_seconds)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Fine-tuned
ft = _model_ft.to(device)
recon_ft = _reconstruct(ft, waveform, device)
# Gradio Audio expects (sample_rate, np.ndarray)
# 44.1 kHz version returns 44.1kHz sr
return (44100, recon_ft)
# ====== UI ======
# Modified DESCRIPTION for the single-model demo
DESCRIPTION = """
# Anime-XCodec2-44.1kHz-v2 Reconstruction Demo
This demo reconstructs audio using the **44.1 kHz fine-tuned (NandemoGHS/Anime-XCodec2-44.1kHz-v2)** model.
- Supported inputs: wav / flac / ogg / mp3
- Input is automatically converted to **16 kHz** (as required by XCodec2).
- ZeroGPU ready. If no GPU is available, it falls back to CPU (slower).
"""
with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
inp = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload an audio file",
waveform_options={"show_controls": True}
)
max_sec = gr.Slider(
3, 60, value=MAX_SECONDS_DEFAULT, step=1,
label="Max length (seconds)",
info="If the input is longer, only the first N seconds will be processed."
)
run_btn = gr.Button("Run", variant="primary")
gr.Markdown(
f"**44.1 kHz model**: `{FT_REPO}`\n"
f"**Inference device**: auto (GPU on ZeroGPU)"
)
with gr.Column(scale=1):
# Single audio output
out_ft = gr.Audio(
label="44.1kHz reconstruction (NandemoGHS/Anime-XCodec2-44.1kHz-v2)",
show_download_button=True, format="wav"
)
# Click action points to the single output
run_btn.click(run, inputs=[inp, max_sec], outputs=[out_ft])
# In Spaces, explicit launch is optional
if __name__ == "__main__":
demo.queue(max_size=8).launch()