krogoldAI commited on
Commit
196b7f6
·
verified ·
1 Parent(s): 6b15628

Upload Fine-tuning.py

Browse files
Files changed (1) hide show
  1. Code/Fine-tuning.py +307 -0
Code/Fine-tuning.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import torch
4
+ from datasets import load_dataset, Dataset
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ TrainingArguments,
9
+ Trainer,
10
+ DataCollatorForLanguageModeling,
11
+ )
12
+ from huggingface_hub import login
13
+
14
+ ##########
15
+ # CONFIG #
16
+ ##########
17
+
18
+ MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
19
+ DATASET = "dataset/repo"
20
+ OUTPUT_MODEL = "model/repo"
21
+
22
+ # Training hyperparams
23
+ NUM_EPOCHS = 3
24
+ PER_DEVICE_BATCH = 4
25
+ GRADIENT_ACCUMULATION = 4
26
+ LEARNING_RATE = 2e-5
27
+ WEIGHT_DECAY = 0.01
28
+ WARMUP_STEPS = 100
29
+ BF16 = True
30
+ TORCH_COMPILE = False
31
+
32
+ #########
33
+ # LOGIN #
34
+ #########
35
+
36
+ login("<YOUR_HF_TOKEN>")
37
+
38
+ ##################
39
+ # LOAD TOKENIZER #
40
+ ##################
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
43
+ tokenizer.padding_side = "right"
44
+
45
+ ################
46
+ # LOAD DATASET #
47
+ ################
48
+
49
+ raw_ds = load_dataset(DATASET, "default", split="train")
50
+ raw_ds = raw_ds.shuffle(seed=42)
51
+
52
+ # Apply Qwen chat template
53
+ formatted_texts = [
54
+ tokenizer.apply_chat_template(
55
+ conv,
56
+ tokenize=False,
57
+ add_generation_prompt=False
58
+ )
59
+ for conv in raw_ds["text"]
60
+ ]
61
+
62
+ # Build simple dataset
63
+ ds = Dataset.from_dict({"text": formatted_texts})
64
+
65
+ ########################
66
+ # CUSTOM DATA COLLATOR #
67
+ ########################
68
+
69
+ class Qwen25DataCollator(DataCollatorForLanguageModeling):
70
+ def __init__(self, tokenizer, mlm=False):
71
+ super().__init__(tokenizer=tokenizer, mlm=mlm)
72
+ # get token ids robustly (some tokenizers might return [] for encode if token missing)
73
+ try:
74
+ self.im_start_token = tokenizer.encode("<|im_start|>", add_special_tokens=False)[0]
75
+ except Exception:
76
+ self.im_start_token = None
77
+ try:
78
+ self.im_end_token = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
79
+ except Exception:
80
+ self.im_end_token = None
81
+
82
+ # "assistant" token sequence (may be multiple tokens)
83
+ try:
84
+ self.assistant_text = tokenizer.encode("assistant", add_special_tokens=False)
85
+ except Exception:
86
+ self.assistant_text = []
87
+
88
+ # Provide both __call__ and torch_call for compatibility
89
+ def __call__(self, features):
90
+ return self.torch_call(features)
91
+
92
+ def torch_call(self, examples):
93
+ """
94
+ examples: list of dicts returned by tokenization (each example contains 'input_ids', 'attention_mask', etc.)
95
+ We'll leverage the parent to create initial batch and then mask labels for assistant responses only.
96
+ """
97
+ batch = super().torch_call(examples) # returns input_ids, attention_mask, labels (for MLM)
98
+ input_ids = batch["input_ids"]
99
+ labels = batch["labels"]
100
+
101
+ # If special tokens are not present, return default batch unchanged
102
+ if self.im_start_token is None or self.im_end_token is None or len(self.assistant_text) == 0:
103
+ return batch
104
+
105
+ # Iterate examples in batch to mask labels: only assistant response tokens should be supervised
106
+ for i, ids in enumerate(input_ids):
107
+ # Find positions of <|im_start|> and <|im_end|>
108
+ im_start_positions = torch.where(ids == self.im_start_token)[0]
109
+ im_end_positions = torch.where(ids == self.im_end_token)[0]
110
+
111
+ if im_start_positions.numel() == 0 or im_end_positions.numel() == 0:
112
+ # no recognized chat markers: leave labels as-is (or continue)
113
+ continue
114
+
115
+ last_assistant_start = None
116
+ # Find last im_start that is followed by "assistant"
117
+ for start_pos in im_start_positions:
118
+ # check if tokens following start_pos match "assistant"
119
+ as_len = len(self.assistant_text)
120
+ candidate_end = start_pos + 1 + as_len
121
+ if candidate_end <= len(ids):
122
+ segment = ids[start_pos + 1:start_pos + 1 + as_len]
123
+ if torch.equal(segment, torch.tensor(self.assistant_text, device=ids.device)):
124
+ last_assistant_start = int(start_pos)
125
+
126
+ if last_assistant_start is None:
127
+ continue
128
+
129
+ # Find first im_end after last_assistant_start
130
+ assistant_end_positions = im_end_positions[im_end_positions > last_assistant_start]
131
+ if assistant_end_positions.numel() == 0:
132
+ continue
133
+
134
+ assistant_end = int(assistant_end_positions[0])
135
+
136
+ # Response text is between (last_assistant_start + 1 + len("assistant")) and assistant_end - 1 (inclusive),
137
+ # but because template may include a newline or an extra token, we set response_start carefully.
138
+ response_start = last_assistant_start + 1 + len(self.assistant_text)
139
+ # If there's a newline token or separator, skip it if present in input_ids
140
+ # (this is conservative: we do not assume an extra token, but we keep it if present)
141
+ if response_start < len(ids) and ids[response_start] == tokenizer.encode("\n", add_special_tokens=False)[0]:
142
+ response_start += 1
143
+
144
+ # Apply masking:
145
+ # Set everything before response_start to -100 (ignored), preserve response tokens, set rest to -100
146
+ labels[i, :] = -100
147
+ if response_start < len(ids):
148
+ # labels slice up to assistant_end inclusive
149
+ end_idx = min(assistant_end + 1, ids.shape[0])
150
+ labels[i, response_start:end_idx] = ids[response_start:end_idx]
151
+
152
+ # assign modified labels back
153
+ batch["labels"] = labels
154
+ return batch
155
+
156
+ collator = Qwen25DataCollator(tokenizer=tokenizer, mlm=False)
157
+
158
+ ###############################################
159
+ # ANALYZE DATASET LENGTHS TO SET `max_length` #
160
+ ###############################################
161
+
162
+ # We analyze the dataset to optimize the choice of `max_length`
163
+ print("Analyzing dataset to determine max_length (sample up to 1000)...")
164
+ assistant_lengths = []
165
+ full_lengths = []
166
+
167
+ sample_limit = min(1000, len(ds))
168
+ for example in ds["text"][:sample_limit]:
169
+ full_tokens = tokenizer(example, truncation=False, add_special_tokens=True)
170
+ full_lengths.append(len(full_tokens["input_ids"]))
171
+
172
+ # extract the last assistant response via regex pattern
173
+ pattern = r"<\|im_start\|>assistant\n(.*?)<\|im_end\|>"
174
+ matches = re.findall(pattern, example, re.DOTALL)
175
+ if matches:
176
+ last_response = matches[-1]
177
+ resp_tokens = tokenizer(last_response, truncation=False, add_special_tokens=False)
178
+ assistant_lengths.append(len(resp_tokens["input_ids"]))
179
+
180
+ # Basic statistics (guard for empty lists)
181
+ def safe_stat(arr):
182
+ if len(arr) == 0:
183
+ return 0.0, 0.0, 0.0, 0.0
184
+ return np.mean(arr), np.median(arr), np.percentile(arr, 95), np.percentile(arr, 99)
185
+
186
+ mean_ass, med_ass, p95_ass, p99_ass = safe_stat(assistant_lengths)
187
+ mean_full, _, p95_full, _ = safe_stat(full_lengths)
188
+
189
+ print(f"Assistant response mean={mean_ass:.1f}, median={med_ass:.1f}, 95%={p95_ass:.1f}, 99%={p99_ass:.1f}")
190
+ print(f"Full conversation mean={mean_full:.1f}, 95%={p95_full:.1f}")
191
+
192
+ # Round up to nearest power of two but don't exceed tokenizer.model_max_length
193
+ def next_power_of_2(x):
194
+ if x <= 1:
195
+ return 1
196
+ return 2 ** int(np.ceil(np.log2(x)))
197
+
198
+ target_length = int(min(p95_full if p95_full > 0 else tokenizer.model_max_length, tokenizer.model_max_length))
199
+ MAX_LENGTH = next_power_of_2(target_length)
200
+ if MAX_LENGTH > tokenizer.model_max_length:
201
+ MAX_LENGTH = tokenizer.model_max_length
202
+
203
+ print(f"Using MAX_LENGTH = {MAX_LENGTH}")
204
+
205
+ ####################
206
+ # TOKENIZE DATASET #
207
+ ####################
208
+
209
+ def tokenize_function(examples):
210
+ return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False)
211
+
212
+ tokenized_ds = ds.map(tokenize_function, batched=True, remove_columns=ds.column_names)
213
+
214
+ ##############
215
+ # LOAD MODEL #
216
+ ##############
217
+
218
+ # Load model
219
+ model = AutoModelForCausalLM.from_pretrained(
220
+ MODEL_NAME,
221
+ torch_dtype=torch.bfloat16 if BF16 else None,
222
+ device_map="auto",
223
+ attn_implementation="flash_attention_2",
224
+ use_cache=False,
225
+ )
226
+
227
+ try:
228
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen2
229
+ try:
230
+ apply_liger_kernel_to_qwen2(model)
231
+ except TypeError:
232
+ apply_liger_kernel_to_qwen2()
233
+ print("Liger Kernel applied successfully for Qwen2 optimization")
234
+ except Exception:
235
+ print("Liger Kernel not available or failed to apply; continuing without it.")
236
+
237
+ print(f"Model loaded. Parameters: {model.num_parameters() / 1e9:.3f}B")
238
+
239
+ ######################
240
+ # TRAINING ARGUMENTS #
241
+ ######################
242
+
243
+ training_args = TrainingArguments(
244
+ output_dir="./qwen_rephraser_checkpoints",
245
+ num_train_epochs=NUM_EPOCHS,
246
+ per_device_train_batch_size=PER_DEVICE_BATCH,
247
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
248
+ learning_rate=LEARNING_RATE,
249
+ weight_decay=WEIGHT_DECAY,
250
+ warmup_steps=WARMUP_STEPS,
251
+ lr_scheduler_type="cosine",
252
+ logging_steps=10,
253
+ save_steps=500,
254
+ save_total_limit=2,
255
+ bf16=BF16,
256
+ optim="adamw_torch_fused",
257
+ gradient_checkpointing=True,
258
+ report_to="none",
259
+ push_to_hub=False, # we'll push manually at the end
260
+ hub_model_id=OUTPUT_MODEL,
261
+ hub_private_repo=True,
262
+ dataloader_num_workers=4,
263
+ dataloader_pin_memory=True,
264
+ ddp_find_unused_parameters=False,
265
+ torch_compile=TORCH_COMPILE,
266
+ )
267
+
268
+ ###########
269
+ # TRAINER #
270
+ ###########
271
+
272
+ trainer = Trainer(
273
+ model=model,
274
+ args=training_args,
275
+ train_dataset=tokenized_ds,
276
+ data_collator=collator,
277
+ )
278
+
279
+ #########
280
+ # TRAIN #
281
+ #########
282
+
283
+ print("Starting training...")
284
+ trainer.train()
285
+
286
+ ####################
287
+ # SAVE FINAL MODEL #
288
+ ####################
289
+
290
+ print("Saving model to ./final_model ...")
291
+ model.config.use_cache = True
292
+ trainer.save_model("./final_model")
293
+ tokenizer.save_pretrained("./final_model")
294
+
295
+ ##################
296
+ # PUSHING TO HUB #
297
+ ##################
298
+
299
+ try:
300
+ print(f"Pushing model and tokenizer to the hub as {OUTPUT_MODEL} (private)...")
301
+ model.push_to_hub(OUTPUT_MODEL, private=True)
302
+ tokenizer.push_to_hub(OUTPUT_MODEL, private=True)
303
+ print("Push completed.")
304
+ except Exception as e:
305
+ print("Warning: push_to_hub failed:", e)
306
+
307
+ print("Training complete!")