|
|
import torchaudio |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
import argparse |
|
|
import json |
|
|
from model.ear_vae import EAR_VAE |
|
|
|
|
|
def main(args): |
|
|
indir = args.indir |
|
|
model_path = args.model_path |
|
|
outdir = args.outdir |
|
|
device = args.device |
|
|
config_path = args.config |
|
|
|
|
|
print(f"Input directory: {indir}") |
|
|
print(f"Model path: {model_path}") |
|
|
print(f"Output directory: {outdir}") |
|
|
print(f"Device: {device}") |
|
|
print(f"Config path: {config_path}") |
|
|
|
|
|
|
|
|
input_path = Path(indir) |
|
|
output_path_dir = Path(outdir) |
|
|
output_path_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
vae_gan_model_config = json.load(f) |
|
|
|
|
|
print("Loading model...") |
|
|
model = EAR_VAE(model_config=vae_gan_model_config).to(device) |
|
|
|
|
|
state = torch.load(model_path, map_location="cpu") |
|
|
model.load_state_dict(state) |
|
|
model.eval() |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
audios = list(input_path.rglob("*")) |
|
|
print(f"Found {len(audios)} audio files to process.") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for audio_path in tqdm(audios, desc="Processing audio files"): |
|
|
try: |
|
|
gt_y, sr = torchaudio.load(audio_path, backend="ffmpeg") |
|
|
|
|
|
if len(gt_y.shape) == 1: |
|
|
gt_y = gt_y.unsqueeze(0) |
|
|
|
|
|
|
|
|
if sr != 44100: |
|
|
resampler = torchaudio.transforms.Resample(sr, 44100).to(device) |
|
|
gt_y = resampler(gt_y) |
|
|
|
|
|
gt_y = gt_y.to(device, torch.float32) |
|
|
|
|
|
|
|
|
if gt_y.shape[0] == 1: |
|
|
gt_y = torch.cat([gt_y, gt_y], dim=0) |
|
|
|
|
|
|
|
|
gt_y = gt_y.unsqueeze(0) |
|
|
|
|
|
fake_audio = model.inference(gt_y) |
|
|
|
|
|
output_filename = f"{Path(audio_path).stem}_{Path(model_path).stem}.wav" |
|
|
output_path = output_path_dir / output_filename |
|
|
|
|
|
fake_audio_processed = fake_audio.squeeze(0).cpu() |
|
|
torchaudio.save(output_path, fake_audio_processed, sample_rate=44100, backend="ffmpeg") |
|
|
except Exception as e: |
|
|
print(f"Error processing {audio_path}: {e}") |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description="Run VAE-GAN audio inference.") |
|
|
parser.add_argument('--indir', type=str, default='./data', help='Input directory for audio files.') |
|
|
parser.add_argument('--model_path', type=str, default='./pretrained_weight/ear_vae_44k.pyt', help='Path to the pretrained model weight.') |
|
|
parser.add_argument('--outdir', type=str, default='./results', help='Output directory for generated audio files.') |
|
|
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to run the model on (e.g., "cuda:0" or "cpu").') |
|
|
parser.add_argument('--config', type=str, default='./config/model_config.json', help='Path to the model config file.') |
|
|
|
|
|
args = parser.parse_args() |
|
|
main(args) |