|
|
import argparse |
|
|
from safetensors import safe_open |
|
|
from collections import Counter |
|
|
import os |
|
|
import math |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DTYPE_TO_BYTES = { |
|
|
"BOOL": 1, |
|
|
|
|
|
"F8_E5M2": 1, |
|
|
"F8E5M2": 1, |
|
|
"F8_E4M3FN": 1, |
|
|
"F8E4M3FN": 1, |
|
|
"F8_E4M3": 1, |
|
|
"F8_E5M2FNUZ": 1, |
|
|
"F8E5M2FNUZ": 1, |
|
|
"F8_E4M3FNUZ": 1, |
|
|
"F8E4M3FNUZ": 1, |
|
|
|
|
|
"F16": 2, |
|
|
"BF16": 2, |
|
|
"F32": 4, |
|
|
"F64": 8, |
|
|
|
|
|
"I8": 1, |
|
|
"I16": 2, |
|
|
"I32": 4, |
|
|
"I64": 8, |
|
|
|
|
|
"U8": 1, |
|
|
"U16": 2, |
|
|
"U32": 4, |
|
|
"U64": 8, |
|
|
} |
|
|
|
|
|
def get_bytes_per_element(dtype_str): |
|
|
"""Returns the number of bytes for a given safetensors dtype string.""" |
|
|
return DTYPE_TO_BYTES.get(dtype_str.upper(), None) |
|
|
|
|
|
def calculate_num_elements(shape): |
|
|
"""Calculates the total number of elements from a tensor shape tuple.""" |
|
|
if not shape: |
|
|
return 1 |
|
|
if 0 in shape: |
|
|
return 0 |
|
|
|
|
|
|
|
|
num_elements = 1 |
|
|
for dim_size in shape: |
|
|
num_elements *= dim_size |
|
|
return num_elements |
|
|
|
|
|
def inspect_safetensors_precision_and_size(filepath): |
|
|
""" |
|
|
Reads a .safetensors file, iterates through its tensors, |
|
|
and reports the precision (dtype), actual size, and theoretical FP32 size. |
|
|
""" |
|
|
if not os.path.exists(filepath): |
|
|
print(f"Error: File not found at '{filepath}'") |
|
|
return |
|
|
|
|
|
if not filepath.lower().endswith(".safetensors"): |
|
|
print(f"Warning: File '{filepath}' does not have a .safetensors extension. Attempting to read anyway.") |
|
|
|
|
|
tensor_info_list = [] |
|
|
dtype_counts = Counter() |
|
|
total_actual_mb = 0.0 |
|
|
total_fp32_equiv_mb = 0.0 |
|
|
|
|
|
try: |
|
|
print(f"Inspecting tensors in: {filepath}\n") |
|
|
with safe_open(filepath, framework="pt", device="cpu") as f: |
|
|
tensor_keys = list(f.keys()) |
|
|
if not tensor_keys: |
|
|
print("No tensors found in the file.") |
|
|
return |
|
|
|
|
|
max_key_len = len("Tensor Name") |
|
|
if tensor_keys: |
|
|
max_key_len = max(max_key_len, max(len(k) for k in tensor_keys)) |
|
|
|
|
|
header = ( |
|
|
f"{'Tensor Name':<{max_key_len}} | " |
|
|
f"{'Precision (dtype)':<17} | " |
|
|
f"{'Actual Size (MB)':>16} | " |
|
|
f"{'FP32 Equiv. (MB)':>18}" |
|
|
) |
|
|
print(header) |
|
|
print( |
|
|
f"{'-' * max_key_len}-|-------------------|------------------|-------------------" |
|
|
) |
|
|
|
|
|
for key in tensor_keys: |
|
|
tensor_slice = f.get_slice(key) |
|
|
dtype_str = tensor_slice.get_dtype() |
|
|
shape = tensor_slice.get_shape() |
|
|
|
|
|
num_elements = calculate_num_elements(shape) |
|
|
bytes_per_el_actual = get_bytes_per_element(dtype_str) |
|
|
|
|
|
actual_size_mb_str = "N/A" |
|
|
fp32_equiv_size_mb_str = "N/A" |
|
|
actual_size_mb_val = 0.0 |
|
|
|
|
|
if bytes_per_el_actual is not None: |
|
|
actual_bytes = num_elements * bytes_per_el_actual |
|
|
actual_size_mb_val = actual_bytes / (1024 * 1024) |
|
|
total_actual_mb += actual_size_mb_val |
|
|
actual_size_mb_str = f"{actual_size_mb_val:.3f}" |
|
|
|
|
|
|
|
|
fp32_equiv_bytes = num_elements * 4 |
|
|
fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024) |
|
|
total_fp32_equiv_mb += fp32_equiv_size_mb_val |
|
|
fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}" |
|
|
else: |
|
|
print(f"Warning: Unknown dtype '{dtype_str}' for tensor '{key}'. Cannot calculate size.") |
|
|
|
|
|
print( |
|
|
f"{key:<{max_key_len}} | " |
|
|
f"{dtype_str:<17} | " |
|
|
f"{actual_size_mb_str:>16} | " |
|
|
f"{fp32_equiv_size_mb_str:>18}" |
|
|
) |
|
|
dtype_counts[dtype_str] += 1 |
|
|
|
|
|
print("\n--- Summary ---") |
|
|
print(f"Total tensors found: {len(tensor_keys)}") |
|
|
if dtype_counts: |
|
|
print("Precision distribution:") |
|
|
for dtype, count in dtype_counts.most_common(): |
|
|
print(f" - {dtype:<12}: {count} tensor(s)") |
|
|
else: |
|
|
print("No dtypes to summarize.") |
|
|
|
|
|
print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB") |
|
|
print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB") |
|
|
|
|
|
if total_fp32_equiv_mb > 0.00001: |
|
|
savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100 |
|
|
print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%") |
|
|
else: |
|
|
print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred while processing '{filepath}':") |
|
|
print(f" {e}") |
|
|
print("Please ensure it's a valid .safetensors file and the 'safetensors' (and 'torch') libraries are installed correctly.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Inspect tensor precision (dtype) and size in a .safetensors file." |
|
|
) |
|
|
parser.add_argument( |
|
|
"filepath", |
|
|
help="Path to the .safetensors file to inspect." |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
inspect_safetensors_precision_and_size(args.filepath) |