hunterbown commited on
Commit
4b16fff
·
verified ·
1 Parent(s): 0f382f5

Upload scripts/eval_bpt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/eval_bpt.py +232 -0
scripts/eval_bpt.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Evaluate base model vs SCU adapter on BPT and perplexity."""
3
+
4
+ import os
5
+ import sys
6
+ import argparse
7
+ import math
8
+ import json
9
+ import random
10
+ import statistics as stats
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
15
+ from peft import PeftModel
16
+
17
+ # Add parent dir to path
18
+ sys.path.append(str(Path(__file__).parent.parent))
19
+ from scu import data
20
+
21
+
22
+ def bpt_for_texts(model, tokenizer, texts, max_len=512, device=None):
23
+ """Calculate BPT for each text.
24
+
25
+ Returns list of BPT values (one per text).
26
+ """
27
+ model.eval()
28
+ bpts = []
29
+
30
+ for text in texts:
31
+ # Tokenize
32
+ enc = tokenizer(
33
+ text,
34
+ return_tensors="pt",
35
+ truncation=True,
36
+ max_length=max_len,
37
+ padding=False
38
+ )
39
+
40
+ # Move to device
41
+ enc = {k: v.to(device or model.device) for k, v in enc.items()}
42
+
43
+ # Labels are same as inputs
44
+ labels = enc["input_ids"].clone()
45
+
46
+ # Forward pass
47
+ with torch.no_grad():
48
+ outputs = model(**enc, labels=labels)
49
+ # Convert from nats to bits
50
+ bpt = outputs.loss.item() / math.log(2)
51
+ bpts.append(bpt)
52
+
53
+ return bpts
54
+
55
+
56
+ def bootstrap_ci(delta_list, iters=10000, seed=42):
57
+ """Bootstrap confidence interval for mean difference.
58
+
59
+ Returns (lower_95, mean, upper_95)
60
+ """
61
+ random.seed(seed)
62
+ means = []
63
+ n = len(delta_list)
64
+
65
+ for _ in range(iters):
66
+ # Resample with replacement
67
+ sample = [delta_list[random.randrange(n)] for _ in range(n)]
68
+ means.append(stats.mean(sample))
69
+
70
+ means.sort()
71
+ lower = means[int(0.025 * iters)]
72
+ upper = means[int(0.975 * iters)]
73
+ mean_val = stats.mean(delta_list)
74
+
75
+ return lower, mean_val, upper
76
+
77
+
78
+ def main(args):
79
+ # Suppress tokenizer warnings
80
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
81
+
82
+ # Setup device and dtype
83
+ if torch.cuda.is_available():
84
+ device = "cuda"
85
+ dtype = torch.float16
86
+ use_4bit = not args.no_4bit
87
+ elif torch.backends.mps.is_available():
88
+ device = "mps"
89
+ dtype = torch.float32
90
+ use_4bit = False
91
+ else:
92
+ device = "cpu"
93
+ dtype = torch.float32
94
+ use_4bit = False
95
+ print("WARNING: Using CPU - evaluation will be slow")
96
+
97
+ # Quantization config
98
+ quantization_config = None
99
+ if use_4bit and device == "cuda":
100
+ quantization_config = BitsAndBytesConfig(
101
+ load_in_4bit=True,
102
+ bnb_4bit_compute_dtype=dtype,
103
+ bnb_4bit_quant_type="nf4",
104
+ bnb_4bit_use_double_quant=True
105
+ )
106
+
107
+ # Load base model
108
+ print(f"Loading base model: {args.base_model}")
109
+ base_model = AutoModelForCausalLM.from_pretrained(
110
+ args.base_model,
111
+ quantization_config=quantization_config,
112
+ torch_dtype=dtype,
113
+ device_map="auto" if device != "cpu" else None,
114
+ trust_remote_code=True
115
+ )
116
+
117
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
118
+ if tokenizer.pad_token is None:
119
+ tokenizer.pad_token = tokenizer.eos_token
120
+
121
+ # Load validation texts
122
+ print(f"Loading validation texts from {args.texts}")
123
+ val_texts = data.load_texts_from_file(args.texts, max_texts=args.max_texts)
124
+ print(f"Loaded {len(val_texts)} texts")
125
+
126
+ # Evaluate base model
127
+ print("\nEvaluating base model...")
128
+ base_bpts = bpt_for_texts(base_model, tokenizer, val_texts, max_len=args.max_length, device=device)
129
+ base_mean_bpt = stats.mean(base_bpts)
130
+ base_perplexity = 2 ** base_mean_bpt
131
+
132
+ # Load adapter model if provided
133
+ if args.adapter_path:
134
+ print(f"\nLoading SCU adapter from {args.adapter_path}")
135
+ scu_model = PeftModel.from_pretrained(base_model, args.adapter_path)
136
+ scu_model.eval()
137
+
138
+ # Evaluate SCU model
139
+ print("Evaluating SCU model...")
140
+ scu_bpts = bpt_for_texts(scu_model, tokenizer, val_texts, max_len=args.max_length, device=device)
141
+ scu_mean_bpt = stats.mean(scu_bpts)
142
+ scu_perplexity = 2 ** scu_mean_bpt
143
+
144
+ # Calculate differences
145
+ delta_bpts = [b - s for b, s in zip(base_bpts, scu_bpts)]
146
+ delta_mean = stats.mean(delta_bpts)
147
+
148
+ # Bootstrap CI
149
+ if args.bootstrap:
150
+ print("\nCalculating bootstrap confidence interval...")
151
+ ci_lower, ci_mean, ci_upper = bootstrap_ci(delta_bpts, iters=args.bootstrap_iters)
152
+ else:
153
+ ci_lower = ci_mean = ci_upper = delta_mean
154
+
155
+ # Print results
156
+ print("\n" + "="*60)
157
+ print("EVALUATION RESULTS")
158
+ print("="*60)
159
+ print(f"Base Model: {base_mean_bpt:.3f} BPT (ppl {base_perplexity:.2f})")
160
+ print(f"SCU Model: {scu_mean_bpt:.3f} BPT (ppl {scu_perplexity:.2f})")
161
+ print(f"Improvement: {delta_mean:.3f} BPT ({100*delta_mean/base_mean_bpt:.1f}%)")
162
+ print(f"Perplexity: -{100*(1 - scu_perplexity/base_perplexity):.1f}%")
163
+
164
+ if args.bootstrap:
165
+ print(f"\nBootstrap 95% CI: [{ci_lower:.3f}, {ci_upper:.3f}]")
166
+ if ci_lower > 0:
167
+ print("✓ CI excludes zero - improvement is statistically significant")
168
+ else:
169
+ print("✗ CI includes zero - improvement not statistically significant")
170
+
171
+ # Save results if requested
172
+ if args.output:
173
+ results = {
174
+ 'base_model': args.base_model,
175
+ 'adapter_path': args.adapter_path,
176
+ 'num_texts': len(val_texts),
177
+ 'base_bpt': base_mean_bpt,
178
+ 'scu_bpt': scu_mean_bpt,
179
+ 'delta_bpt': delta_mean,
180
+ 'delta_bpt_percent': 100 * delta_mean / base_mean_bpt,
181
+ 'base_perplexity': base_perplexity,
182
+ 'scu_perplexity': scu_perplexity,
183
+ 'perplexity_reduction': 100 * (1 - scu_perplexity/base_perplexity),
184
+ 'ci_lower': ci_lower,
185
+ 'ci_mean': ci_mean,
186
+ 'ci_upper': ci_upper,
187
+ 'individual_base_bpts': base_bpts,
188
+ 'individual_scu_bpts': scu_bpts
189
+ }
190
+
191
+ output_path = Path(args.output)
192
+ output_path.parent.mkdir(parents=True, exist_ok=True)
193
+
194
+ with open(output_path, 'w') as f:
195
+ json.dump(results, f, indent=2)
196
+
197
+ print(f"\nResults saved to {args.output}")
198
+
199
+ else:
200
+ # Base model only
201
+ print("\n" + "="*60)
202
+ print("BASE MODEL RESULTS")
203
+ print("="*60)
204
+ print(f"BPT: {base_mean_bpt:.3f}")
205
+ print(f"Perplexity: {base_perplexity:.2f}")
206
+ print(f"Texts: {len(val_texts)}")
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = argparse.ArgumentParser(description="Evaluate BPT and perplexity")
211
+
212
+ parser.add_argument("--base_model", default="meta-llama/Llama-3.2-1B",
213
+ help="Base model name")
214
+ parser.add_argument("--adapter_path", default=None,
215
+ help="Path to SCU adapter (optional)")
216
+ parser.add_argument("--texts", default="data/val.txt",
217
+ help="Validation texts file")
218
+ parser.add_argument("--max_texts", type=int, default=None,
219
+ help="Maximum texts to evaluate")
220
+ parser.add_argument("--max_length", type=int, default=512,
221
+ help="Maximum sequence length")
222
+ parser.add_argument("--no_4bit", action="store_true",
223
+ help="Disable 4-bit quantization")
224
+ parser.add_argument("--bootstrap", action="store_true",
225
+ help="Calculate bootstrap CI")
226
+ parser.add_argument("--bootstrap_iters", type=int, default=10000,
227
+ help="Bootstrap iterations")
228
+ parser.add_argument("--output", default=None,
229
+ help="Output JSON file for results")
230
+
231
+ args = parser.parse_args()
232
+ main(args)