|
|
import argparse |
|
|
import torch |
|
|
from collections import Counter |
|
|
import os |
|
|
import math |
|
|
|
|
|
|
|
|
TORCH_DTYPE_TO_BYTES = { |
|
|
|
|
|
torch.bool: 1, |
|
|
|
|
|
torch.float16: 2, |
|
|
torch.half: 2, |
|
|
torch.bfloat16: 2, |
|
|
torch.float32: 4, |
|
|
torch.float: 4, |
|
|
torch.float64: 8, |
|
|
torch.double: 8, |
|
|
|
|
|
torch.complex64: 8, |
|
|
torch.complex128: 16, |
|
|
torch.cfloat: 8, |
|
|
torch.cdouble: 16, |
|
|
|
|
|
torch.int8: 1, |
|
|
torch.int16: 2, |
|
|
torch.short: 2, |
|
|
torch.int32: 4, |
|
|
torch.int: 4, |
|
|
torch.int64: 8, |
|
|
torch.long: 8, |
|
|
|
|
|
torch.uint8: 1, |
|
|
torch.uint16: 2, |
|
|
torch.uint32: 4, |
|
|
torch.uint64: 8, |
|
|
|
|
|
torch.qint8: 1, |
|
|
torch.quint8: 1, |
|
|
torch.qint32: 4, |
|
|
torch.quint4x2: 1, |
|
|
} |
|
|
|
|
|
def get_bytes_per_element(dtype): |
|
|
"""Returns the number of bytes for a given PyTorch dtype.""" |
|
|
return TORCH_DTYPE_TO_BYTES.get(dtype, None) |
|
|
|
|
|
def get_dtype_name(dtype): |
|
|
"""Returns a readable string for a PyTorch dtype.""" |
|
|
return str(dtype).replace('torch.', '') |
|
|
|
|
|
def calculate_num_elements(shape): |
|
|
"""Calculates the total number of elements from a tensor shape tuple.""" |
|
|
if not shape: |
|
|
return 1 |
|
|
if 0 in shape: |
|
|
return 0 |
|
|
num_elements = 1 |
|
|
for dim_size in shape: |
|
|
num_elements *= dim_size |
|
|
return num_elements |
|
|
|
|
|
def extract_tensors_from_obj(obj, prefix=""): |
|
|
""" |
|
|
Recursively extracts tensors from nested dictionaries/objects. |
|
|
Returns a dictionary of {key: tensor} pairs. |
|
|
""" |
|
|
tensors = {} |
|
|
|
|
|
if isinstance(obj, torch.Tensor): |
|
|
return {prefix or "tensor": obj} |
|
|
|
|
|
elif isinstance(obj, dict): |
|
|
for key, value in obj.items(): |
|
|
new_prefix = f"{prefix}.{key}" if prefix else key |
|
|
tensors.update(extract_tensors_from_obj(value, new_prefix)) |
|
|
|
|
|
elif hasattr(obj, 'state_dict') and callable(getattr(obj, 'state_dict')): |
|
|
|
|
|
state_dict = obj.state_dict() |
|
|
new_prefix = f"{prefix}.state_dict" if prefix else "state_dict" |
|
|
tensors.update(extract_tensors_from_obj(state_dict, new_prefix)) |
|
|
|
|
|
elif hasattr(obj, '__dict__'): |
|
|
|
|
|
for key, value in obj.__dict__.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
new_prefix = f"{prefix}.{key}" if prefix else key |
|
|
tensors[new_prefix] = value |
|
|
|
|
|
return tensors |
|
|
|
|
|
def inspect_pth_precision_and_size(filepath): |
|
|
""" |
|
|
Reads a .pth file, extracts tensors from it, |
|
|
and reports the precision (dtype), actual size, and theoretical FP32 size. |
|
|
""" |
|
|
if not os.path.exists(filepath): |
|
|
print(f"Error: File not found at '{filepath}'") |
|
|
return |
|
|
|
|
|
try: |
|
|
print(f"Loading PyTorch file: {filepath}") |
|
|
|
|
|
|
|
|
try: |
|
|
obj = torch.load(filepath, map_location="cpu", weights_only=True) |
|
|
print("(Loaded with weights_only=True for security)\n") |
|
|
except TypeError: |
|
|
|
|
|
obj = torch.load(filepath, map_location="cpu") |
|
|
print("(Warning: Loaded without weights_only=True - older PyTorch version)\n") |
|
|
|
|
|
|
|
|
tensors = extract_tensors_from_obj(obj) |
|
|
|
|
|
if not tensors: |
|
|
print("No tensors found in the file.") |
|
|
return |
|
|
|
|
|
tensor_info_list = [] |
|
|
dtype_counts = Counter() |
|
|
total_actual_mb = 0.0 |
|
|
total_fp32_equiv_mb = 0.0 |
|
|
|
|
|
max_key_len = max(len("Tensor Name"), max(len(k) for k in tensors.keys())) |
|
|
|
|
|
header = ( |
|
|
f"{'Tensor Name':<{max_key_len}} | " |
|
|
f"{'Precision (dtype)':<17} | " |
|
|
f"{'Shape':<20} | " |
|
|
f"{'Actual Size (MB)':>16} | " |
|
|
f"{'FP32 Equiv. (MB)':>18}" |
|
|
) |
|
|
print(header) |
|
|
print( |
|
|
f"{'-' * max_key_len}-|-------------------|{'-' * 20}|------------------|-------------------" |
|
|
) |
|
|
|
|
|
for key, tensor in tensors.items(): |
|
|
dtype = tensor.dtype |
|
|
dtype_name = get_dtype_name(dtype) |
|
|
shape = tuple(tensor.shape) |
|
|
shape_str = str(shape) |
|
|
|
|
|
num_elements = tensor.numel() |
|
|
bytes_per_el_actual = get_bytes_per_element(dtype) |
|
|
|
|
|
actual_size_mb_str = "N/A" |
|
|
fp32_equiv_size_mb_str = "N/A" |
|
|
actual_size_mb_val = 0.0 |
|
|
|
|
|
if bytes_per_el_actual is not None: |
|
|
actual_bytes = num_elements * bytes_per_el_actual |
|
|
actual_size_mb_val = actual_bytes / (1024 * 1024) |
|
|
total_actual_mb += actual_size_mb_val |
|
|
actual_size_mb_str = f"{actual_size_mb_val:.3f}" |
|
|
|
|
|
|
|
|
fp32_equiv_bytes = num_elements * 4 |
|
|
fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024) |
|
|
total_fp32_equiv_mb += fp32_equiv_size_mb_val |
|
|
fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}" |
|
|
else: |
|
|
print(f"Warning: Unknown dtype '{dtype}' for tensor '{key}'. Cannot calculate size.") |
|
|
|
|
|
|
|
|
if len(shape_str) > 18: |
|
|
shape_str = shape_str[:15] + "..." |
|
|
|
|
|
print( |
|
|
f"{key:<{max_key_len}} | " |
|
|
f"{dtype_name:<17} | " |
|
|
f"{shape_str:<20} | " |
|
|
f"{actual_size_mb_str:>16} | " |
|
|
f"{fp32_equiv_size_mb_str:>18}" |
|
|
) |
|
|
dtype_counts[dtype_name] += 1 |
|
|
|
|
|
print("\n--- Summary ---") |
|
|
print(f"Total tensors found: {len(tensors)}") |
|
|
if dtype_counts: |
|
|
print("Precision distribution:") |
|
|
for dtype, count in dtype_counts.most_common(): |
|
|
print(f" - {dtype:<12}: {count} tensor(s)") |
|
|
else: |
|
|
print("No dtypes to summarize.") |
|
|
|
|
|
print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB") |
|
|
print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB") |
|
|
|
|
|
if total_fp32_equiv_mb > 0.00001: |
|
|
savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100 |
|
|
print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%") |
|
|
else: |
|
|
print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).") |
|
|
|
|
|
|
|
|
non_tensor_keys = [] |
|
|
if isinstance(obj, dict): |
|
|
for key, value in obj.items(): |
|
|
if key not in [k.split('.')[0] for k in tensors.keys()]: |
|
|
non_tensor_keys.append(f"{key}: {type(value).__name__}") |
|
|
|
|
|
if non_tensor_keys: |
|
|
print(f"\nNon-tensor content found:") |
|
|
for item in non_tensor_keys[:5]: |
|
|
print(f" - {item}") |
|
|
if len(non_tensor_keys) > 5: |
|
|
print(f" ... and {len(non_tensor_keys) - 5} more items") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred while processing '{filepath}':") |
|
|
print(f" {e}") |
|
|
print("Please ensure it's a valid PyTorch .pth file and PyTorch is installed correctly.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Inspect tensor precision (dtype) and size in a PyTorch .pth file." |
|
|
) |
|
|
parser.add_argument( |
|
|
"filepath", |
|
|
help="Path to the .pth file to inspect." |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
inspect_pth_precision_and_size(args.filepath) |