import argparse import torch from collections import Counter import os import math # --- PyTorch Dtype to Bytes Mapping --- TORCH_DTYPE_TO_BYTES = { # Boolean torch.bool: 1, # Floating point torch.float16: 2, torch.half: 2, # alias for float16 torch.bfloat16: 2, torch.float32: 4, torch.float: 4, # alias for float32 torch.float64: 8, torch.double: 8, # alias for float64 # Complex torch.complex64: 8, # 2 * float32 torch.complex128: 16, # 2 * float64 torch.cfloat: 8, # alias for complex64 torch.cdouble: 16, # alias for complex128 # Signed integers torch.int8: 1, torch.int16: 2, torch.short: 2, # alias for int16 torch.int32: 4, torch.int: 4, # alias for int32 torch.int64: 8, torch.long: 8, # alias for int64 # Unsigned integers torch.uint8: 1, torch.uint16: 2, torch.uint32: 4, torch.uint64: 8, # Quantized types (approximate sizes) torch.qint8: 1, torch.quint8: 1, torch.qint32: 4, torch.quint4x2: 1, # 4-bit packed } def get_bytes_per_element(dtype): """Returns the number of bytes for a given PyTorch dtype.""" return TORCH_DTYPE_TO_BYTES.get(dtype, None) def get_dtype_name(dtype): """Returns a readable string for a PyTorch dtype.""" return str(dtype).replace('torch.', '') def calculate_num_elements(shape): """Calculates the total number of elements from a tensor shape tuple.""" if not shape: # Scalar tensor (shape is ()) return 1 if 0 in shape: # If any dimension is 0, total elements is 0 return 0 num_elements = 1 for dim_size in shape: num_elements *= dim_size return num_elements def extract_tensors_from_obj(obj, prefix=""): """ Recursively extracts tensors from nested dictionaries/objects. Returns a dictionary of {key: tensor} pairs. """ tensors = {} if isinstance(obj, torch.Tensor): return {prefix or "tensor": obj} elif isinstance(obj, dict): for key, value in obj.items(): new_prefix = f"{prefix}.{key}" if prefix else key tensors.update(extract_tensors_from_obj(value, new_prefix)) elif hasattr(obj, 'state_dict') and callable(getattr(obj, 'state_dict')): # Handle nn.Module objects state_dict = obj.state_dict() new_prefix = f"{prefix}.state_dict" if prefix else "state_dict" tensors.update(extract_tensors_from_obj(state_dict, new_prefix)) elif hasattr(obj, '__dict__'): # Handle other objects with attributes for key, value in obj.__dict__.items(): if isinstance(value, torch.Tensor): new_prefix = f"{prefix}.{key}" if prefix else key tensors[new_prefix] = value return tensors def inspect_pth_precision_and_size(filepath): """ Reads a .pth file, extracts tensors from it, and reports the precision (dtype), actual size, and theoretical FP32 size. """ if not os.path.exists(filepath): print(f"Error: File not found at '{filepath}'") return try: print(f"Loading PyTorch file: {filepath}") # Load with weights_only=True for security if PyTorch >= 2.0 try: obj = torch.load(filepath, map_location="cpu", weights_only=True) print("(Loaded with weights_only=True for security)\n") except TypeError: # Fallback for older PyTorch versions obj = torch.load(filepath, map_location="cpu") print("(Warning: Loaded without weights_only=True - older PyTorch version)\n") # Extract all tensors from the loaded object tensors = extract_tensors_from_obj(obj) if not tensors: print("No tensors found in the file.") return tensor_info_list = [] dtype_counts = Counter() total_actual_mb = 0.0 total_fp32_equiv_mb = 0.0 max_key_len = max(len("Tensor Name"), max(len(k) for k in tensors.keys())) header = ( f"{'Tensor Name':<{max_key_len}} | " f"{'Precision (dtype)':<17} | " f"{'Shape':<20} | " f"{'Actual Size (MB)':>16} | " f"{'FP32 Equiv. (MB)':>18}" ) print(header) print( f"{'-' * max_key_len}-|-------------------|{'-' * 20}|------------------|-------------------" ) for key, tensor in tensors.items(): dtype = tensor.dtype dtype_name = get_dtype_name(dtype) shape = tuple(tensor.shape) shape_str = str(shape) num_elements = tensor.numel() bytes_per_el_actual = get_bytes_per_element(dtype) actual_size_mb_str = "N/A" fp32_equiv_size_mb_str = "N/A" actual_size_mb_val = 0.0 if bytes_per_el_actual is not None: actual_bytes = num_elements * bytes_per_el_actual actual_size_mb_val = actual_bytes / (1024 * 1024) total_actual_mb += actual_size_mb_val actual_size_mb_str = f"{actual_size_mb_val:.3f}" # Theoretical FP32 size (FP32 is 4 bytes per element) fp32_equiv_bytes = num_elements * 4 fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024) total_fp32_equiv_mb += fp32_equiv_size_mb_val fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}" else: print(f"Warning: Unknown dtype '{dtype}' for tensor '{key}'. Cannot calculate size.") # Truncate shape string if too long if len(shape_str) > 18: shape_str = shape_str[:15] + "..." print( f"{key:<{max_key_len}} | " f"{dtype_name:<17} | " f"{shape_str:<20} | " f"{actual_size_mb_str:>16} | " f"{fp32_equiv_size_mb_str:>18}" ) dtype_counts[dtype_name] += 1 print("\n--- Summary ---") print(f"Total tensors found: {len(tensors)}") if dtype_counts: print("Precision distribution:") for dtype, count in dtype_counts.most_common(): print(f" - {dtype:<12}: {count} tensor(s)") else: print("No dtypes to summarize.") print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB") print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB") if total_fp32_equiv_mb > 0.00001: # Avoid division by zero savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100 print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%") else: print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).") # Additional info about non-tensor content non_tensor_keys = [] if isinstance(obj, dict): for key, value in obj.items(): if key not in [k.split('.')[0] for k in tensors.keys()]: # Simplified check non_tensor_keys.append(f"{key}: {type(value).__name__}") if non_tensor_keys: print(f"\nNon-tensor content found:") for item in non_tensor_keys[:5]: # Show first 5 print(f" - {item}") if len(non_tensor_keys) > 5: print(f" ... and {len(non_tensor_keys) - 5} more items") except Exception as e: print(f"An error occurred while processing '{filepath}':") print(f" {e}") print("Please ensure it's a valid PyTorch .pth file and PyTorch is installed correctly.") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Inspect tensor precision (dtype) and size in a PyTorch .pth file." ) parser.add_argument( "filepath", help="Path to the .pth file to inspect." ) args = parser.parse_args() inspect_pth_precision_and_size(args.filepath)