File size: 8,149 Bytes
0873e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import argparse
import torch
from collections import Counter
import os
import math

# --- PyTorch Dtype to Bytes Mapping ---
TORCH_DTYPE_TO_BYTES = {
    # Boolean
    torch.bool: 1,
    # Floating point
    torch.float16: 2,
    torch.half: 2,  # alias for float16
    torch.bfloat16: 2,
    torch.float32: 4,
    torch.float: 4,  # alias for float32
    torch.float64: 8,
    torch.double: 8,  # alias for float64
    # Complex
    torch.complex64: 8,   # 2 * float32
    torch.complex128: 16, # 2 * float64
    torch.cfloat: 8,      # alias for complex64
    torch.cdouble: 16,    # alias for complex128
    # Signed integers
    torch.int8: 1,
    torch.int16: 2,
    torch.short: 2,  # alias for int16
    torch.int32: 4,
    torch.int: 4,    # alias for int32
    torch.int64: 8,
    torch.long: 8,   # alias for int64
    # Unsigned integers
    torch.uint8: 1,
    torch.uint16: 2,
    torch.uint32: 4,
    torch.uint64: 8,
    # Quantized types (approximate sizes)
    torch.qint8: 1,
    torch.quint8: 1,
    torch.qint32: 4,
    torch.quint4x2: 1,  # 4-bit packed
}

def get_bytes_per_element(dtype):
    """Returns the number of bytes for a given PyTorch dtype."""
    return TORCH_DTYPE_TO_BYTES.get(dtype, None)

def get_dtype_name(dtype):
    """Returns a readable string for a PyTorch dtype."""
    return str(dtype).replace('torch.', '')

def calculate_num_elements(shape):
    """Calculates the total number of elements from a tensor shape tuple."""
    if not shape:  # Scalar tensor (shape is ())
        return 1
    if 0 in shape:  # If any dimension is 0, total elements is 0
        return 0
    num_elements = 1
    for dim_size in shape:
        num_elements *= dim_size
    return num_elements

def extract_tensors_from_obj(obj, prefix=""):
    """
    Recursively extracts tensors from nested dictionaries/objects.
    Returns a dictionary of {key: tensor} pairs.
    """
    tensors = {}
    
    if isinstance(obj, torch.Tensor):
        return {prefix or "tensor": obj}
    
    elif isinstance(obj, dict):
        for key, value in obj.items():
            new_prefix = f"{prefix}.{key}" if prefix else key
            tensors.update(extract_tensors_from_obj(value, new_prefix))
    
    elif hasattr(obj, 'state_dict') and callable(getattr(obj, 'state_dict')):
        # Handle nn.Module objects
        state_dict = obj.state_dict()
        new_prefix = f"{prefix}.state_dict" if prefix else "state_dict"
        tensors.update(extract_tensors_from_obj(state_dict, new_prefix))
    
    elif hasattr(obj, '__dict__'):
        # Handle other objects with attributes
        for key, value in obj.__dict__.items():
            if isinstance(value, torch.Tensor):
                new_prefix = f"{prefix}.{key}" if prefix else key
                tensors[new_prefix] = value
    
    return tensors

def inspect_pth_precision_and_size(filepath):
    """
    Reads a .pth file, extracts tensors from it,
    and reports the precision (dtype), actual size, and theoretical FP32 size.
    """
    if not os.path.exists(filepath):
        print(f"Error: File not found at '{filepath}'")
        return

    try:
        print(f"Loading PyTorch file: {filepath}")
        
        # Load with weights_only=True for security if PyTorch >= 2.0
        try:
            obj = torch.load(filepath, map_location="cpu", weights_only=True)
            print("(Loaded with weights_only=True for security)\n")
        except TypeError:
            # Fallback for older PyTorch versions
            obj = torch.load(filepath, map_location="cpu")
            print("(Warning: Loaded without weights_only=True - older PyTorch version)\n")
        
        # Extract all tensors from the loaded object
        tensors = extract_tensors_from_obj(obj)
        
        if not tensors:
            print("No tensors found in the file.")
            return

        tensor_info_list = []
        dtype_counts = Counter()
        total_actual_mb = 0.0
        total_fp32_equiv_mb = 0.0

        max_key_len = max(len("Tensor Name"), max(len(k) for k in tensors.keys()))

        header = (
            f"{'Tensor Name':<{max_key_len}} | "
            f"{'Precision (dtype)':<17} | "
            f"{'Shape':<20} | "
            f"{'Actual Size (MB)':>16} | "
            f"{'FP32 Equiv. (MB)':>18}"
        )
        print(header)
        print(
            f"{'-' * max_key_len}-|-------------------|{'-' * 20}|------------------|-------------------"
        )

        for key, tensor in tensors.items():
            dtype = tensor.dtype
            dtype_name = get_dtype_name(dtype)
            shape = tuple(tensor.shape)
            shape_str = str(shape)
            
            num_elements = tensor.numel()
            bytes_per_el_actual = get_bytes_per_element(dtype)

            actual_size_mb_str = "N/A"
            fp32_equiv_size_mb_str = "N/A"
            actual_size_mb_val = 0.0

            if bytes_per_el_actual is not None:
                actual_bytes = num_elements * bytes_per_el_actual
                actual_size_mb_val = actual_bytes / (1024 * 1024)
                total_actual_mb += actual_size_mb_val
                actual_size_mb_str = f"{actual_size_mb_val:.3f}"

                # Theoretical FP32 size (FP32 is 4 bytes per element)
                fp32_equiv_bytes = num_elements * 4
                fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024)
                total_fp32_equiv_mb += fp32_equiv_size_mb_val
                fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}"
            else:
                print(f"Warning: Unknown dtype '{dtype}' for tensor '{key}'. Cannot calculate size.")

            # Truncate shape string if too long
            if len(shape_str) > 18:
                shape_str = shape_str[:15] + "..."

            print(
                f"{key:<{max_key_len}} | "
                f"{dtype_name:<17} | "
                f"{shape_str:<20} | "
                f"{actual_size_mb_str:>16} | "
                f"{fp32_equiv_size_mb_str:>18}"
            )
            dtype_counts[dtype_name] += 1

        print("\n--- Summary ---")
        print(f"Total tensors found: {len(tensors)}")
        if dtype_counts:
            print("Precision distribution:")
            for dtype, count in dtype_counts.most_common():
                print(f"  - {dtype:<12}: {count} tensor(s)")
        else:
            print("No dtypes to summarize.")

        print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB")
        print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB")

        if total_fp32_equiv_mb > 0.00001:  # Avoid division by zero
            savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100
            print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%")
        else:
            print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).")

        # Additional info about non-tensor content
        non_tensor_keys = []
        if isinstance(obj, dict):
            for key, value in obj.items():
                if key not in [k.split('.')[0] for k in tensors.keys()]:  # Simplified check
                    non_tensor_keys.append(f"{key}: {type(value).__name__}")
        
        if non_tensor_keys:
            print(f"\nNon-tensor content found:")
            for item in non_tensor_keys[:5]:  # Show first 5
                print(f"  - {item}")
            if len(non_tensor_keys) > 5:
                print(f"  ... and {len(non_tensor_keys) - 5} more items")

    except Exception as e:
        print(f"An error occurred while processing '{filepath}':")
        print(f"  {e}")
        print("Please ensure it's a valid PyTorch .pth file and PyTorch is installed correctly.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Inspect tensor precision (dtype) and size in a PyTorch .pth file."
    )
    parser.add_argument(
        "filepath",
        help="Path to the .pth file to inspect."
    )
    args = parser.parse_args()

    inspect_pth_precision_and_size(args.filepath)