File size: 8,149 Bytes
0873e19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import argparse
import torch
from collections import Counter
import os
import math
# --- PyTorch Dtype to Bytes Mapping ---
TORCH_DTYPE_TO_BYTES = {
# Boolean
torch.bool: 1,
# Floating point
torch.float16: 2,
torch.half: 2, # alias for float16
torch.bfloat16: 2,
torch.float32: 4,
torch.float: 4, # alias for float32
torch.float64: 8,
torch.double: 8, # alias for float64
# Complex
torch.complex64: 8, # 2 * float32
torch.complex128: 16, # 2 * float64
torch.cfloat: 8, # alias for complex64
torch.cdouble: 16, # alias for complex128
# Signed integers
torch.int8: 1,
torch.int16: 2,
torch.short: 2, # alias for int16
torch.int32: 4,
torch.int: 4, # alias for int32
torch.int64: 8,
torch.long: 8, # alias for int64
# Unsigned integers
torch.uint8: 1,
torch.uint16: 2,
torch.uint32: 4,
torch.uint64: 8,
# Quantized types (approximate sizes)
torch.qint8: 1,
torch.quint8: 1,
torch.qint32: 4,
torch.quint4x2: 1, # 4-bit packed
}
def get_bytes_per_element(dtype):
"""Returns the number of bytes for a given PyTorch dtype."""
return TORCH_DTYPE_TO_BYTES.get(dtype, None)
def get_dtype_name(dtype):
"""Returns a readable string for a PyTorch dtype."""
return str(dtype).replace('torch.', '')
def calculate_num_elements(shape):
"""Calculates the total number of elements from a tensor shape tuple."""
if not shape: # Scalar tensor (shape is ())
return 1
if 0 in shape: # If any dimension is 0, total elements is 0
return 0
num_elements = 1
for dim_size in shape:
num_elements *= dim_size
return num_elements
def extract_tensors_from_obj(obj, prefix=""):
"""
Recursively extracts tensors from nested dictionaries/objects.
Returns a dictionary of {key: tensor} pairs.
"""
tensors = {}
if isinstance(obj, torch.Tensor):
return {prefix or "tensor": obj}
elif isinstance(obj, dict):
for key, value in obj.items():
new_prefix = f"{prefix}.{key}" if prefix else key
tensors.update(extract_tensors_from_obj(value, new_prefix))
elif hasattr(obj, 'state_dict') and callable(getattr(obj, 'state_dict')):
# Handle nn.Module objects
state_dict = obj.state_dict()
new_prefix = f"{prefix}.state_dict" if prefix else "state_dict"
tensors.update(extract_tensors_from_obj(state_dict, new_prefix))
elif hasattr(obj, '__dict__'):
# Handle other objects with attributes
for key, value in obj.__dict__.items():
if isinstance(value, torch.Tensor):
new_prefix = f"{prefix}.{key}" if prefix else key
tensors[new_prefix] = value
return tensors
def inspect_pth_precision_and_size(filepath):
"""
Reads a .pth file, extracts tensors from it,
and reports the precision (dtype), actual size, and theoretical FP32 size.
"""
if not os.path.exists(filepath):
print(f"Error: File not found at '{filepath}'")
return
try:
print(f"Loading PyTorch file: {filepath}")
# Load with weights_only=True for security if PyTorch >= 2.0
try:
obj = torch.load(filepath, map_location="cpu", weights_only=True)
print("(Loaded with weights_only=True for security)\n")
except TypeError:
# Fallback for older PyTorch versions
obj = torch.load(filepath, map_location="cpu")
print("(Warning: Loaded without weights_only=True - older PyTorch version)\n")
# Extract all tensors from the loaded object
tensors = extract_tensors_from_obj(obj)
if not tensors:
print("No tensors found in the file.")
return
tensor_info_list = []
dtype_counts = Counter()
total_actual_mb = 0.0
total_fp32_equiv_mb = 0.0
max_key_len = max(len("Tensor Name"), max(len(k) for k in tensors.keys()))
header = (
f"{'Tensor Name':<{max_key_len}} | "
f"{'Precision (dtype)':<17} | "
f"{'Shape':<20} | "
f"{'Actual Size (MB)':>16} | "
f"{'FP32 Equiv. (MB)':>18}"
)
print(header)
print(
f"{'-' * max_key_len}-|-------------------|{'-' * 20}|------------------|-------------------"
)
for key, tensor in tensors.items():
dtype = tensor.dtype
dtype_name = get_dtype_name(dtype)
shape = tuple(tensor.shape)
shape_str = str(shape)
num_elements = tensor.numel()
bytes_per_el_actual = get_bytes_per_element(dtype)
actual_size_mb_str = "N/A"
fp32_equiv_size_mb_str = "N/A"
actual_size_mb_val = 0.0
if bytes_per_el_actual is not None:
actual_bytes = num_elements * bytes_per_el_actual
actual_size_mb_val = actual_bytes / (1024 * 1024)
total_actual_mb += actual_size_mb_val
actual_size_mb_str = f"{actual_size_mb_val:.3f}"
# Theoretical FP32 size (FP32 is 4 bytes per element)
fp32_equiv_bytes = num_elements * 4
fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024)
total_fp32_equiv_mb += fp32_equiv_size_mb_val
fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}"
else:
print(f"Warning: Unknown dtype '{dtype}' for tensor '{key}'. Cannot calculate size.")
# Truncate shape string if too long
if len(shape_str) > 18:
shape_str = shape_str[:15] + "..."
print(
f"{key:<{max_key_len}} | "
f"{dtype_name:<17} | "
f"{shape_str:<20} | "
f"{actual_size_mb_str:>16} | "
f"{fp32_equiv_size_mb_str:>18}"
)
dtype_counts[dtype_name] += 1
print("\n--- Summary ---")
print(f"Total tensors found: {len(tensors)}")
if dtype_counts:
print("Precision distribution:")
for dtype, count in dtype_counts.most_common():
print(f" - {dtype:<12}: {count} tensor(s)")
else:
print("No dtypes to summarize.")
print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB")
print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB")
if total_fp32_equiv_mb > 0.00001: # Avoid division by zero
savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100
print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%")
else:
print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).")
# Additional info about non-tensor content
non_tensor_keys = []
if isinstance(obj, dict):
for key, value in obj.items():
if key not in [k.split('.')[0] for k in tensors.keys()]: # Simplified check
non_tensor_keys.append(f"{key}: {type(value).__name__}")
if non_tensor_keys:
print(f"\nNon-tensor content found:")
for item in non_tensor_keys[:5]: # Show first 5
print(f" - {item}")
if len(non_tensor_keys) > 5:
print(f" ... and {len(non_tensor_keys) - 5} more items")
except Exception as e:
print(f"An error occurred while processing '{filepath}':")
print(f" {e}")
print("Please ensure it's a valid PyTorch .pth file and PyTorch is installed correctly.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Inspect tensor precision (dtype) and size in a PyTorch .pth file."
)
parser.add_argument(
"filepath",
help="Path to the .pth file to inspect."
)
args = parser.parse_args()
inspect_pth_precision_and_size(args.filepath) |