#!/usr/bin/env python3 """ Mixed-FP8 safetensors converter for Hunyuan-Foley checkpoints. - Converts selected .weight tensors to FP8 storage (E5M2 by default on pre-Hopper). - Keeps math in FP16/BF16; this is a storage-only change in the file. - Honors existing FP8 tensors in the input unless --recode-fp8 is set. - Skips norms, biases, visual_proj.*, final_layer.* by design. - Optional --aggressive converts modulation linears too. USAGE (simple): python convert_fp8.py in.safetensors [out.safetensors] # out is optional USAGE (flags): python convert_fp8.py in.safetensors out.safetensors --fp8 auto --aggressive Notes: - “auto” picks FP8_E5M2 on SM < 90 (e.g., 3090), else FP8_E4M3FN. - You can force a format: --fp8 e5m2 | e4m3fn - Dry run: add --dry to print what would change without writing. """ import argparse import re from typing import Dict, Tuple from pathlib import Path import torch from safetensors.torch import load_file, save_file # --------------------------- Policy (names) --------------------------- # Skip norms/bias and sensitive endpoints explicitly _DENY_SUBSTRINGS = ( ".bias", ".norm", "q_norm.", "k_norm.", "final_layer.", "visual_proj.", ) # Allowed patterns target this architecture’s large linears _ALLOW_PATTERNS = tuple(re.compile(p) for p in ( # Single-stream blocks r"^single_blocks\.\d+\.linear1\.weight$", r"^single_blocks\.\d+\.linear2\.w[123]\.weight$", # w1/w2/w3 r"^single_blocks\.\d+\.linear_qkv\.weight$", r"^single_blocks\.\d+\.modulation\.linear\.weight$", # gated by --aggressive # Triple-stream blocks: MLPs (dominant size) r"^triple_blocks\.\d+\.audio_mlp\.fc[12]\.weight$", r"^triple_blocks\.\d+\.v_cond_mlp\.fc[12]\.weight$", # Triple-stream blocks: attention projections r"^triple_blocks\.\d+\.(audio_self_attn_qkv|v_cond_attn_qkv|text_cross_kv)\.weight$", r"^triple_blocks\.\d+\.(audio_self_proj|v_cond_self_proj)\.weight$", # r"^triple_blocks\.\d+\.(audio_cross_q|v_cond_cross_q)\.weight$", # r"^triple_blocks\.\d+\.(audio_cross_proj|v_cond_cross_proj)\.weight$", # Triple-stream blocks: modulation linears (gated) r"^triple_blocks\.\d+\.(audio_mod|v_cond_mod)\.linear\.weight$", )) # --------------------------- Helpers --------------------------- def default_out_path(in_path: str, tgt_dtype: torch.dtype) -> str: """_fp8_.safetensors (idempotent if already suffixed).""" suffix = "e5m2" if tgt_dtype == torch.float8_e5m2 else "e4m3fn" p = Path(in_path) stem = re.sub(r"_fp8_e(5m2|4m3fn)$", "", p.stem) # strip prior suffix ext = p.suffix or ".safetensors" return str(p.with_name(f"{stem}_fp8_{suffix}{ext}")) def pick_fp8_dtype(fp8_mode: str) -> torch.dtype: """Pick target FP8 dtype.""" m = fp8_mode.lower() if m == "e5m2": return torch.float8_e5m2 if m == "e4m3fn": return torch.float8_e4m3fn # auto try: major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) except Exception: major = 0 return torch.float8_e5m2 if major < 9 else torch.float8_e4m3fn def bytes_of(t: torch.Tensor) -> int: """Size in bytes (FP8=1 byte/elt).""" if t.dtype in (torch.float8_e5m2, torch.float8_e4m3fn): return t.numel() * 1 return t.numel() * t.element_size() def human_gb(nbytes: int) -> float: return nbytes / (1024 ** 3) def _is_denied(name: str) -> bool: return any(tok in name for tok in _DENY_SUBSTRINGS) def should_convert_to_fp8(name: str, aggressive: bool) -> bool: """Match names for conversion, with modulation linears gated by --aggressive.""" if not name.endswith(".weight"): return False if _is_denied(name): return False for pat in _ALLOW_PATTERNS: if pat.search(name): # Gate modulation linears (single/triple) behind --aggressive if ( ".modulation.linear.weight" in name or ".audio_mod.linear.weight" in name or ".v_cond_mod.linear.weight" in name ): return aggressive return True return False # --------------------------- Core --------------------------- def convert_state_dict( sd: Dict[str, torch.Tensor], fp8_mode: str = "auto", aggressive: bool = False, recode_fp8: bool = False, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: """ Convert selected weights to FP8 storage according to the policy. Honors existing FP8 unless recode_fp8=True. Returns (new_sd, stats) with byte counts. """ tgt_dtype = pick_fp8_dtype(fp8_mode) out: Dict[str, torch.Tensor] = {} stats = { "total_before": 0, "total_after": 0, "converted_count": 0, "kept_fp8_count": 0, "skipped_count": 0, } for name, tensor in sd.items(): before = bytes_of(tensor) stats["total_before"] += before # Respect existing FP8 unless asked to recode if tensor.dtype in (torch.float8_e5m2, torch.float8_e4m3fn): if recode_fp8: out[name] = tensor.to(dtype=tgt_dtype) stats["converted_count"] += 1 else: out[name] = tensor stats["kept_fp8_count"] += 1 stats["total_after"] += bytes_of(out[name]) continue # Decide conversion if should_convert_to_fp8(name, aggressive): out[name] = tensor.to(dtype=tgt_dtype) stats["converted_count"] += 1 else: out[name] = tensor stats["skipped_count"] += 1 stats["total_after"] += bytes_of(out[name]) return out, stats # --------------------------- CLI --------------------------- def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Convert selected weights in a safetensors file to FP8 storage.") p.add_argument("in_path", help="Input .safetensors") p.add_argument("out_path", nargs="?", help="Output .safetensors (optional)") p.add_argument("--fp8", choices=["auto", "e5m2", "e4m3fn"], default="auto", help='Target FP8 storage dtype: "auto" (default), "e5m2", or "e4m3fn"') p.add_argument("--aggressive", action="store_true", help="Also convert modulation linears (audio_mod/v_cond_mod + single modulation.linear).") p.add_argument("--recode-fp8", action="store_true", help="Re-encode existing FP8 tensors to the chosen target dtype.") p.add_argument("--dry", action="store_true", help="Dry run: report only; do not write output file.") return p.parse_args() def main(): args = parse_args() print(f"[load] {args.in_path}") sd = load_file(args.in_path) tgt = pick_fp8_dtype(args.fp8) if not args.out_path: args.out_path = default_out_path(args.in_path, tgt) print(f"[auto-out] {args.out_path}") print(f"[policy] fp8_mode={args.fp8} -> {str(tgt).replace('torch.','')}, " f"aggressive={args.aggressive}, recode_fp8={args.recode_fp8}") new_sd, stats = convert_state_dict( sd, fp8_mode=args.fp8, aggressive=args.aggressive, recode_fp8=args.recode_fp8, ) saved = stats["total_before"] - stats["total_after"] print(f"[stats] tensors: {len(sd)}") print(f"[stats] converted: {stats['converted_count']} | kept_fp8: {stats['kept_fp8_count']} " f"| skipped: {stats['skipped_count']}") print(f"[bytes] before={human_gb(stats['total_before']):.3f} GiB | " f"after={human_gb(stats['total_after']):.3f} GiB | saved={human_gb(saved):.3f} GiB") if args.dry: print("[dry] no file written") return print(f"[save] {args.out_path}") save_file(new_sd, args.out_path) print("[done]") if __name__ == "__main__": main()