File size: 7,944 Bytes
c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 0873e19 c1f87d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
#!/usr/bin/env python3
"""
Mixed-FP8 safetensors converter for Hunyuan-Foley checkpoints.
- Converts selected .weight tensors to FP8 storage (E5M2 by default on pre-Hopper).
- Keeps math in FP16/BF16; this is a storage-only change in the file.
- Honors existing FP8 tensors in the input unless --recode-fp8 is set.
- Skips norms, biases, visual_proj.*, final_layer.* by design.
- Optional --aggressive converts modulation linears too.
USAGE (simple):
python convert_fp8.py in.safetensors [out.safetensors] # out is optional
USAGE (flags):
python convert_fp8.py in.safetensors out.safetensors --fp8 auto --aggressive
Notes:
- “auto” picks FP8_E5M2 on SM < 90 (e.g., 3090), else FP8_E4M3FN.
- You can force a format: --fp8 e5m2 | e4m3fn
- Dry run: add --dry to print what would change without writing.
"""
import argparse
import re
from typing import Dict, Tuple
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
# --------------------------- Policy (names) ---------------------------
# Skip norms/bias and sensitive endpoints explicitly
_DENY_SUBSTRINGS = (
".bias", ".norm", "q_norm.", "k_norm.",
"final_layer.", "visual_proj.",
)
# Allowed patterns target this architecture’s large linears
_ALLOW_PATTERNS = tuple(re.compile(p) for p in (
# Single-stream blocks
r"^single_blocks\.\d+\.linear1\.weight$",
r"^single_blocks\.\d+\.linear2\.w[123]\.weight$", # w1/w2/w3
r"^single_blocks\.\d+\.linear_qkv\.weight$",
r"^single_blocks\.\d+\.modulation\.linear\.weight$", # gated by --aggressive
# Triple-stream blocks: MLPs (dominant size)
r"^triple_blocks\.\d+\.audio_mlp\.fc[12]\.weight$",
r"^triple_blocks\.\d+\.v_cond_mlp\.fc[12]\.weight$",
# Triple-stream blocks: attention projections
r"^triple_blocks\.\d+\.(audio_self_attn_qkv|v_cond_attn_qkv|text_cross_kv)\.weight$",
r"^triple_blocks\.\d+\.(audio_self_proj|v_cond_self_proj)\.weight$",
# r"^triple_blocks\.\d+\.(audio_cross_q|v_cond_cross_q)\.weight$",
# r"^triple_blocks\.\d+\.(audio_cross_proj|v_cond_cross_proj)\.weight$",
# Triple-stream blocks: modulation linears (gated)
r"^triple_blocks\.\d+\.(audio_mod|v_cond_mod)\.linear\.weight$",
))
# --------------------------- Helpers ---------------------------
def default_out_path(in_path: str, tgt_dtype: torch.dtype) -> str:
"""<in>_fp8_<e5m2|e4m3fn>.safetensors (idempotent if already suffixed)."""
suffix = "e5m2" if tgt_dtype == torch.float8_e5m2 else "e4m3fn"
p = Path(in_path)
stem = re.sub(r"_fp8_e(5m2|4m3fn)$", "", p.stem) # strip prior suffix
ext = p.suffix or ".safetensors"
return str(p.with_name(f"{stem}_fp8_{suffix}{ext}"))
def pick_fp8_dtype(fp8_mode: str) -> torch.dtype:
"""Pick target FP8 dtype."""
m = fp8_mode.lower()
if m == "e5m2":
return torch.float8_e5m2
if m == "e4m3fn":
return torch.float8_e4m3fn
# auto
try:
major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
except Exception:
major = 0
return torch.float8_e5m2 if major < 9 else torch.float8_e4m3fn
def bytes_of(t: torch.Tensor) -> int:
"""Size in bytes (FP8=1 byte/elt)."""
if t.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
return t.numel() * 1
return t.numel() * t.element_size()
def human_gb(nbytes: int) -> float:
return nbytes / (1024 ** 3)
def _is_denied(name: str) -> bool:
return any(tok in name for tok in _DENY_SUBSTRINGS)
def should_convert_to_fp8(name: str, aggressive: bool) -> bool:
"""Match names for conversion, with modulation linears gated by --aggressive."""
if not name.endswith(".weight"):
return False
if _is_denied(name):
return False
for pat in _ALLOW_PATTERNS:
if pat.search(name):
# Gate modulation linears (single/triple) behind --aggressive
if (
".modulation.linear.weight" in name
or ".audio_mod.linear.weight" in name
or ".v_cond_mod.linear.weight" in name
):
return aggressive
return True
return False
# --------------------------- Core ---------------------------
def convert_state_dict(
sd: Dict[str, torch.Tensor],
fp8_mode: str = "auto",
aggressive: bool = False,
recode_fp8: bool = False,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]:
"""
Convert selected weights to FP8 storage according to the policy.
Honors existing FP8 unless recode_fp8=True.
Returns (new_sd, stats) with byte counts.
"""
tgt_dtype = pick_fp8_dtype(fp8_mode)
out: Dict[str, torch.Tensor] = {}
stats = {
"total_before": 0,
"total_after": 0,
"converted_count": 0,
"kept_fp8_count": 0,
"skipped_count": 0,
}
for name, tensor in sd.items():
before = bytes_of(tensor)
stats["total_before"] += before
# Respect existing FP8 unless asked to recode
if tensor.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
if recode_fp8:
out[name] = tensor.to(dtype=tgt_dtype)
stats["converted_count"] += 1
else:
out[name] = tensor
stats["kept_fp8_count"] += 1
stats["total_after"] += bytes_of(out[name])
continue
# Decide conversion
if should_convert_to_fp8(name, aggressive):
out[name] = tensor.to(dtype=tgt_dtype)
stats["converted_count"] += 1
else:
out[name] = tensor
stats["skipped_count"] += 1
stats["total_after"] += bytes_of(out[name])
return out, stats
# --------------------------- CLI ---------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Convert selected weights in a safetensors file to FP8 storage.")
p.add_argument("in_path", help="Input .safetensors")
p.add_argument("out_path", nargs="?", help="Output .safetensors (optional)")
p.add_argument("--fp8", choices=["auto", "e5m2", "e4m3fn"], default="auto",
help='Target FP8 storage dtype: "auto" (default), "e5m2", or "e4m3fn"')
p.add_argument("--aggressive", action="store_true",
help="Also convert modulation linears (audio_mod/v_cond_mod + single modulation.linear).")
p.add_argument("--recode-fp8", action="store_true",
help="Re-encode existing FP8 tensors to the chosen target dtype.")
p.add_argument("--dry", action="store_true",
help="Dry run: report only; do not write output file.")
return p.parse_args()
def main():
args = parse_args()
print(f"[load] {args.in_path}")
sd = load_file(args.in_path)
tgt = pick_fp8_dtype(args.fp8)
if not args.out_path:
args.out_path = default_out_path(args.in_path, tgt)
print(f"[auto-out] {args.out_path}")
print(f"[policy] fp8_mode={args.fp8} -> {str(tgt).replace('torch.','')}, "
f"aggressive={args.aggressive}, recode_fp8={args.recode_fp8}")
new_sd, stats = convert_state_dict(
sd,
fp8_mode=args.fp8,
aggressive=args.aggressive,
recode_fp8=args.recode_fp8,
)
saved = stats["total_before"] - stats["total_after"]
print(f"[stats] tensors: {len(sd)}")
print(f"[stats] converted: {stats['converted_count']} | kept_fp8: {stats['kept_fp8_count']} "
f"| skipped: {stats['skipped_count']}")
print(f"[bytes] before={human_gb(stats['total_before']):.3f} GiB | "
f"after={human_gb(stats['total_after']):.3f} GiB | saved={human_gb(saved):.3f} GiB")
if args.dry:
print("[dry] no file written")
return
print(f"[save] {args.out_path}")
save_file(new_sd, args.out_path)
print("[done]")
if __name__ == "__main__":
main()
|