|
|
import os
|
|
|
import re
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import wandb
|
|
|
from datasets import load_dataset, Dataset
|
|
|
from transformers import (
|
|
|
AutoModelForCausalLM,
|
|
|
AutoTokenizer,
|
|
|
TrainingArguments,
|
|
|
Trainer,
|
|
|
DataCollatorForLanguageModeling,
|
|
|
)
|
|
|
from huggingface_hub import login, hf_hub_download
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "krogoldAI/QueryRefiner-0.5B-v0.1-GRPO-SFT1"
|
|
|
DATASET = "krogoldAI/rag-query-analysis"
|
|
|
OUTPUT_MODEL = "model/repo"
|
|
|
|
|
|
|
|
|
NUM_EPOCHS = 3
|
|
|
PER_DEVICE_BATCH = 4
|
|
|
GRADIENT_ACCUMULATION = 4
|
|
|
LEARNING_RATE = 2e-5
|
|
|
WEIGHT_DECAY = 0.01
|
|
|
WARMUP_STEPS = 50
|
|
|
BF16 = True
|
|
|
TORCH_COMPILE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
login("<MY_HF_TOKEN>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
tokenizer.padding_side = "right"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_ds = load_dataset(DATASET, "default", split="train")
|
|
|
raw_ds = raw_ds.shuffle(seed=42)
|
|
|
|
|
|
|
|
|
metrics = [
|
|
|
"ambiguity_assessment",
|
|
|
"domain_accuracy",
|
|
|
"follows_guidelines",
|
|
|
"intent_accuracy",
|
|
|
"intent_preservation",
|
|
|
"rephrasing_quality",
|
|
|
]
|
|
|
|
|
|
|
|
|
def check(metadata):
|
|
|
if metadata.get("critical_issue") is not None or metadata.get("usable") == False:
|
|
|
return False
|
|
|
for metric in metrics:
|
|
|
if metadata.get(metric) != 5:
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
filtered_ds = raw_ds.filter(lambda x: check(x["metadata"]))
|
|
|
|
|
|
|
|
|
def format_and_clean(batch):
|
|
|
out = []
|
|
|
for conv in batch["text"]:
|
|
|
s = tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
|
|
|
s = re.sub(r"```[a-zA-Z0-9]*\n?", "", s).replace("```", "").strip()
|
|
|
out.append(s)
|
|
|
return {"text": out}
|
|
|
|
|
|
ds = filtered_ds.map(format_and_clean, batched=True, batch_size=128)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Qwen25DataCollator(DataCollatorForLanguageModeling):
|
|
|
def __init__(self, tokenizer, mlm=False):
|
|
|
super().__init__(tokenizer=tokenizer, mlm=mlm)
|
|
|
|
|
|
try:
|
|
|
self.im_start_token = tokenizer.encode("<|im_start|>", add_special_tokens=False)[0]
|
|
|
except Exception:
|
|
|
self.im_start_token = None
|
|
|
try:
|
|
|
self.im_end_token = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
|
|
|
except Exception:
|
|
|
self.im_end_token = None
|
|
|
|
|
|
|
|
|
try:
|
|
|
self.assistant_text = tokenizer.encode("assistant", add_special_tokens=False)
|
|
|
except Exception:
|
|
|
self.assistant_text = []
|
|
|
|
|
|
|
|
|
def __call__(self, features):
|
|
|
return self.torch_call(features)
|
|
|
|
|
|
def torch_call(self, examples):
|
|
|
"""
|
|
|
examples: list of dicts returned by tokenization (each example contains 'input_ids', 'attention_mask', etc.)
|
|
|
We'll leverage the parent to create initial batch and then mask labels for assistant responses only.
|
|
|
"""
|
|
|
batch = super().torch_call(examples)
|
|
|
input_ids = batch["input_ids"]
|
|
|
labels = batch["labels"]
|
|
|
|
|
|
|
|
|
if self.im_start_token is None or self.im_end_token is None or len(self.assistant_text) == 0:
|
|
|
return batch
|
|
|
|
|
|
|
|
|
for i, ids in enumerate(input_ids):
|
|
|
|
|
|
im_start_positions = torch.where(ids == self.im_start_token)[0]
|
|
|
im_end_positions = torch.where(ids == self.im_end_token)[0]
|
|
|
|
|
|
if im_start_positions.numel() == 0 or im_end_positions.numel() == 0:
|
|
|
|
|
|
continue
|
|
|
|
|
|
last_assistant_start = None
|
|
|
|
|
|
for start_pos in im_start_positions:
|
|
|
|
|
|
as_len = len(self.assistant_text)
|
|
|
candidate_end = start_pos + 1 + as_len
|
|
|
if candidate_end <= len(ids):
|
|
|
segment = ids[start_pos + 1:start_pos + 1 + as_len]
|
|
|
if torch.equal(segment, torch.tensor(self.assistant_text, device=ids.device)):
|
|
|
last_assistant_start = int(start_pos)
|
|
|
|
|
|
if last_assistant_start is None:
|
|
|
continue
|
|
|
|
|
|
|
|
|
assistant_end_positions = im_end_positions[im_end_positions > last_assistant_start]
|
|
|
if assistant_end_positions.numel() == 0:
|
|
|
continue
|
|
|
|
|
|
assistant_end = int(assistant_end_positions[0])
|
|
|
|
|
|
|
|
|
|
|
|
response_start = last_assistant_start + 1 + len(self.assistant_text)
|
|
|
|
|
|
|
|
|
if response_start < len(ids) and ids[response_start] == tokenizer.encode("\n", add_special_tokens=False)[0]:
|
|
|
response_start += 1
|
|
|
|
|
|
|
|
|
|
|
|
labels[i, :] = -100
|
|
|
if response_start < len(ids):
|
|
|
|
|
|
end_idx = min(assistant_end + 1, ids.shape[0])
|
|
|
labels[i, response_start:end_idx] = ids[response_start:end_idx]
|
|
|
|
|
|
|
|
|
batch["labels"] = labels
|
|
|
return batch
|
|
|
|
|
|
collator = Qwen25DataCollator(tokenizer=tokenizer, mlm=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Analyzing dataset to determine max_length (sample up to 1000)...")
|
|
|
assistant_lengths = []
|
|
|
full_lengths = []
|
|
|
|
|
|
sample_limit = min(1000, len(ds))
|
|
|
for example in ds["text"][:sample_limit]:
|
|
|
full_tokens = tokenizer(example, truncation=False, add_special_tokens=True)
|
|
|
full_lengths.append(len(full_tokens["input_ids"]))
|
|
|
|
|
|
|
|
|
pattern = r"<\|im_start\|>assistant\n(.*?)<\|im_end\|>"
|
|
|
matches = re.findall(pattern, example, re.DOTALL)
|
|
|
if matches:
|
|
|
last_response = matches[-1]
|
|
|
resp_tokens = tokenizer(last_response, truncation=False, add_special_tokens=False)
|
|
|
assistant_lengths.append(len(resp_tokens["input_ids"]))
|
|
|
|
|
|
|
|
|
def safe_stat(arr):
|
|
|
if len(arr) == 0:
|
|
|
return 0.0, 0.0, 0.0, 0.0
|
|
|
return np.mean(arr), np.median(arr), np.percentile(arr, 95), np.percentile(arr, 99)
|
|
|
|
|
|
mean_ass, med_ass, p95_ass, p99_ass = safe_stat(assistant_lengths)
|
|
|
mean_full, _, p95_full, _ = safe_stat(full_lengths)
|
|
|
|
|
|
print(f"Assistant response mean={mean_ass:.1f}, median={med_ass:.1f}, 95%={p95_ass:.1f}, 99%={p99_ass:.1f}")
|
|
|
print(f"Full conversation mean={mean_full:.1f}, 95%={p95_full:.1f}")
|
|
|
|
|
|
|
|
|
def next_power_of_2(x):
|
|
|
if x <= 1:
|
|
|
return 1
|
|
|
return 2 ** int(np.ceil(np.log2(x)))
|
|
|
|
|
|
target_length = int(min(p95_full if p95_full > 0 else tokenizer.model_max_length, tokenizer.model_max_length))
|
|
|
MAX_LENGTH = next_power_of_2(target_length)
|
|
|
if MAX_LENGTH > tokenizer.model_max_length:
|
|
|
MAX_LENGTH = tokenizer.model_max_length
|
|
|
|
|
|
print(f"Using MAX_LENGTH = {MAX_LENGTH}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_function(examples):
|
|
|
return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False)
|
|
|
|
|
|
tokenized_ds = ds.map(tokenize_function, batched=True, remove_columns=ds.column_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
MODEL_NAME,
|
|
|
torch_dtype=torch.bfloat16 if BF16 else None,
|
|
|
device_map="auto",
|
|
|
attn_implementation="flash_attention_2",
|
|
|
use_cache=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
try:
|
|
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen2
|
|
|
try:
|
|
|
apply_liger_kernel_to_qwen2(model)
|
|
|
except TypeError:
|
|
|
apply_liger_kernel_to_qwen2()
|
|
|
print("Liger Kernel applied successfully for Qwen2 optimization")
|
|
|
except Exception:
|
|
|
print("Liger Kernel not available or failed to apply; continuing without it.")
|
|
|
|
|
|
print(f"Model loaded. Parameters: {model.num_parameters() / 1e9:.3f}B")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wandb.login(key = "<MY_WANDB_KEY>")
|
|
|
run = wandb.init(project='QueryRefiner-GRPO-Phase2', job_type="training", anonymous="allow")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
output_dir="./qwen_rephraser_checkpoints",
|
|
|
num_train_epochs=NUM_EPOCHS,
|
|
|
per_device_train_batch_size=PER_DEVICE_BATCH,
|
|
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
|
|
|
learning_rate=LEARNING_RATE,
|
|
|
weight_decay=WEIGHT_DECAY,
|
|
|
warmup_steps=WARMUP_STEPS,
|
|
|
lr_scheduler_type="cosine",
|
|
|
logging_steps=10,
|
|
|
save_steps=500,
|
|
|
save_total_limit=2,
|
|
|
bf16=BF16,
|
|
|
optim="adamw_torch_fused",
|
|
|
gradient_checkpointing=True,
|
|
|
report_to="wandb",
|
|
|
push_to_hub=False,
|
|
|
hub_model_id=OUTPUT_MODEL,
|
|
|
hub_private_repo=True,
|
|
|
dataloader_num_workers=4,
|
|
|
dataloader_pin_memory=True,
|
|
|
ddp_find_unused_parameters=False,
|
|
|
torch_compile=TORCH_COMPILE,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model=model,
|
|
|
args=training_args,
|
|
|
train_dataset=tokenized_ds,
|
|
|
data_collator=collator,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Starting training...")
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Saving model to ./final_model ...")
|
|
|
model.config.use_cache = True
|
|
|
trainer.save_model("./final_model")
|
|
|
tokenizer.save_pretrained("./final_model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
print(f"Pushing model and tokenizer to the hub as {OUTPUT_MODEL} (private)...")
|
|
|
model.push_to_hub(OUTPUT_MODEL, private=True)
|
|
|
tokenizer.push_to_hub(OUTPUT_MODEL, private=True)
|
|
|
print("Push completed.")
|
|
|
except Exception as e:
|
|
|
print("Warning: push_to_hub failed:", e)
|
|
|
|
|
|
print("Training complete!") |