File size: 7,944 Bytes
c1f87d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0873e19
c1f87d4
 
 
0873e19
c1f87d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0873e19
c1f87d4
0873e19
c1f87d4
 
 
 
 
 
 
 
 
 
 
0873e19
 
c1f87d4
 
 
 
 
 
 
 
 
0873e19
c1f87d4
 
 
0873e19
c1f87d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0873e19
 
c1f87d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0873e19
c1f87d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0873e19
 
c1f87d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python3
"""
Mixed-FP8 safetensors converter for Hunyuan-Foley checkpoints.

- Converts selected .weight tensors to FP8 storage (E5M2 by default on pre-Hopper).
- Keeps math in FP16/BF16; this is a storage-only change in the file.
- Honors existing FP8 tensors in the input unless --recode-fp8 is set.
- Skips norms, biases, visual_proj.*, final_layer.* by design.
- Optional --aggressive converts modulation linears too.

USAGE (simple):
  python convert_fp8.py in.safetensors [out.safetensors]   # out is optional

USAGE (flags):
  python convert_fp8.py in.safetensors out.safetensors --fp8 auto --aggressive

Notes:
- “auto” picks FP8_E5M2 on SM < 90 (e.g., 3090), else FP8_E4M3FN.
- You can force a format: --fp8 e5m2 | e4m3fn
- Dry run: add --dry to print what would change without writing.
"""

import argparse
import re
from typing import Dict, Tuple
from pathlib import Path

import torch
from safetensors.torch import load_file, save_file


# --------------------------- Policy (names) ---------------------------

# Skip norms/bias and sensitive endpoints explicitly
_DENY_SUBSTRINGS = (
    ".bias", ".norm", "q_norm.", "k_norm.",
    "final_layer.", "visual_proj.",
)

# Allowed patterns target this architecture’s large linears
_ALLOW_PATTERNS = tuple(re.compile(p) for p in (
    # Single-stream blocks
    r"^single_blocks\.\d+\.linear1\.weight$",
    r"^single_blocks\.\d+\.linear2\.w[123]\.weight$",           # w1/w2/w3
    r"^single_blocks\.\d+\.linear_qkv\.weight$",
    r"^single_blocks\.\d+\.modulation\.linear\.weight$",        # gated by --aggressive

    # Triple-stream blocks: MLPs (dominant size)
    r"^triple_blocks\.\d+\.audio_mlp\.fc[12]\.weight$",
    r"^triple_blocks\.\d+\.v_cond_mlp\.fc[12]\.weight$",

    # Triple-stream blocks: attention projections
    r"^triple_blocks\.\d+\.(audio_self_attn_qkv|v_cond_attn_qkv|text_cross_kv)\.weight$",
    r"^triple_blocks\.\d+\.(audio_self_proj|v_cond_self_proj)\.weight$",

    # r"^triple_blocks\.\d+\.(audio_cross_q|v_cond_cross_q)\.weight$",
    # r"^triple_blocks\.\d+\.(audio_cross_proj|v_cond_cross_proj)\.weight$",

    # Triple-stream blocks: modulation linears (gated)
    r"^triple_blocks\.\d+\.(audio_mod|v_cond_mod)\.linear\.weight$",
))


# --------------------------- Helpers ---------------------------

def default_out_path(in_path: str, tgt_dtype: torch.dtype) -> str:
    """<in>_fp8_<e5m2|e4m3fn>.safetensors (idempotent if already suffixed)."""
    suffix = "e5m2" if tgt_dtype == torch.float8_e5m2 else "e4m3fn"
    p = Path(in_path)
    stem = re.sub(r"_fp8_e(5m2|4m3fn)$", "", p.stem)  # strip prior suffix
    ext = p.suffix or ".safetensors"
    return str(p.with_name(f"{stem}_fp8_{suffix}{ext}"))


def pick_fp8_dtype(fp8_mode: str) -> torch.dtype:
    """Pick target FP8 dtype."""
    m = fp8_mode.lower()
    if m == "e5m2":
        return torch.float8_e5m2
    if m == "e4m3fn":
        return torch.float8_e4m3fn
    # auto
    try:
        major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
    except Exception:
        major = 0
    return torch.float8_e5m2 if major < 9 else torch.float8_e4m3fn


def bytes_of(t: torch.Tensor) -> int:
    """Size in bytes (FP8=1 byte/elt)."""
    if t.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
        return t.numel() * 1
    return t.numel() * t.element_size()


def human_gb(nbytes: int) -> float:
    return nbytes / (1024 ** 3)


def _is_denied(name: str) -> bool:
    return any(tok in name for tok in _DENY_SUBSTRINGS)


def should_convert_to_fp8(name: str, aggressive: bool) -> bool:
    """Match names for conversion, with modulation linears gated by --aggressive."""
    if not name.endswith(".weight"):
        return False
    if _is_denied(name):
        return False

    for pat in _ALLOW_PATTERNS:
        if pat.search(name):
            # Gate modulation linears (single/triple) behind --aggressive
            if (
                ".modulation.linear.weight" in name
                or ".audio_mod.linear.weight" in name
                or ".v_cond_mod.linear.weight" in name
            ):
                return aggressive
            return True
    return False


# --------------------------- Core ---------------------------

def convert_state_dict(
    sd: Dict[str, torch.Tensor],
    fp8_mode: str = "auto",
    aggressive: bool = False,
    recode_fp8: bool = False,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]:
    """
    Convert selected weights to FP8 storage according to the policy.
    Honors existing FP8 unless recode_fp8=True.
    Returns (new_sd, stats) with byte counts.
    """
    tgt_dtype = pick_fp8_dtype(fp8_mode)
    out: Dict[str, torch.Tensor] = {}
    stats = {
        "total_before": 0,
        "total_after": 0,
        "converted_count": 0,
        "kept_fp8_count": 0,
        "skipped_count": 0,
    }

    for name, tensor in sd.items():
        before = bytes_of(tensor)
        stats["total_before"] += before

        # Respect existing FP8 unless asked to recode
        if tensor.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
            if recode_fp8:
                out[name] = tensor.to(dtype=tgt_dtype)
                stats["converted_count"] += 1
            else:
                out[name] = tensor
                stats["kept_fp8_count"] += 1
            stats["total_after"] += bytes_of(out[name])
            continue

        # Decide conversion
        if should_convert_to_fp8(name, aggressive):
            out[name] = tensor.to(dtype=tgt_dtype)
            stats["converted_count"] += 1
        else:
            out[name] = tensor
            stats["skipped_count"] += 1

        stats["total_after"] += bytes_of(out[name])

    return out, stats


# --------------------------- CLI ---------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Convert selected weights in a safetensors file to FP8 storage.")
    p.add_argument("in_path",  help="Input .safetensors")
    p.add_argument("out_path", nargs="?", help="Output .safetensors (optional)")
    p.add_argument("--fp8", choices=["auto", "e5m2", "e4m3fn"], default="auto",
                   help='Target FP8 storage dtype: "auto" (default), "e5m2", or "e4m3fn"')
    p.add_argument("--aggressive", action="store_true",
                   help="Also convert modulation linears (audio_mod/v_cond_mod + single modulation.linear).")
    p.add_argument("--recode-fp8", action="store_true",
                   help="Re-encode existing FP8 tensors to the chosen target dtype.")
    p.add_argument("--dry", action="store_true",
                   help="Dry run: report only; do not write output file.")
    return p.parse_args()


def main():
    args = parse_args()

    print(f"[load] {args.in_path}")
    sd = load_file(args.in_path)

    tgt = pick_fp8_dtype(args.fp8)
    if not args.out_path:
        args.out_path = default_out_path(args.in_path, tgt)
        print(f"[auto-out] {args.out_path}")

    print(f"[policy] fp8_mode={args.fp8} -> {str(tgt).replace('torch.','')}, "
          f"aggressive={args.aggressive}, recode_fp8={args.recode_fp8}")

    new_sd, stats = convert_state_dict(
        sd,
        fp8_mode=args.fp8,
        aggressive=args.aggressive,
        recode_fp8=args.recode_fp8,
    )

    saved = stats["total_before"] - stats["total_after"]
    print(f"[stats] tensors: {len(sd)}")
    print(f"[stats] converted: {stats['converted_count']} | kept_fp8: {stats['kept_fp8_count']} "
          f"| skipped: {stats['skipped_count']}")
    print(f"[bytes] before={human_gb(stats['total_before']):.3f} GiB | "
          f"after={human_gb(stats['total_after']):.3f} GiB | saved={human_gb(saved):.3f} GiB")

    if args.dry:
        print("[dry] no file written")
        return

    print(f"[save] {args.out_path}")
    save_file(new_sd, args.out_path)
    print("[done]")


if __name__ == "__main__":
    main()