Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -244,8 +244,8 @@ def train_model(
|
|
| 244 |
train_dataset, test_dataset = dataset["train"], dataset["test"]
|
| 245 |
|
| 246 |
# ===== ⚡ FAST mode: use small subset =====
|
| 247 |
-
train_dataset = train_dataset.select(range(min(
|
| 248 |
-
test_dataset = test_dataset.select(range(min(
|
| 249 |
log_message(output_log, f"⚡ Using {len(train_dataset)} train / {len(test_dataset)} test samples")
|
| 250 |
|
| 251 |
# ===== Format samples =====
|
|
@@ -263,7 +263,21 @@ def train_model(
|
|
| 263 |
train_dataset = train_dataset.map(format_example)
|
| 264 |
test_dataset = test_dataset.map(format_example)
|
| 265 |
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 268 |
if tokenizer.pad_token is None:
|
| 269 |
tokenizer.pad_token = tokenizer.eos_token
|
|
@@ -271,25 +285,29 @@ def train_model(
|
|
| 271 |
model = AutoModelForCausalLM.from_pretrained(
|
| 272 |
base_model,
|
| 273 |
trust_remote_code=True,
|
| 274 |
-
torch_dtype=
|
| 275 |
-
device_map="auto" if device == "cuda" else None,
|
| 276 |
low_cpu_mem_usage=True,
|
| 277 |
)
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
# ===== LoRA
|
|
|
|
| 281 |
lora_config = LoraConfig(
|
| 282 |
task_type=TaskType.CAUSAL_LM,
|
| 283 |
-
r=
|
| 284 |
-
lora_alpha=
|
| 285 |
lora_dropout=0.1,
|
| 286 |
target_modules=["q_proj", "v_proj"],
|
| 287 |
bias="none",
|
| 288 |
)
|
| 289 |
model = get_peft_model(model, lora_config)
|
| 290 |
-
|
|
|
|
| 291 |
|
| 292 |
-
# ===== Tokenization =====
|
| 293 |
def tokenize_fn(examples):
|
| 294 |
tokenized = tokenizer(
|
| 295 |
examples["text"],
|
|
@@ -302,24 +320,22 @@ def train_model(
|
|
| 302 |
|
| 303 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 304 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
|
|
|
| 305 |
|
| 306 |
-
# ===== Training
|
| 307 |
-
output_dir = "./
|
| 308 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 309 |
-
|
| 310 |
training_args = TrainingArguments(
|
| 311 |
output_dir=output_dir,
|
| 312 |
num_train_epochs=num_epochs,
|
| 313 |
per_device_train_batch_size=batch_size,
|
| 314 |
gradient_accumulation_steps=2,
|
| 315 |
-
warmup_steps=
|
| 316 |
logging_steps=5,
|
| 317 |
-
save_strategy="
|
| 318 |
-
fp16=
|
| 319 |
-
bf16=(dtype == torch.bfloat16),
|
| 320 |
-
learning_rate=learning_rate,
|
| 321 |
-
report_to="none",
|
| 322 |
optim="adamw_torch",
|
|
|
|
|
|
|
| 323 |
)
|
| 324 |
|
| 325 |
trainer = Trainer(
|
|
@@ -331,29 +347,22 @@ def train_model(
|
|
| 331 |
)
|
| 332 |
|
| 333 |
# ===== Train =====
|
| 334 |
-
log_message(output_log, "\n
|
| 335 |
trainer.train()
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
log_message(output_log, "\n💾 Saving fast fine-tuned model...")
|
| 339 |
-
model.save_pretrained(output_dir)
|
| 340 |
tokenizer.save_pretrained(output_dir)
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
folder_path=output_dir,
|
| 346 |
-
repo_type="model",
|
| 347 |
-
commit_message="Quick test fine-tune upload",
|
| 348 |
-
)
|
| 349 |
|
| 350 |
-
log_message(output_log, "
|
| 351 |
|
| 352 |
except Exception as e:
|
| 353 |
-
log_message(output_log, f"❌ Error: {e}")
|
| 354 |
|
| 355 |
return "\n".join(output_log)
|
| 356 |
-
|
| 357 |
# ==== Gradio Interface ====
|
| 358 |
def create_interface():
|
| 359 |
with gr.Blocks(title="PromptWizard — Qwen Trainer") as demo:
|
|
|
|
| 244 |
train_dataset, test_dataset = dataset["train"], dataset["test"]
|
| 245 |
|
| 246 |
# ===== ⚡ FAST mode: use small subset =====
|
| 247 |
+
train_dataset = train_dataset.select(range(min(1000, len(train_dataset))))
|
| 248 |
+
test_dataset = test_dataset.select(range(min(200, len(test_dataset))))
|
| 249 |
log_message(output_log, f"⚡ Using {len(train_dataset)} train / {len(test_dataset)} test samples")
|
| 250 |
|
| 251 |
# ===== Format samples =====
|
|
|
|
| 263 |
train_dataset = train_dataset.map(format_example)
|
| 264 |
test_dataset = test_dataset.map(format_example)
|
| 265 |
|
| 266 |
+
# ===== Format examples dynamically =====
|
| 267 |
+
def format_example(item):
|
| 268 |
+
text_content = item.get("text") or item.get("content") or str(item.get("path", "")) or " ".join(str(v) for v in item.values())
|
| 269 |
+
# Use shorter, clean system prompt + user content for better loss
|
| 270 |
+
prompt = (
|
| 271 |
+
f"<|system|>\nYou are an expert AI assistant.\n<|user|>\n{text_content}\n<|assistant|>\n"
|
| 272 |
+
)
|
| 273 |
+
return {"text": prompt}
|
| 274 |
+
|
| 275 |
+
train_dataset = train_dataset.map(format_example)
|
| 276 |
+
test_dataset = test_dataset.map(format_example)
|
| 277 |
+
log_message(output_log, f"✅ Formatted {len(train_dataset)} train + {len(test_dataset)} test examples")
|
| 278 |
+
|
| 279 |
+
# ===== Load model & tokenizer =====
|
| 280 |
+
log_message(output_log, f"\n🤖 Loading model: {base_model}")
|
| 281 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 282 |
if tokenizer.pad_token is None:
|
| 283 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
| 285 |
model = AutoModelForCausalLM.from_pretrained(
|
| 286 |
base_model,
|
| 287 |
trust_remote_code=True,
|
| 288 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
|
|
| 289 |
low_cpu_mem_usage=True,
|
| 290 |
)
|
| 291 |
+
if device == "cuda":
|
| 292 |
+
model = model.to(device)
|
| 293 |
+
log_message(output_log, "✅ Model and tokenizer loaded successfully")
|
| 294 |
+
log_message(output_log, f"Tokenizer vocab size: {tokenizer.vocab_size}")
|
| 295 |
|
| 296 |
+
# ===== LoRA configuration =====
|
| 297 |
+
log_message(output_log, "\n⚙️ Configuring LoRA for efficient fine-tuning...")
|
| 298 |
lora_config = LoraConfig(
|
| 299 |
task_type=TaskType.CAUSAL_LM,
|
| 300 |
+
r=8,
|
| 301 |
+
lora_alpha=16,
|
| 302 |
lora_dropout=0.1,
|
| 303 |
target_modules=["q_proj", "v_proj"],
|
| 304 |
bias="none",
|
| 305 |
)
|
| 306 |
model = get_peft_model(model, lora_config)
|
| 307 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 308 |
+
log_message(output_log, f"Trainable params after LoRA: {trainable_params:,}")
|
| 309 |
|
| 310 |
+
# ===== Tokenization + labels =====
|
| 311 |
def tokenize_fn(examples):
|
| 312 |
tokenized = tokenizer(
|
| 313 |
examples["text"],
|
|
|
|
| 320 |
|
| 321 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 322 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
| 323 |
+
log_message(output_log, "✅ Tokenization + labels done")
|
| 324 |
|
| 325 |
+
# ===== Training arguments =====
|
| 326 |
+
output_dir = "./qwen-gita-lora"
|
|
|
|
|
|
|
| 327 |
training_args = TrainingArguments(
|
| 328 |
output_dir=output_dir,
|
| 329 |
num_train_epochs=num_epochs,
|
| 330 |
per_device_train_batch_size=batch_size,
|
| 331 |
gradient_accumulation_steps=2,
|
| 332 |
+
warmup_steps=10,
|
| 333 |
logging_steps=5,
|
| 334 |
+
save_strategy="epoch",
|
| 335 |
+
fp16=device == "cuda",
|
|
|
|
|
|
|
|
|
|
| 336 |
optim="adamw_torch",
|
| 337 |
+
learning_rate=learning_rate,
|
| 338 |
+
max_steps=500, # Limit for demo is 100
|
| 339 |
)
|
| 340 |
|
| 341 |
trainer = Trainer(
|
|
|
|
| 347 |
)
|
| 348 |
|
| 349 |
# ===== Train =====
|
| 350 |
+
log_message(output_log, "\n🚀 Starting training...")
|
| 351 |
trainer.train()
|
| 352 |
+
log_message(output_log, "\n💾 Saving trained model locally...")
|
| 353 |
+
trainer.save_model(output_dir)
|
|
|
|
|
|
|
| 354 |
tokenizer.save_pretrained(output_dir)
|
| 355 |
|
| 356 |
+
# ===== Async upload =====
|
| 357 |
+
log_message(output_log, f"\n☁️ Initiating async upload to {hf_repo}")
|
| 358 |
+
start_async_upload(output_dir, hf_repo, output_log)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
+
log_message(output_log, "✅ Training complete & async upload started!")
|
| 361 |
|
| 362 |
except Exception as e:
|
| 363 |
+
log_message(output_log, f"\n❌ Error during training: {e}")
|
| 364 |
|
| 365 |
return "\n".join(output_log)
|
|
|
|
| 366 |
# ==== Gradio Interface ====
|
| 367 |
def create_interface():
|
| 368 |
with gr.Blocks(title="PromptWizard — Qwen Trainer") as demo:
|