phazei's picture
Move scripts
7a94e98
import argparse
import torch
from collections import Counter
import os
import math
# --- PyTorch Dtype to Bytes Mapping ---
TORCH_DTYPE_TO_BYTES = {
# Boolean
torch.bool: 1,
# Floating point
torch.float16: 2,
torch.half: 2, # alias for float16
torch.bfloat16: 2,
torch.float32: 4,
torch.float: 4, # alias for float32
torch.float64: 8,
torch.double: 8, # alias for float64
# Complex
torch.complex64: 8, # 2 * float32
torch.complex128: 16, # 2 * float64
torch.cfloat: 8, # alias for complex64
torch.cdouble: 16, # alias for complex128
# Signed integers
torch.int8: 1,
torch.int16: 2,
torch.short: 2, # alias for int16
torch.int32: 4,
torch.int: 4, # alias for int32
torch.int64: 8,
torch.long: 8, # alias for int64
# Unsigned integers
torch.uint8: 1,
torch.uint16: 2,
torch.uint32: 4,
torch.uint64: 8,
# Quantized types (approximate sizes)
torch.qint8: 1,
torch.quint8: 1,
torch.qint32: 4,
torch.quint4x2: 1, # 4-bit packed
}
def get_bytes_per_element(dtype):
"""Returns the number of bytes for a given PyTorch dtype."""
return TORCH_DTYPE_TO_BYTES.get(dtype, None)
def get_dtype_name(dtype):
"""Returns a readable string for a PyTorch dtype."""
return str(dtype).replace('torch.', '')
def calculate_num_elements(shape):
"""Calculates the total number of elements from a tensor shape tuple."""
if not shape: # Scalar tensor (shape is ())
return 1
if 0 in shape: # If any dimension is 0, total elements is 0
return 0
num_elements = 1
for dim_size in shape:
num_elements *= dim_size
return num_elements
def extract_tensors_from_obj(obj, prefix=""):
"""
Recursively extracts tensors from nested dictionaries/objects.
Returns a dictionary of {key: tensor} pairs.
"""
tensors = {}
if isinstance(obj, torch.Tensor):
return {prefix or "tensor": obj}
elif isinstance(obj, dict):
for key, value in obj.items():
new_prefix = f"{prefix}.{key}" if prefix else key
tensors.update(extract_tensors_from_obj(value, new_prefix))
elif hasattr(obj, 'state_dict') and callable(getattr(obj, 'state_dict')):
# Handle nn.Module objects
state_dict = obj.state_dict()
new_prefix = f"{prefix}.state_dict" if prefix else "state_dict"
tensors.update(extract_tensors_from_obj(state_dict, new_prefix))
elif hasattr(obj, '__dict__'):
# Handle other objects with attributes
for key, value in obj.__dict__.items():
if isinstance(value, torch.Tensor):
new_prefix = f"{prefix}.{key}" if prefix else key
tensors[new_prefix] = value
return tensors
def inspect_pth_precision_and_size(filepath):
"""
Reads a .pth file, extracts tensors from it,
and reports the precision (dtype), actual size, and theoretical FP32 size.
"""
if not os.path.exists(filepath):
print(f"Error: File not found at '{filepath}'")
return
try:
print(f"Loading PyTorch file: {filepath}")
# Load with weights_only=True for security if PyTorch >= 2.0
try:
obj = torch.load(filepath, map_location="cpu", weights_only=True)
print("(Loaded with weights_only=True for security)\n")
except TypeError:
# Fallback for older PyTorch versions
obj = torch.load(filepath, map_location="cpu")
print("(Warning: Loaded without weights_only=True - older PyTorch version)\n")
# Extract all tensors from the loaded object
tensors = extract_tensors_from_obj(obj)
if not tensors:
print("No tensors found in the file.")
return
tensor_info_list = []
dtype_counts = Counter()
total_actual_mb = 0.0
total_fp32_equiv_mb = 0.0
max_key_len = max(len("Tensor Name"), max(len(k) for k in tensors.keys()))
header = (
f"{'Tensor Name':<{max_key_len}} | "
f"{'Precision (dtype)':<17} | "
f"{'Shape':<20} | "
f"{'Actual Size (MB)':>16} | "
f"{'FP32 Equiv. (MB)':>18}"
)
print(header)
print(
f"{'-' * max_key_len}-|-------------------|{'-' * 20}|------------------|-------------------"
)
for key, tensor in tensors.items():
dtype = tensor.dtype
dtype_name = get_dtype_name(dtype)
shape = tuple(tensor.shape)
shape_str = str(shape)
num_elements = tensor.numel()
bytes_per_el_actual = get_bytes_per_element(dtype)
actual_size_mb_str = "N/A"
fp32_equiv_size_mb_str = "N/A"
actual_size_mb_val = 0.0
if bytes_per_el_actual is not None:
actual_bytes = num_elements * bytes_per_el_actual
actual_size_mb_val = actual_bytes / (1024 * 1024)
total_actual_mb += actual_size_mb_val
actual_size_mb_str = f"{actual_size_mb_val:.3f}"
# Theoretical FP32 size (FP32 is 4 bytes per element)
fp32_equiv_bytes = num_elements * 4
fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024)
total_fp32_equiv_mb += fp32_equiv_size_mb_val
fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}"
else:
print(f"Warning: Unknown dtype '{dtype}' for tensor '{key}'. Cannot calculate size.")
# Truncate shape string if too long
if len(shape_str) > 18:
shape_str = shape_str[:15] + "..."
print(
f"{key:<{max_key_len}} | "
f"{dtype_name:<17} | "
f"{shape_str:<20} | "
f"{actual_size_mb_str:>16} | "
f"{fp32_equiv_size_mb_str:>18}"
)
dtype_counts[dtype_name] += 1
print("\n--- Summary ---")
print(f"Total tensors found: {len(tensors)}")
if dtype_counts:
print("Precision distribution:")
for dtype, count in dtype_counts.most_common():
print(f" - {dtype:<12}: {count} tensor(s)")
else:
print("No dtypes to summarize.")
print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB")
print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB")
if total_fp32_equiv_mb > 0.00001: # Avoid division by zero
savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100
print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%")
else:
print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).")
# Additional info about non-tensor content
non_tensor_keys = []
if isinstance(obj, dict):
for key, value in obj.items():
if key not in [k.split('.')[0] for k in tensors.keys()]: # Simplified check
non_tensor_keys.append(f"{key}: {type(value).__name__}")
if non_tensor_keys:
print(f"\nNon-tensor content found:")
for item in non_tensor_keys[:5]: # Show first 5
print(f" - {item}")
if len(non_tensor_keys) > 5:
print(f" ... and {len(non_tensor_keys) - 5} more items")
except Exception as e:
print(f"An error occurred while processing '{filepath}':")
print(f" {e}")
print("Please ensure it's a valid PyTorch .pth file and PyTorch is installed correctly.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Inspect tensor precision (dtype) and size in a PyTorch .pth file."
)
parser.add_argument(
"filepath",
help="Path to the .pth file to inspect."
)
args = parser.parse_args()
inspect_pth_precision_and_size(args.filepath)