sd15-flow-lune / trainer.py
AbstractPhil's picture
Create trainer.py
df0879c verified
raw
history blame
16 kB
"""
SD15 Flow-Matching trainer
Author: AbstractPhil
Loads the current format pt and ensures through multiple validations that the process is correct for training.
Trains flow matching for sd15.
License: MIT
If you use my work, a cite wouldnt hurt.
"""
import os
import json
import datetime
from dataclasses import dataclass, asdict
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import datasets
from diffusers import UNet2DConditionModel
from huggingface_hub import HfApi, create_repo, hf_hub_download
@dataclass
class TrainConfig:
output_dir: str = "./outputs"
model_repo: str = "AbstractPhil/sd15-flow-matching-try2"
checkpoint_filename: str = "sd15_flowmatch_david_weighted_2_e34.pt"
dataset_name: str = "AbstractPhil/sd15-latent-distillation-500k"
# HuggingFace upload settings
hf_repo_id: str = "AbstractPhil/sd15-flow-lune"
upload_to_hub: bool = True
seed: int = 42
batch_size: int = 16
base_lr: float = 2e-6
shift: float = 2.0
dropout: float = 0.1
max_train_steps: int = 50_000
checkpointing_steps: int = 1000
num_workers: int = 0
# VAE scaling factor - multiply raw latents
vae_scale: float = 0.18215
def load_student_unet(repo_id: str, filename: str, device="cuda") -> UNet2DConditionModel:
"""Load UNet from .pt checkpoint containing student state_dict"""
# Download checkpoint from HuggingFace
print(f"Downloading checkpoint from {repo_id}/{filename}...")
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model"
)
print(f"✓ Downloaded to: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Initialize UNet with SD1.5 config in fp32
print("Loading SD1.5 UNet architecture...")
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
torch_dtype=torch.float32
)
# Get original state for comparison
original_state_dict = {k: v.clone() for k, v in unet.state_dict().items()}
# Load student weights and strip "unet." prefix
student_state_dict = checkpoint["student"]
# Strip prefix if present
cleaned_student_dict = {}
for key, value in student_state_dict.items():
if key.startswith("unet."):
cleaned_key = key[5:] # Remove "unet." prefix
cleaned_student_dict[cleaned_key] = value
else:
cleaned_student_dict[key] = value
print(f"\n{'='*70}")
print("WEIGHT VERIFICATION")
print(f"{'='*70}")
# 1. Compare keys
original_keys = set(original_state_dict.keys())
student_keys = set(cleaned_student_dict.keys())
matching_keys = original_keys & student_keys
print(f"Original UNet keys: {len(original_keys)}")
print(f"Student checkpoint keys: {len(student_keys)}")
print(f"Matching keys: {len(matching_keys)}")
# 2. Compare student weights vs original BEFORE loading
total_params = 0
different_params = 0
mean_diff_sum = 0.0
max_diff = 0.0
for key in matching_keys:
if key not in original_state_dict or key not in cleaned_student_dict:
continue
orig = original_state_dict[key]
student = cleaned_student_dict[key].float() # Convert to fp32 for comparison
if orig.shape != student.shape:
print(f"⚠ Shape mismatch for {key}: {orig.shape} vs {student.shape}")
continue
total_params += orig.numel()
# Check if weights are different
diff = (orig - student).abs()
if diff.max() > 1e-6:
different_params += orig.numel()
mean_diff_sum += diff.sum().item()
max_diff = max(max_diff, diff.max().item())
pct_different = (different_params / total_params * 100) if total_params > 0 else 0
avg_diff = mean_diff_sum / different_params if different_params > 0 else 0
print(f"\nStudent vs Original (BEFORE loading):")
print(f" Total parameters: {total_params:,}")
print(f" Parameters different: {different_params:,} ({pct_different:.1f}%)")
print(f" Average difference: {avg_diff:.6f}")
print(f" Max difference: {max_diff:.6f}")
# 3. Load weights
load_result = unet.load_state_dict(cleaned_student_dict, strict=False)
if load_result.missing_keys:
print(f"\n⚠ Missing keys during load: {len(load_result.missing_keys)}")
for key in load_result.missing_keys[:3]:
print(f" - {key}")
if load_result.unexpected_keys:
print(f"⚠ Unexpected keys during load: {len(load_result.unexpected_keys)}")
for key in load_result.unexpected_keys[:3]:
print(f" - {key}")
# 4. Verify weights actually changed after loading
loaded_state_dict = unet.state_dict()
total_params_after = 0
changed_params = 0
mean_diff_after = 0.0
max_diff_after = 0.0
for key in matching_keys:
if key not in original_state_dict or key not in loaded_state_dict:
continue
orig = original_state_dict[key]
loaded = loaded_state_dict[key]
total_params_after += orig.numel()
diff = (orig - loaded).abs()
if diff.max() > 1e-6:
changed_params += orig.numel()
mean_diff_after += diff.sum().item()
max_diff_after = max(max_diff_after, diff.max().item())
pct_changed = (changed_params / total_params_after * 100) if total_params_after > 0 else 0
avg_diff_after = mean_diff_after / changed_params if changed_params > 0 else 0
print(f"\nOriginal vs Loaded (AFTER loading):")
print(f" Parameters changed: {changed_params:,} ({pct_changed:.1f}%)")
print(f" Average difference: {avg_diff_after:.6f}")
print(f" Max difference: {max_diff_after:.6f}")
print(f"\n{'='*70}")
# Verification checks
if pct_different < 50:
print(f"⚠️ WARNING: Student weights only {pct_different:.1f}% different from base!")
print(" This checkpoint may not be trained.")
elif pct_changed < 90:
print(f"⚠️ WARNING: Only {pct_changed:.1f}% of weights changed after loading!")
print(" The load may have failed.")
else:
print(f"✅ Weights loaded successfully!")
print(f" Checkpoint step: {checkpoint.get('gstep', 'unknown')}")
print(f" {pct_different:.1f}% of weights differ from base SD1.5")
print(f"{'='*70}\n")
return unet.to(device)
def train(config: TrainConfig):
device = "cuda"
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(config.seed)
torch.cuda.manual_seed(config.seed)
# Setup output directory
date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
real_output_dir = os.path.join(config.output_dir, date_time)
os.makedirs(real_output_dir, exist_ok=True)
t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60)
# Initialize HuggingFace API
hf_api = None
if config.upload_to_hub:
try:
hf_api = HfApi()
create_repo(
repo_id=config.hf_repo_id,
repo_type="model",
exist_ok=True,
private=False
)
print(f"✓ HuggingFace repo ready: {config.hf_repo_id}")
except Exception as e:
print(f"⚠ Hub upload disabled: {e}")
config.upload_to_hub = False
# Save config locally and to hub
config_path = os.path.join(real_output_dir, "config.json")
with open(config_path, "w") as f:
json.dump(asdict(config), f, indent=2)
if config.upload_to_hub:
hf_api.upload_file(
path_or_fileobj=config_path,
path_in_repo="config.json",
repo_id=config.hf_repo_id,
repo_type="model"
)
# Load dataset in streaming mode
print(f"\nLoading dataset (streaming): {config.dataset_name}")
train_dataset = datasets.load_dataset(
config.dataset_name,
split="train",
streaming=True,
trust_remote_code=True
)
train_dataset = train_dataset.shuffle(seed=config.seed, buffer_size=1000)
print(f"✓ Dataset loaded in streaming mode")
def collate_fn(examples):
# Latents are RAW from VAE - need to scale them
latents = torch.stack([torch.tensor(ex["latent"]) for ex in examples])
latents = latents * config.vae_scale # Scale: ~[-6, 6] -> ~[-1, 1]
clip_embeddings = torch.stack([torch.tensor(ex["clip_embedding"]) for ex in examples])
ids = [ex["id"] for ex in examples]
prompts = [ex["prompt"] for ex in examples]
return latents, clip_embeddings, ids, prompts
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=config.batch_size,
collate_fn=collate_fn,
num_workers=config.num_workers,
)
# Verify first batch latent range (on GPU for speed)
print("\nVerifying latent scaling on first batch...")
first_batch = next(iter(train_dataloader))
latents_check, _, _, _ = first_batch
print(f"Raw latent range: [{latents_check.min():.3f}, {latents_check.max():.3f}]")
latents_check = latents_check.to(device)
print(f"After GPU transfer: [{latents_check.min():.3f}, {latents_check.max():.3f}]")
print(f"Expected: ~[-1, 1] for properly scaled latents")
del latents_check
# Load pretrained student UNet
print(f"\nLoading model from HuggingFace...")
unet = load_student_unet(config.model_repo, config.checkpoint_filename, device=device)
unet.requires_grad_(True)
unet.enable_gradient_checkpointing()
unet.train()
optimizer = torch.optim.Adam(
unet.parameters(),
lr=config.base_lr * (config.batch_size ** 0.5),
)
global_step = 0
train_logs = {
"train_step": [],
"train_loss": [],
"train_timestep": [],
"trained_images": []
}
def get_prediction(batch, log_to=None):
latents, encoder_hidden_states, ids, prompts = batch
# Everything in fp32
latents = latents.to(dtype=torch.float32, device=device)
encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float32, device=device)
batch_size = latents.shape[0]
# Apply dropout to conditioning for CFG support
dropout_mask = torch.rand(batch_size, device=device) < config.dropout
encoder_hidden_states = encoder_hidden_states.clone()
encoder_hidden_states[dropout_mask] = 0
# Sample timesteps with shift
sigmas = torch.rand(batch_size, device=device)
sigmas = (config.shift * sigmas) / (1 + (config.shift - 1) * sigmas)
timesteps = sigmas * 1000
sigmas = sigmas[:, None, None, None]
# Flow matching forward process
noise = torch.randn_like(latents)
noisy_latents = noise * sigmas + latents * (1 - sigmas)
target = noise - latents
# Predict velocity
pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(pred, target, reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape))))
if log_to is not None:
for i in range(batch_size):
log_to["train_step"].append(global_step)
log_to["train_loss"].append(loss[i].item())
log_to["train_timestep"].append(timesteps[i].item())
log_to["trained_images"].append({
"step": global_step,
"id": ids[i],
"prompt": prompts[i]
})
return loss.mean()
def plot_logs(log_dict):
plt.figure(figsize=(10, 6))
plt.scatter(
log_dict["train_timestep"],
log_dict["train_loss"],
s=3,
c=log_dict["train_step"],
marker=".",
cmap='cool'
)
plt.xlabel("timestep")
plt.ylabel("loss")
plt.yscale("log")
plt.colorbar(label="step")
def save_checkpoint(step):
checkpoint_path = os.path.join(real_output_dir, f"checkpoint-{step:08}")
os.makedirs(checkpoint_path, exist_ok=True)
# Save UNet weights as diffusers format
unet.save_pretrained(
os.path.join(checkpoint_path, "unet"),
safe_serialization=True
)
# Save complete checkpoint in .pt format
pt_filename = f"sd15_flow_lune_e{step//1000}_s{step}.pt"
pt_path = os.path.join(checkpoint_path, pt_filename)
torch.save({
"cfg": asdict(config),
"student": unet.state_dict(),
"opt": optimizer.state_dict(),
"gstep": step
}, pt_path)
# Save training metadata
metadata = {
"step": step,
"trained_images": train_logs["trained_images"]
}
metadata_path = os.path.join(checkpoint_path, "trained_images.json")
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
print(f"✓ Checkpoint saved at step {step}")
# Upload to HuggingFace Hub
if config.upload_to_hub and hf_api is not None:
try:
hf_api.upload_file(
path_or_fileobj=pt_path,
path_in_repo=pt_filename,
repo_id=config.hf_repo_id,
repo_type="model"
)
hf_api.upload_folder(
folder_path=os.path.join(checkpoint_path, "unet"),
path_in_repo=f"checkpoint-{step:08}/unet",
repo_id=config.hf_repo_id,
repo_type="model"
)
hf_api.upload_file(
path_or_fileobj=metadata_path,
path_in_repo=f"checkpoint-{step:08}/trained_images.json",
repo_id=config.hf_repo_id,
repo_type="model"
)
print(f"✓ Uploaded to hub: {config.hf_repo_id}")
except Exception as e:
print(f"⚠ Upload failed: {e}")
print("\nStarting training...")
progress_bar = tqdm(range(0, config.max_train_steps))
for batch in train_dataloader:
loss = get_prediction(batch, log_to=train_logs)
t_writer.add_scalar("train/loss", loss.detach().item(), global_step)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 2.0)
t_writer.add_scalar("train/grad_norm", grad_norm.detach().item(), global_step)
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
global_step += 1
if global_step % 100 == 0:
plot_logs(train_logs)
t_writer.add_figure("train_loss", plt.gcf(), global_step)
plt.close()
if global_step % config.checkpointing_steps == 0:
save_checkpoint(global_step)
if global_step >= config.max_train_steps:
save_checkpoint(global_step)
print("\n✅ Training complete!")
return
if __name__ == "__main__":
config = TrainConfig()
train(config)