|
|
|
|
|
""" |
|
|
Mixed-FP8 safetensors converter for Hunyuan-Foley checkpoints. |
|
|
|
|
|
- Converts selected .weight tensors to FP8 storage (E5M2 by default on pre-Hopper). |
|
|
- Keeps math in FP16/BF16; this is a storage-only change in the file. |
|
|
- Honors existing FP8 tensors in the input unless --recode-fp8 is set. |
|
|
- Skips norms, biases, visual_proj.*, final_layer.* by design. |
|
|
- Optional --aggressive converts modulation linears too. |
|
|
|
|
|
USAGE (simple): |
|
|
python convert_fp8.py in.safetensors [out.safetensors] # out is optional |
|
|
|
|
|
USAGE (flags): |
|
|
python convert_fp8.py in.safetensors out.safetensors --fp8 auto --aggressive |
|
|
|
|
|
Notes: |
|
|
- “auto” picks FP8_E5M2 on SM < 90 (e.g., 3090), else FP8_E4M3FN. |
|
|
- You can force a format: --fp8 e5m2 | e4m3fn |
|
|
- Dry run: add --dry to print what would change without writing. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import re |
|
|
from typing import Dict, Tuple |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file, save_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_DENY_SUBSTRINGS = ( |
|
|
".bias", ".norm", "q_norm.", "k_norm.", |
|
|
"final_layer.", "visual_proj.", |
|
|
) |
|
|
|
|
|
|
|
|
_ALLOW_PATTERNS = tuple(re.compile(p) for p in ( |
|
|
|
|
|
r"^single_blocks\.\d+\.linear1\.weight$", |
|
|
r"^single_blocks\.\d+\.linear2\.w[123]\.weight$", |
|
|
r"^single_blocks\.\d+\.linear_qkv\.weight$", |
|
|
r"^single_blocks\.\d+\.modulation\.linear\.weight$", |
|
|
|
|
|
|
|
|
r"^triple_blocks\.\d+\.audio_mlp\.fc[12]\.weight$", |
|
|
r"^triple_blocks\.\d+\.v_cond_mlp\.fc[12]\.weight$", |
|
|
|
|
|
|
|
|
r"^triple_blocks\.\d+\.(audio_self_attn_qkv|v_cond_attn_qkv|text_cross_kv)\.weight$", |
|
|
r"^triple_blocks\.\d+\.(audio_self_proj|v_cond_self_proj)\.weight$", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"^triple_blocks\.\d+\.(audio_mod|v_cond_mod)\.linear\.weight$", |
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_out_path(in_path: str, tgt_dtype: torch.dtype) -> str: |
|
|
"""<in>_fp8_<e5m2|e4m3fn>.safetensors (idempotent if already suffixed).""" |
|
|
suffix = "e5m2" if tgt_dtype == torch.float8_e5m2 else "e4m3fn" |
|
|
p = Path(in_path) |
|
|
stem = re.sub(r"_fp8_e(5m2|4m3fn)$", "", p.stem) |
|
|
ext = p.suffix or ".safetensors" |
|
|
return str(p.with_name(f"{stem}_fp8_{suffix}{ext}")) |
|
|
|
|
|
|
|
|
def pick_fp8_dtype(fp8_mode: str) -> torch.dtype: |
|
|
"""Pick target FP8 dtype.""" |
|
|
m = fp8_mode.lower() |
|
|
if m == "e5m2": |
|
|
return torch.float8_e5m2 |
|
|
if m == "e4m3fn": |
|
|
return torch.float8_e4m3fn |
|
|
|
|
|
try: |
|
|
major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) |
|
|
except Exception: |
|
|
major = 0 |
|
|
return torch.float8_e5m2 if major < 9 else torch.float8_e4m3fn |
|
|
|
|
|
|
|
|
def bytes_of(t: torch.Tensor) -> int: |
|
|
"""Size in bytes (FP8=1 byte/elt).""" |
|
|
if t.dtype in (torch.float8_e5m2, torch.float8_e4m3fn): |
|
|
return t.numel() * 1 |
|
|
return t.numel() * t.element_size() |
|
|
|
|
|
|
|
|
def human_gb(nbytes: int) -> float: |
|
|
return nbytes / (1024 ** 3) |
|
|
|
|
|
|
|
|
def _is_denied(name: str) -> bool: |
|
|
return any(tok in name for tok in _DENY_SUBSTRINGS) |
|
|
|
|
|
|
|
|
def should_convert_to_fp8(name: str, aggressive: bool) -> bool: |
|
|
"""Match names for conversion, with modulation linears gated by --aggressive.""" |
|
|
if not name.endswith(".weight"): |
|
|
return False |
|
|
if _is_denied(name): |
|
|
return False |
|
|
|
|
|
for pat in _ALLOW_PATTERNS: |
|
|
if pat.search(name): |
|
|
|
|
|
if ( |
|
|
".modulation.linear.weight" in name |
|
|
or ".audio_mod.linear.weight" in name |
|
|
or ".v_cond_mod.linear.weight" in name |
|
|
): |
|
|
return aggressive |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_state_dict( |
|
|
sd: Dict[str, torch.Tensor], |
|
|
fp8_mode: str = "auto", |
|
|
aggressive: bool = False, |
|
|
recode_fp8: bool = False, |
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: |
|
|
""" |
|
|
Convert selected weights to FP8 storage according to the policy. |
|
|
Honors existing FP8 unless recode_fp8=True. |
|
|
Returns (new_sd, stats) with byte counts. |
|
|
""" |
|
|
tgt_dtype = pick_fp8_dtype(fp8_mode) |
|
|
out: Dict[str, torch.Tensor] = {} |
|
|
stats = { |
|
|
"total_before": 0, |
|
|
"total_after": 0, |
|
|
"converted_count": 0, |
|
|
"kept_fp8_count": 0, |
|
|
"skipped_count": 0, |
|
|
} |
|
|
|
|
|
for name, tensor in sd.items(): |
|
|
before = bytes_of(tensor) |
|
|
stats["total_before"] += before |
|
|
|
|
|
|
|
|
if tensor.dtype in (torch.float8_e5m2, torch.float8_e4m3fn): |
|
|
if recode_fp8: |
|
|
out[name] = tensor.to(dtype=tgt_dtype) |
|
|
stats["converted_count"] += 1 |
|
|
else: |
|
|
out[name] = tensor |
|
|
stats["kept_fp8_count"] += 1 |
|
|
stats["total_after"] += bytes_of(out[name]) |
|
|
continue |
|
|
|
|
|
|
|
|
if should_convert_to_fp8(name, aggressive): |
|
|
out[name] = tensor.to(dtype=tgt_dtype) |
|
|
stats["converted_count"] += 1 |
|
|
else: |
|
|
out[name] = tensor |
|
|
stats["skipped_count"] += 1 |
|
|
|
|
|
stats["total_after"] += bytes_of(out[name]) |
|
|
|
|
|
return out, stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
p = argparse.ArgumentParser(description="Convert selected weights in a safetensors file to FP8 storage.") |
|
|
p.add_argument("in_path", help="Input .safetensors") |
|
|
p.add_argument("out_path", nargs="?", help="Output .safetensors (optional)") |
|
|
p.add_argument("--fp8", choices=["auto", "e5m2", "e4m3fn"], default="auto", |
|
|
help='Target FP8 storage dtype: "auto" (default), "e5m2", or "e4m3fn"') |
|
|
p.add_argument("--aggressive", action="store_true", |
|
|
help="Also convert modulation linears (audio_mod/v_cond_mod + single modulation.linear).") |
|
|
p.add_argument("--recode-fp8", action="store_true", |
|
|
help="Re-encode existing FP8 tensors to the chosen target dtype.") |
|
|
p.add_argument("--dry", action="store_true", |
|
|
help="Dry run: report only; do not write output file.") |
|
|
return p.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
print(f"[load] {args.in_path}") |
|
|
sd = load_file(args.in_path) |
|
|
|
|
|
tgt = pick_fp8_dtype(args.fp8) |
|
|
if not args.out_path: |
|
|
args.out_path = default_out_path(args.in_path, tgt) |
|
|
print(f"[auto-out] {args.out_path}") |
|
|
|
|
|
print(f"[policy] fp8_mode={args.fp8} -> {str(tgt).replace('torch.','')}, " |
|
|
f"aggressive={args.aggressive}, recode_fp8={args.recode_fp8}") |
|
|
|
|
|
new_sd, stats = convert_state_dict( |
|
|
sd, |
|
|
fp8_mode=args.fp8, |
|
|
aggressive=args.aggressive, |
|
|
recode_fp8=args.recode_fp8, |
|
|
) |
|
|
|
|
|
saved = stats["total_before"] - stats["total_after"] |
|
|
print(f"[stats] tensors: {len(sd)}") |
|
|
print(f"[stats] converted: {stats['converted_count']} | kept_fp8: {stats['kept_fp8_count']} " |
|
|
f"| skipped: {stats['skipped_count']}") |
|
|
print(f"[bytes] before={human_gb(stats['total_before']):.3f} GiB | " |
|
|
f"after={human_gb(stats['total_after']):.3f} GiB | saved={human_gb(saved):.3f} GiB") |
|
|
|
|
|
if args.dry: |
|
|
print("[dry] no file written") |
|
|
return |
|
|
|
|
|
print(f"[save] {args.out_path}") |
|
|
save_file(new_sd, args.out_path) |
|
|
print("[done]") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|