import argparse from safetensors import safe_open from collections import Counter import os import math # math.prod is Python 3.8+ # --- Dtype to Bytes Mapping --- # Safetensors Dtype strings: # BOOL, F8_E5M2, F8_E4M3FN, F16, BF16, F32, F64, # I8, I16, I32, I64, U8, U16, U32, U64, # F8_E5M2FNUZ, F8_E4M3FNUZ DTYPE_TO_BYTES = { "BOOL": 1, # Float8 variants "F8_E5M2": 1, "F8E5M2": 1, # Common alternative naming "F8_E4M3FN": 1, "F8E4M3FN": 1, # Common alternative naming "F8_E4M3": 1, # As seen in user example, likely E4M3FN "F8_E5M2FNUZ": 1, "F8E5M2FNUZ": 1, # Common alternative naming "F8_E4M3FNUZ": 1, "F8E4M3FNUZ": 1, # Common alternative naming # Standard floats "F16": 2, "BF16": 2, "F32": 4, "F64": 8, # Integers "I8": 1, "I16": 2, "I32": 4, "I64": 8, # Unsigned Integers "U8": 1, "U16": 2, "U32": 4, "U64": 8, } def get_bytes_per_element(dtype_str): """Returns the number of bytes for a given safetensors dtype string.""" return DTYPE_TO_BYTES.get(dtype_str.upper(), None) def calculate_num_elements(shape): """Calculates the total number of elements from a tensor shape tuple.""" if not shape: # Scalar tensor (shape is ()) return 1 if 0 in shape: # If any dimension is 0, total elements is 0 return 0 # Using math.prod for conciseness if Python 3.8+ # For broader compatibility, a loop can be used: num_elements = 1 for dim_size in shape: num_elements *= dim_size return num_elements def inspect_safetensors_precision_and_size(filepath): """ Reads a .safetensors file, iterates through its tensors, and reports the precision (dtype), actual size, and theoretical FP32 size. """ if not os.path.exists(filepath): print(f"Error: File not found at '{filepath}'") return if not filepath.lower().endswith(".safetensors"): print(f"Warning: File '{filepath}' does not have a .safetensors extension. Attempting to read anyway.") tensor_info_list = [] dtype_counts = Counter() total_actual_mb = 0.0 total_fp32_equiv_mb = 0.0 try: print(f"Inspecting tensors in: {filepath}\n") with safe_open(filepath, framework="pt", device="cpu") as f: tensor_keys = list(f.keys()) if not tensor_keys: print("No tensors found in the file.") return max_key_len = len("Tensor Name") # Default/minimum if tensor_keys: max_key_len = max(max_key_len, max(len(k) for k in tensor_keys)) header = ( f"{'Tensor Name':<{max_key_len}} | " f"{'Precision (dtype)':<17} | " f"{'Actual Size (MB)':>16} | " f"{'FP32 Equiv. (MB)':>18}" ) print(header) print( f"{'-' * max_key_len}-|-------------------|------------------|-------------------" ) for key in tensor_keys: tensor_slice = f.get_slice(key) dtype_str = tensor_slice.get_dtype() shape = tensor_slice.get_shape() num_elements = calculate_num_elements(shape) bytes_per_el_actual = get_bytes_per_element(dtype_str) actual_size_mb_str = "N/A" fp32_equiv_size_mb_str = "N/A" actual_size_mb_val = 0.0 if bytes_per_el_actual is not None: actual_bytes = num_elements * bytes_per_el_actual actual_size_mb_val = actual_bytes / (1024 * 1024) total_actual_mb += actual_size_mb_val actual_size_mb_str = f"{actual_size_mb_val:.3f}" # Theoretical FP32 size (FP32 is 4 bytes per element) fp32_equiv_bytes = num_elements * 4 fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024) total_fp32_equiv_mb += fp32_equiv_size_mb_val fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}" else: print(f"Warning: Unknown dtype '{dtype_str}' for tensor '{key}'. Cannot calculate size.") print( f"{key:<{max_key_len}} | " f"{dtype_str:<17} | " f"{actual_size_mb_str:>16} | " f"{fp32_equiv_size_mb_str:>18}" ) dtype_counts[dtype_str] += 1 print("\n--- Summary ---") print(f"Total tensors found: {len(tensor_keys)}") if dtype_counts: print("Precision distribution:") for dtype, count in dtype_counts.most_common(): print(f" - {dtype:<12}: {count} tensor(s)") else: print("No dtypes to summarize.") print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB") print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB") if total_fp32_equiv_mb > 0.00001: # Avoid division by zero or near-zero savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100 print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%") else: print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).") except Exception as e: print(f"An error occurred while processing '{filepath}':") print(f" {e}") print("Please ensure it's a valid .safetensors file and the 'safetensors' (and 'torch') libraries are installed correctly.") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Inspect tensor precision (dtype) and size in a .safetensors file." ) parser.add_argument( "filepath", help="Path to the .safetensors file to inspect." ) args = parser.parse_args() inspect_safetensors_precision_and_size(args.filepath)