The UVR-MDX-NET-Inst_HQ_4.pt was obtained w/ the following code:

import argparse

import onnx
import torch
from onnx2torch import convert


def try_forward(m, shape):
    x = torch.randn(*shape, dtype=torch.float32)
    with torch.no_grad():
        m(x)
    return True


def main(onnx_path, out_prefix):
    model_onnx = onnx.load(onnx_path)
    model_torch = convert(model_onnx).eval()

    candidates = [
        (1, 4, 2560, 256),
        (1, 4, 2560, 320),
        (1, 4, 3072, 256),
        (1, 4, 3072, 320),
    ]

    ok_shape = None
    for shape in candidates:
        try:
            try_forward(model_torch, shape)
            ok_shape = shape
            break
        except Exception:
            pass
    if ok_shape is None:
        raise RuntimeError("Could not find a working input shape for this ONNX model.")

    try:
        scripted = torch.jit.script(model_torch)
        print("Scripted model")
    except Exception:
        scripted = torch.jit.trace(model_torch, torch.randn(*ok_shape), strict=False)
        print("Traced model")

    out_pt = f"{out_prefix}.pt"
    torch.jit.save(scripted, out_pt)


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--onnx", required=True)
    ap.add_argument("--out-prefix", default="UVR-MDX-NET-Inst_HQ_4")
    args = ap.parse_args()
    main(args.onnx, args.out_prefix)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support