HunyuanVideo-Foley / scripts /safetensors_info.py
phazei's picture
Move scripts
7a94e98
import argparse
from safetensors import safe_open
from collections import Counter
import os
import math # math.prod is Python 3.8+
# --- Dtype to Bytes Mapping ---
# Safetensors Dtype strings:
# BOOL, F8_E5M2, F8_E4M3FN, F16, BF16, F32, F64,
# I8, I16, I32, I64, U8, U16, U32, U64,
# F8_E5M2FNUZ, F8_E4M3FNUZ
DTYPE_TO_BYTES = {
"BOOL": 1,
# Float8 variants
"F8_E5M2": 1,
"F8E5M2": 1, # Common alternative naming
"F8_E4M3FN": 1,
"F8E4M3FN": 1, # Common alternative naming
"F8_E4M3": 1, # As seen in user example, likely E4M3FN
"F8_E5M2FNUZ": 1,
"F8E5M2FNUZ": 1, # Common alternative naming
"F8_E4M3FNUZ": 1,
"F8E4M3FNUZ": 1, # Common alternative naming
# Standard floats
"F16": 2,
"BF16": 2,
"F32": 4,
"F64": 8,
# Integers
"I8": 1,
"I16": 2,
"I32": 4,
"I64": 8,
# Unsigned Integers
"U8": 1,
"U16": 2,
"U32": 4,
"U64": 8,
}
def get_bytes_per_element(dtype_str):
"""Returns the number of bytes for a given safetensors dtype string."""
return DTYPE_TO_BYTES.get(dtype_str.upper(), None)
def calculate_num_elements(shape):
"""Calculates the total number of elements from a tensor shape tuple."""
if not shape: # Scalar tensor (shape is ())
return 1
if 0 in shape: # If any dimension is 0, total elements is 0
return 0
# Using math.prod for conciseness if Python 3.8+
# For broader compatibility, a loop can be used:
num_elements = 1
for dim_size in shape:
num_elements *= dim_size
return num_elements
def inspect_safetensors_precision_and_size(filepath):
"""
Reads a .safetensors file, iterates through its tensors,
and reports the precision (dtype), actual size, and theoretical FP32 size.
"""
if not os.path.exists(filepath):
print(f"Error: File not found at '{filepath}'")
return
if not filepath.lower().endswith(".safetensors"):
print(f"Warning: File '{filepath}' does not have a .safetensors extension. Attempting to read anyway.")
tensor_info_list = []
dtype_counts = Counter()
total_actual_mb = 0.0
total_fp32_equiv_mb = 0.0
try:
print(f"Inspecting tensors in: {filepath}\n")
with safe_open(filepath, framework="pt", device="cpu") as f:
tensor_keys = list(f.keys())
if not tensor_keys:
print("No tensors found in the file.")
return
max_key_len = len("Tensor Name") # Default/minimum
if tensor_keys:
max_key_len = max(max_key_len, max(len(k) for k in tensor_keys))
header = (
f"{'Tensor Name':<{max_key_len}} | "
f"{'Precision (dtype)':<17} | "
f"{'Actual Size (MB)':>16} | "
f"{'FP32 Equiv. (MB)':>18}"
)
print(header)
print(
f"{'-' * max_key_len}-|-------------------|------------------|-------------------"
)
for key in tensor_keys:
tensor_slice = f.get_slice(key)
dtype_str = tensor_slice.get_dtype()
shape = tensor_slice.get_shape()
num_elements = calculate_num_elements(shape)
bytes_per_el_actual = get_bytes_per_element(dtype_str)
actual_size_mb_str = "N/A"
fp32_equiv_size_mb_str = "N/A"
actual_size_mb_val = 0.0
if bytes_per_el_actual is not None:
actual_bytes = num_elements * bytes_per_el_actual
actual_size_mb_val = actual_bytes / (1024 * 1024)
total_actual_mb += actual_size_mb_val
actual_size_mb_str = f"{actual_size_mb_val:.3f}"
# Theoretical FP32 size (FP32 is 4 bytes per element)
fp32_equiv_bytes = num_elements * 4
fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024)
total_fp32_equiv_mb += fp32_equiv_size_mb_val
fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}"
else:
print(f"Warning: Unknown dtype '{dtype_str}' for tensor '{key}'. Cannot calculate size.")
print(
f"{key:<{max_key_len}} | "
f"{dtype_str:<17} | "
f"{actual_size_mb_str:>16} | "
f"{fp32_equiv_size_mb_str:>18}"
)
dtype_counts[dtype_str] += 1
print("\n--- Summary ---")
print(f"Total tensors found: {len(tensor_keys)}")
if dtype_counts:
print("Precision distribution:")
for dtype, count in dtype_counts.most_common():
print(f" - {dtype:<12}: {count} tensor(s)")
else:
print("No dtypes to summarize.")
print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB")
print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB")
if total_fp32_equiv_mb > 0.00001: # Avoid division by zero or near-zero
savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100
print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%")
else:
print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).")
except Exception as e:
print(f"An error occurred while processing '{filepath}':")
print(f" {e}")
print("Please ensure it's a valid .safetensors file and the 'safetensors' (and 'torch') libraries are installed correctly.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Inspect tensor precision (dtype) and size in a .safetensors file."
)
parser.add_argument(
"filepath",
help="Path to the .safetensors file to inspect."
)
args = parser.parse_args()
inspect_safetensors_precision_and_size(args.filepath)