HunyuanVideo-Foley / scripts /convert_safetensors_to_fp8.py
phazei's picture
Move scripts
7a94e98
#!/usr/bin/env python3
"""
Mixed-FP8 safetensors converter for Hunyuan-Foley checkpoints.
- Converts selected .weight tensors to FP8 storage (E5M2 by default on pre-Hopper).
- Keeps math in FP16/BF16; this is a storage-only change in the file.
- Honors existing FP8 tensors in the input unless --recode-fp8 is set.
- Skips norms, biases, visual_proj.*, final_layer.* by design.
- Optional --aggressive converts modulation linears too.
USAGE (simple):
python convert_fp8.py in.safetensors [out.safetensors] # out is optional
USAGE (flags):
python convert_fp8.py in.safetensors out.safetensors --fp8 auto --aggressive
Notes:
- “auto” picks FP8_E5M2 on SM < 90 (e.g., 3090), else FP8_E4M3FN.
- You can force a format: --fp8 e5m2 | e4m3fn
- Dry run: add --dry to print what would change without writing.
"""
import argparse
import re
from typing import Dict, Tuple
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
# --------------------------- Policy (names) ---------------------------
# Skip norms/bias and sensitive endpoints explicitly
_DENY_SUBSTRINGS = (
".bias", ".norm", "q_norm.", "k_norm.",
"final_layer.", "visual_proj.",
)
# Allowed patterns target this architecture’s large linears
_ALLOW_PATTERNS = tuple(re.compile(p) for p in (
# Single-stream blocks
r"^single_blocks\.\d+\.linear1\.weight$",
r"^single_blocks\.\d+\.linear2\.w[123]\.weight$", # w1/w2/w3
r"^single_blocks\.\d+\.linear_qkv\.weight$",
r"^single_blocks\.\d+\.modulation\.linear\.weight$", # gated by --aggressive
# Triple-stream blocks: MLPs (dominant size)
r"^triple_blocks\.\d+\.audio_mlp\.fc[12]\.weight$",
r"^triple_blocks\.\d+\.v_cond_mlp\.fc[12]\.weight$",
# Triple-stream blocks: attention projections
r"^triple_blocks\.\d+\.(audio_self_attn_qkv|v_cond_attn_qkv|text_cross_kv)\.weight$",
r"^triple_blocks\.\d+\.(audio_self_proj|v_cond_self_proj)\.weight$",
# r"^triple_blocks\.\d+\.(audio_cross_q|v_cond_cross_q)\.weight$",
# r"^triple_blocks\.\d+\.(audio_cross_proj|v_cond_cross_proj)\.weight$",
# Triple-stream blocks: modulation linears (gated)
r"^triple_blocks\.\d+\.(audio_mod|v_cond_mod)\.linear\.weight$",
))
# --------------------------- Helpers ---------------------------
def default_out_path(in_path: str, tgt_dtype: torch.dtype) -> str:
"""<in>_fp8_<e5m2|e4m3fn>.safetensors (idempotent if already suffixed)."""
suffix = "e5m2" if tgt_dtype == torch.float8_e5m2 else "e4m3fn"
p = Path(in_path)
stem = re.sub(r"_fp8_e(5m2|4m3fn)$", "", p.stem) # strip prior suffix
ext = p.suffix or ".safetensors"
return str(p.with_name(f"{stem}_fp8_{suffix}{ext}"))
def pick_fp8_dtype(fp8_mode: str) -> torch.dtype:
"""Pick target FP8 dtype."""
m = fp8_mode.lower()
if m == "e5m2":
return torch.float8_e5m2
if m == "e4m3fn":
return torch.float8_e4m3fn
# auto
try:
major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
except Exception:
major = 0
return torch.float8_e5m2 if major < 9 else torch.float8_e4m3fn
def bytes_of(t: torch.Tensor) -> int:
"""Size in bytes (FP8=1 byte/elt)."""
if t.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
return t.numel() * 1
return t.numel() * t.element_size()
def human_gb(nbytes: int) -> float:
return nbytes / (1024 ** 3)
def _is_denied(name: str) -> bool:
return any(tok in name for tok in _DENY_SUBSTRINGS)
def should_convert_to_fp8(name: str, aggressive: bool) -> bool:
"""Match names for conversion, with modulation linears gated by --aggressive."""
if not name.endswith(".weight"):
return False
if _is_denied(name):
return False
for pat in _ALLOW_PATTERNS:
if pat.search(name):
# Gate modulation linears (single/triple) behind --aggressive
if (
".modulation.linear.weight" in name
or ".audio_mod.linear.weight" in name
or ".v_cond_mod.linear.weight" in name
):
return aggressive
return True
return False
# --------------------------- Core ---------------------------
def convert_state_dict(
sd: Dict[str, torch.Tensor],
fp8_mode: str = "auto",
aggressive: bool = False,
recode_fp8: bool = False,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]:
"""
Convert selected weights to FP8 storage according to the policy.
Honors existing FP8 unless recode_fp8=True.
Returns (new_sd, stats) with byte counts.
"""
tgt_dtype = pick_fp8_dtype(fp8_mode)
out: Dict[str, torch.Tensor] = {}
stats = {
"total_before": 0,
"total_after": 0,
"converted_count": 0,
"kept_fp8_count": 0,
"skipped_count": 0,
}
for name, tensor in sd.items():
before = bytes_of(tensor)
stats["total_before"] += before
# Respect existing FP8 unless asked to recode
if tensor.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
if recode_fp8:
out[name] = tensor.to(dtype=tgt_dtype)
stats["converted_count"] += 1
else:
out[name] = tensor
stats["kept_fp8_count"] += 1
stats["total_after"] += bytes_of(out[name])
continue
# Decide conversion
if should_convert_to_fp8(name, aggressive):
out[name] = tensor.to(dtype=tgt_dtype)
stats["converted_count"] += 1
else:
out[name] = tensor
stats["skipped_count"] += 1
stats["total_after"] += bytes_of(out[name])
return out, stats
# --------------------------- CLI ---------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Convert selected weights in a safetensors file to FP8 storage.")
p.add_argument("in_path", help="Input .safetensors")
p.add_argument("out_path", nargs="?", help="Output .safetensors (optional)")
p.add_argument("--fp8", choices=["auto", "e5m2", "e4m3fn"], default="auto",
help='Target FP8 storage dtype: "auto" (default), "e5m2", or "e4m3fn"')
p.add_argument("--aggressive", action="store_true",
help="Also convert modulation linears (audio_mod/v_cond_mod + single modulation.linear).")
p.add_argument("--recode-fp8", action="store_true",
help="Re-encode existing FP8 tensors to the chosen target dtype.")
p.add_argument("--dry", action="store_true",
help="Dry run: report only; do not write output file.")
return p.parse_args()
def main():
args = parse_args()
print(f"[load] {args.in_path}")
sd = load_file(args.in_path)
tgt = pick_fp8_dtype(args.fp8)
if not args.out_path:
args.out_path = default_out_path(args.in_path, tgt)
print(f"[auto-out] {args.out_path}")
print(f"[policy] fp8_mode={args.fp8} -> {str(tgt).replace('torch.','')}, "
f"aggressive={args.aggressive}, recode_fp8={args.recode_fp8}")
new_sd, stats = convert_state_dict(
sd,
fp8_mode=args.fp8,
aggressive=args.aggressive,
recode_fp8=args.recode_fp8,
)
saved = stats["total_before"] - stats["total_after"]
print(f"[stats] tensors: {len(sd)}")
print(f"[stats] converted: {stats['converted_count']} | kept_fp8: {stats['kept_fp8_count']} "
f"| skipped: {stats['skipped_count']}")
print(f"[bytes] before={human_gb(stats['total_before']):.3f} GiB | "
f"after={human_gb(stats['total_after']):.3f} GiB | saved={human_gb(saved):.3f} GiB")
if args.dry:
print("[dry] no file written")
return
print(f"[save] {args.out_path}")
save_file(new_sd, args.out_path)
print("[done]")
if __name__ == "__main__":
main()