krogoldAI's picture
Upload Fine-tuning.py
196b7f6 verified
import re
import numpy as np
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from huggingface_hub import login
##########
# CONFIG #
##########
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
DATASET = "dataset/repo"
OUTPUT_MODEL = "model/repo"
# Training hyperparams
NUM_EPOCHS = 3
PER_DEVICE_BATCH = 4
GRADIENT_ACCUMULATION = 4
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 100
BF16 = True
TORCH_COMPILE = False
#########
# LOGIN #
#########
login("<YOUR_HF_TOKEN>")
##################
# LOAD TOKENIZER #
##################
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = "right"
################
# LOAD DATASET #
################
raw_ds = load_dataset(DATASET, "default", split="train")
raw_ds = raw_ds.shuffle(seed=42)
# Apply Qwen chat template
formatted_texts = [
tokenizer.apply_chat_template(
conv,
tokenize=False,
add_generation_prompt=False
)
for conv in raw_ds["text"]
]
# Build simple dataset
ds = Dataset.from_dict({"text": formatted_texts})
########################
# CUSTOM DATA COLLATOR #
########################
class Qwen25DataCollator(DataCollatorForLanguageModeling):
def __init__(self, tokenizer, mlm=False):
super().__init__(tokenizer=tokenizer, mlm=mlm)
# get token ids robustly (some tokenizers might return [] for encode if token missing)
try:
self.im_start_token = tokenizer.encode("<|im_start|>", add_special_tokens=False)[0]
except Exception:
self.im_start_token = None
try:
self.im_end_token = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
except Exception:
self.im_end_token = None
# "assistant" token sequence (may be multiple tokens)
try:
self.assistant_text = tokenizer.encode("assistant", add_special_tokens=False)
except Exception:
self.assistant_text = []
# Provide both __call__ and torch_call for compatibility
def __call__(self, features):
return self.torch_call(features)
def torch_call(self, examples):
"""
examples: list of dicts returned by tokenization (each example contains 'input_ids', 'attention_mask', etc.)
We'll leverage the parent to create initial batch and then mask labels for assistant responses only.
"""
batch = super().torch_call(examples) # returns input_ids, attention_mask, labels (for MLM)
input_ids = batch["input_ids"]
labels = batch["labels"]
# If special tokens are not present, return default batch unchanged
if self.im_start_token is None or self.im_end_token is None or len(self.assistant_text) == 0:
return batch
# Iterate examples in batch to mask labels: only assistant response tokens should be supervised
for i, ids in enumerate(input_ids):
# Find positions of <|im_start|> and <|im_end|>
im_start_positions = torch.where(ids == self.im_start_token)[0]
im_end_positions = torch.where(ids == self.im_end_token)[0]
if im_start_positions.numel() == 0 or im_end_positions.numel() == 0:
# no recognized chat markers: leave labels as-is (or continue)
continue
last_assistant_start = None
# Find last im_start that is followed by "assistant"
for start_pos in im_start_positions:
# check if tokens following start_pos match "assistant"
as_len = len(self.assistant_text)
candidate_end = start_pos + 1 + as_len
if candidate_end <= len(ids):
segment = ids[start_pos + 1:start_pos + 1 + as_len]
if torch.equal(segment, torch.tensor(self.assistant_text, device=ids.device)):
last_assistant_start = int(start_pos)
if last_assistant_start is None:
continue
# Find first im_end after last_assistant_start
assistant_end_positions = im_end_positions[im_end_positions > last_assistant_start]
if assistant_end_positions.numel() == 0:
continue
assistant_end = int(assistant_end_positions[0])
# Response text is between (last_assistant_start + 1 + len("assistant")) and assistant_end - 1 (inclusive),
# but because template may include a newline or an extra token, we set response_start carefully.
response_start = last_assistant_start + 1 + len(self.assistant_text)
# If there's a newline token or separator, skip it if present in input_ids
# (this is conservative: we do not assume an extra token, but we keep it if present)
if response_start < len(ids) and ids[response_start] == tokenizer.encode("\n", add_special_tokens=False)[0]:
response_start += 1
# Apply masking:
# Set everything before response_start to -100 (ignored), preserve response tokens, set rest to -100
labels[i, :] = -100
if response_start < len(ids):
# labels slice up to assistant_end inclusive
end_idx = min(assistant_end + 1, ids.shape[0])
labels[i, response_start:end_idx] = ids[response_start:end_idx]
# assign modified labels back
batch["labels"] = labels
return batch
collator = Qwen25DataCollator(tokenizer=tokenizer, mlm=False)
###############################################
# ANALYZE DATASET LENGTHS TO SET `max_length` #
###############################################
# We analyze the dataset to optimize the choice of `max_length`
print("Analyzing dataset to determine max_length (sample up to 1000)...")
assistant_lengths = []
full_lengths = []
sample_limit = min(1000, len(ds))
for example in ds["text"][:sample_limit]:
full_tokens = tokenizer(example, truncation=False, add_special_tokens=True)
full_lengths.append(len(full_tokens["input_ids"]))
# extract the last assistant response via regex pattern
pattern = r"<\|im_start\|>assistant\n(.*?)<\|im_end\|>"
matches = re.findall(pattern, example, re.DOTALL)
if matches:
last_response = matches[-1]
resp_tokens = tokenizer(last_response, truncation=False, add_special_tokens=False)
assistant_lengths.append(len(resp_tokens["input_ids"]))
# Basic statistics (guard for empty lists)
def safe_stat(arr):
if len(arr) == 0:
return 0.0, 0.0, 0.0, 0.0
return np.mean(arr), np.median(arr), np.percentile(arr, 95), np.percentile(arr, 99)
mean_ass, med_ass, p95_ass, p99_ass = safe_stat(assistant_lengths)
mean_full, _, p95_full, _ = safe_stat(full_lengths)
print(f"Assistant response mean={mean_ass:.1f}, median={med_ass:.1f}, 95%={p95_ass:.1f}, 99%={p99_ass:.1f}")
print(f"Full conversation mean={mean_full:.1f}, 95%={p95_full:.1f}")
# Round up to nearest power of two but don't exceed tokenizer.model_max_length
def next_power_of_2(x):
if x <= 1:
return 1
return 2 ** int(np.ceil(np.log2(x)))
target_length = int(min(p95_full if p95_full > 0 else tokenizer.model_max_length, tokenizer.model_max_length))
MAX_LENGTH = next_power_of_2(target_length)
if MAX_LENGTH > tokenizer.model_max_length:
MAX_LENGTH = tokenizer.model_max_length
print(f"Using MAX_LENGTH = {MAX_LENGTH}")
####################
# TOKENIZE DATASET #
####################
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False)
tokenized_ds = ds.map(tokenize_function, batched=True, remove_columns=ds.column_names)
##############
# LOAD MODEL #
##############
# Load model
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16 if BF16 else None,
device_map="auto",
attn_implementation="flash_attention_2",
use_cache=False,
)
try:
from liger_kernel.transformers import apply_liger_kernel_to_qwen2
try:
apply_liger_kernel_to_qwen2(model)
except TypeError:
apply_liger_kernel_to_qwen2()
print("Liger Kernel applied successfully for Qwen2 optimization")
except Exception:
print("Liger Kernel not available or failed to apply; continuing without it.")
print(f"Model loaded. Parameters: {model.num_parameters() / 1e9:.3f}B")
######################
# TRAINING ARGUMENTS #
######################
training_args = TrainingArguments(
output_dir="./qwen_rephraser_checkpoints",
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
learning_rate=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
warmup_steps=WARMUP_STEPS,
lr_scheduler_type="cosine",
logging_steps=10,
save_steps=500,
save_total_limit=2,
bf16=BF16,
optim="adamw_torch_fused",
gradient_checkpointing=True,
report_to="none",
push_to_hub=False, # we'll push manually at the end
hub_model_id=OUTPUT_MODEL,
hub_private_repo=True,
dataloader_num_workers=4,
dataloader_pin_memory=True,
ddp_find_unused_parameters=False,
torch_compile=TORCH_COMPILE,
)
###########
# TRAINER #
###########
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds,
data_collator=collator,
)
#########
# TRAIN #
#########
print("Starting training...")
trainer.train()
####################
# SAVE FINAL MODEL #
####################
print("Saving model to ./final_model ...")
model.config.use_cache = True
trainer.save_model("./final_model")
tokenizer.save_pretrained("./final_model")
##################
# PUSHING TO HUB #
##################
try:
print(f"Pushing model and tokenizer to the hub as {OUTPUT_MODEL} (private)...")
model.push_to_hub(OUTPUT_MODEL, private=True)
tokenizer.push_to_hub(OUTPUT_MODEL, private=True)
print("Push completed.")
except Exception as e:
print("Warning: push_to_hub failed:", e)
print("Training complete!")