|
|
""" |
|
|
SD15 Flow-Matching trainer |
|
|
Author: AbstractPhil |
|
|
|
|
|
Loads the current format pt and ensures through multiple validations that the process is correct for training. |
|
|
|
|
|
Trains flow matching for sd15. |
|
|
|
|
|
License: MIT |
|
|
If you use my work, a cite wouldnt hurt. |
|
|
|
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import datetime |
|
|
from dataclasses import dataclass, asdict |
|
|
from tqdm.auto import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
import datasets |
|
|
from diffusers import UNet2DConditionModel |
|
|
from huggingface_hub import HfApi, create_repo, hf_hub_download |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainConfig: |
|
|
output_dir: str = "./outputs" |
|
|
model_repo: str = "AbstractPhil/sd15-flow-matching-try2" |
|
|
checkpoint_filename: str = "sd15_flowmatch_david_weighted_2_e34.pt" |
|
|
dataset_name: str = "AbstractPhil/sd15-latent-distillation-500k" |
|
|
|
|
|
|
|
|
hf_repo_id: str = "AbstractPhil/sd15-flow-lune" |
|
|
upload_to_hub: bool = True |
|
|
|
|
|
seed: int = 42 |
|
|
batch_size: int = 16 |
|
|
base_lr: float = 2e-6 |
|
|
shift: float = 2.0 |
|
|
dropout: float = 0.1 |
|
|
|
|
|
max_train_steps: int = 50_000 |
|
|
checkpointing_steps: int = 1000 |
|
|
num_workers: int = 0 |
|
|
|
|
|
|
|
|
vae_scale: float = 0.18215 |
|
|
|
|
|
|
|
|
def load_student_unet(repo_id: str, filename: str, device="cuda") -> UNet2DConditionModel: |
|
|
"""Load UNet from .pt checkpoint containing student state_dict""" |
|
|
|
|
|
print(f"Downloading checkpoint from {repo_id}/{filename}...") |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=filename, |
|
|
repo_type="model" |
|
|
) |
|
|
print(f"✓ Downloaded to: {checkpoint_path}") |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
|
|
|
print("Loading SD1.5 UNet architecture...") |
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="unet", |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
original_state_dict = {k: v.clone() for k, v in unet.state_dict().items()} |
|
|
|
|
|
|
|
|
student_state_dict = checkpoint["student"] |
|
|
|
|
|
|
|
|
cleaned_student_dict = {} |
|
|
for key, value in student_state_dict.items(): |
|
|
if key.startswith("unet."): |
|
|
cleaned_key = key[5:] |
|
|
cleaned_student_dict[cleaned_key] = value |
|
|
else: |
|
|
cleaned_student_dict[key] = value |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("WEIGHT VERIFICATION") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
|
|
|
original_keys = set(original_state_dict.keys()) |
|
|
student_keys = set(cleaned_student_dict.keys()) |
|
|
|
|
|
matching_keys = original_keys & student_keys |
|
|
|
|
|
print(f"Original UNet keys: {len(original_keys)}") |
|
|
print(f"Student checkpoint keys: {len(student_keys)}") |
|
|
print(f"Matching keys: {len(matching_keys)}") |
|
|
|
|
|
|
|
|
total_params = 0 |
|
|
different_params = 0 |
|
|
mean_diff_sum = 0.0 |
|
|
max_diff = 0.0 |
|
|
|
|
|
for key in matching_keys: |
|
|
if key not in original_state_dict or key not in cleaned_student_dict: |
|
|
continue |
|
|
|
|
|
orig = original_state_dict[key] |
|
|
student = cleaned_student_dict[key].float() |
|
|
|
|
|
if orig.shape != student.shape: |
|
|
print(f"⚠ Shape mismatch for {key}: {orig.shape} vs {student.shape}") |
|
|
continue |
|
|
|
|
|
total_params += orig.numel() |
|
|
|
|
|
|
|
|
diff = (orig - student).abs() |
|
|
if diff.max() > 1e-6: |
|
|
different_params += orig.numel() |
|
|
mean_diff_sum += diff.sum().item() |
|
|
max_diff = max(max_diff, diff.max().item()) |
|
|
|
|
|
pct_different = (different_params / total_params * 100) if total_params > 0 else 0 |
|
|
avg_diff = mean_diff_sum / different_params if different_params > 0 else 0 |
|
|
|
|
|
print(f"\nStudent vs Original (BEFORE loading):") |
|
|
print(f" Total parameters: {total_params:,}") |
|
|
print(f" Parameters different: {different_params:,} ({pct_different:.1f}%)") |
|
|
print(f" Average difference: {avg_diff:.6f}") |
|
|
print(f" Max difference: {max_diff:.6f}") |
|
|
|
|
|
|
|
|
load_result = unet.load_state_dict(cleaned_student_dict, strict=False) |
|
|
|
|
|
if load_result.missing_keys: |
|
|
print(f"\n⚠ Missing keys during load: {len(load_result.missing_keys)}") |
|
|
for key in load_result.missing_keys[:3]: |
|
|
print(f" - {key}") |
|
|
|
|
|
if load_result.unexpected_keys: |
|
|
print(f"⚠ Unexpected keys during load: {len(load_result.unexpected_keys)}") |
|
|
for key in load_result.unexpected_keys[:3]: |
|
|
print(f" - {key}") |
|
|
|
|
|
|
|
|
loaded_state_dict = unet.state_dict() |
|
|
|
|
|
total_params_after = 0 |
|
|
changed_params = 0 |
|
|
mean_diff_after = 0.0 |
|
|
max_diff_after = 0.0 |
|
|
|
|
|
for key in matching_keys: |
|
|
if key not in original_state_dict or key not in loaded_state_dict: |
|
|
continue |
|
|
|
|
|
orig = original_state_dict[key] |
|
|
loaded = loaded_state_dict[key] |
|
|
|
|
|
total_params_after += orig.numel() |
|
|
|
|
|
diff = (orig - loaded).abs() |
|
|
if diff.max() > 1e-6: |
|
|
changed_params += orig.numel() |
|
|
mean_diff_after += diff.sum().item() |
|
|
max_diff_after = max(max_diff_after, diff.max().item()) |
|
|
|
|
|
pct_changed = (changed_params / total_params_after * 100) if total_params_after > 0 else 0 |
|
|
avg_diff_after = mean_diff_after / changed_params if changed_params > 0 else 0 |
|
|
|
|
|
print(f"\nOriginal vs Loaded (AFTER loading):") |
|
|
print(f" Parameters changed: {changed_params:,} ({pct_changed:.1f}%)") |
|
|
print(f" Average difference: {avg_diff_after:.6f}") |
|
|
print(f" Max difference: {max_diff_after:.6f}") |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
|
|
|
if pct_different < 50: |
|
|
print(f"⚠️ WARNING: Student weights only {pct_different:.1f}% different from base!") |
|
|
print(" This checkpoint may not be trained.") |
|
|
elif pct_changed < 90: |
|
|
print(f"⚠️ WARNING: Only {pct_changed:.1f}% of weights changed after loading!") |
|
|
print(" The load may have failed.") |
|
|
else: |
|
|
print(f"✅ Weights loaded successfully!") |
|
|
print(f" Checkpoint step: {checkpoint.get('gstep', 'unknown')}") |
|
|
print(f" {pct_different:.1f}% of weights differ from base SD1.5") |
|
|
|
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
return unet.to(device) |
|
|
|
|
|
|
|
|
def train(config: TrainConfig): |
|
|
device = "cuda" |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
torch.manual_seed(config.seed) |
|
|
torch.cuda.manual_seed(config.seed) |
|
|
|
|
|
|
|
|
date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
real_output_dir = os.path.join(config.output_dir, date_time) |
|
|
os.makedirs(real_output_dir, exist_ok=True) |
|
|
t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60) |
|
|
|
|
|
|
|
|
hf_api = None |
|
|
if config.upload_to_hub: |
|
|
try: |
|
|
hf_api = HfApi() |
|
|
create_repo( |
|
|
repo_id=config.hf_repo_id, |
|
|
repo_type="model", |
|
|
exist_ok=True, |
|
|
private=False |
|
|
) |
|
|
print(f"✓ HuggingFace repo ready: {config.hf_repo_id}") |
|
|
except Exception as e: |
|
|
print(f"⚠ Hub upload disabled: {e}") |
|
|
config.upload_to_hub = False |
|
|
|
|
|
|
|
|
config_path = os.path.join(real_output_dir, "config.json") |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(asdict(config), f, indent=2) |
|
|
|
|
|
if config.upload_to_hub: |
|
|
hf_api.upload_file( |
|
|
path_or_fileobj=config_path, |
|
|
path_in_repo="config.json", |
|
|
repo_id=config.hf_repo_id, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\nLoading dataset (streaming): {config.dataset_name}") |
|
|
train_dataset = datasets.load_dataset( |
|
|
config.dataset_name, |
|
|
split="train", |
|
|
streaming=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
train_dataset = train_dataset.shuffle(seed=config.seed, buffer_size=1000) |
|
|
print(f"✓ Dataset loaded in streaming mode") |
|
|
|
|
|
def collate_fn(examples): |
|
|
|
|
|
latents = torch.stack([torch.tensor(ex["latent"]) for ex in examples]) |
|
|
latents = latents * config.vae_scale |
|
|
|
|
|
clip_embeddings = torch.stack([torch.tensor(ex["clip_embedding"]) for ex in examples]) |
|
|
ids = [ex["id"] for ex in examples] |
|
|
prompts = [ex["prompt"] for ex in examples] |
|
|
|
|
|
return latents, clip_embeddings, ids, prompts |
|
|
|
|
|
train_dataloader = DataLoader( |
|
|
dataset=train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
collate_fn=collate_fn, |
|
|
num_workers=config.num_workers, |
|
|
) |
|
|
|
|
|
|
|
|
print("\nVerifying latent scaling on first batch...") |
|
|
first_batch = next(iter(train_dataloader)) |
|
|
latents_check, _, _, _ = first_batch |
|
|
print(f"Raw latent range: [{latents_check.min():.3f}, {latents_check.max():.3f}]") |
|
|
latents_check = latents_check.to(device) |
|
|
print(f"After GPU transfer: [{latents_check.min():.3f}, {latents_check.max():.3f}]") |
|
|
print(f"Expected: ~[-1, 1] for properly scaled latents") |
|
|
del latents_check |
|
|
|
|
|
|
|
|
print(f"\nLoading model from HuggingFace...") |
|
|
unet = load_student_unet(config.model_repo, config.checkpoint_filename, device=device) |
|
|
unet.requires_grad_(True) |
|
|
unet.enable_gradient_checkpointing() |
|
|
unet.train() |
|
|
|
|
|
optimizer = torch.optim.Adam( |
|
|
unet.parameters(), |
|
|
lr=config.base_lr * (config.batch_size ** 0.5), |
|
|
) |
|
|
|
|
|
global_step = 0 |
|
|
train_logs = { |
|
|
"train_step": [], |
|
|
"train_loss": [], |
|
|
"train_timestep": [], |
|
|
"trained_images": [] |
|
|
} |
|
|
|
|
|
def get_prediction(batch, log_to=None): |
|
|
latents, encoder_hidden_states, ids, prompts = batch |
|
|
|
|
|
|
|
|
latents = latents.to(dtype=torch.float32, device=device) |
|
|
encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float32, device=device) |
|
|
|
|
|
batch_size = latents.shape[0] |
|
|
|
|
|
|
|
|
dropout_mask = torch.rand(batch_size, device=device) < config.dropout |
|
|
encoder_hidden_states = encoder_hidden_states.clone() |
|
|
encoder_hidden_states[dropout_mask] = 0 |
|
|
|
|
|
|
|
|
sigmas = torch.rand(batch_size, device=device) |
|
|
sigmas = (config.shift * sigmas) / (1 + (config.shift - 1) * sigmas) |
|
|
timesteps = sigmas * 1000 |
|
|
sigmas = sigmas[:, None, None, None] |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
noisy_latents = noise * sigmas + latents * (1 - sigmas) |
|
|
target = noise - latents |
|
|
|
|
|
|
|
|
pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
|
|
|
|
|
loss = F.mse_loss(pred, target, reduction="none") |
|
|
loss = loss.mean(dim=list(range(1, len(loss.shape)))) |
|
|
|
|
|
if log_to is not None: |
|
|
for i in range(batch_size): |
|
|
log_to["train_step"].append(global_step) |
|
|
log_to["train_loss"].append(loss[i].item()) |
|
|
log_to["train_timestep"].append(timesteps[i].item()) |
|
|
log_to["trained_images"].append({ |
|
|
"step": global_step, |
|
|
"id": ids[i], |
|
|
"prompt": prompts[i] |
|
|
}) |
|
|
|
|
|
return loss.mean() |
|
|
|
|
|
def plot_logs(log_dict): |
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.scatter( |
|
|
log_dict["train_timestep"], |
|
|
log_dict["train_loss"], |
|
|
s=3, |
|
|
c=log_dict["train_step"], |
|
|
marker=".", |
|
|
cmap='cool' |
|
|
) |
|
|
plt.xlabel("timestep") |
|
|
plt.ylabel("loss") |
|
|
plt.yscale("log") |
|
|
plt.colorbar(label="step") |
|
|
|
|
|
def save_checkpoint(step): |
|
|
checkpoint_path = os.path.join(real_output_dir, f"checkpoint-{step:08}") |
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
|
|
|
|
|
|
unet.save_pretrained( |
|
|
os.path.join(checkpoint_path, "unet"), |
|
|
safe_serialization=True |
|
|
) |
|
|
|
|
|
|
|
|
pt_filename = f"sd15_flow_lune_e{step//1000}_s{step}.pt" |
|
|
pt_path = os.path.join(checkpoint_path, pt_filename) |
|
|
|
|
|
torch.save({ |
|
|
"cfg": asdict(config), |
|
|
"student": unet.state_dict(), |
|
|
"opt": optimizer.state_dict(), |
|
|
"gstep": step |
|
|
}, pt_path) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"step": step, |
|
|
"trained_images": train_logs["trained_images"] |
|
|
} |
|
|
metadata_path = os.path.join(checkpoint_path, "trained_images.json") |
|
|
with open(metadata_path, "w") as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
print(f"✓ Checkpoint saved at step {step}") |
|
|
|
|
|
|
|
|
if config.upload_to_hub and hf_api is not None: |
|
|
try: |
|
|
hf_api.upload_file( |
|
|
path_or_fileobj=pt_path, |
|
|
path_in_repo=pt_filename, |
|
|
repo_id=config.hf_repo_id, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
hf_api.upload_folder( |
|
|
folder_path=os.path.join(checkpoint_path, "unet"), |
|
|
path_in_repo=f"checkpoint-{step:08}/unet", |
|
|
repo_id=config.hf_repo_id, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
hf_api.upload_file( |
|
|
path_or_fileobj=metadata_path, |
|
|
path_in_repo=f"checkpoint-{step:08}/trained_images.json", |
|
|
repo_id=config.hf_repo_id, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
print(f"✓ Uploaded to hub: {config.hf_repo_id}") |
|
|
except Exception as e: |
|
|
print(f"⚠ Upload failed: {e}") |
|
|
|
|
|
print("\nStarting training...") |
|
|
progress_bar = tqdm(range(0, config.max_train_steps)) |
|
|
|
|
|
for batch in train_dataloader: |
|
|
loss = get_prediction(batch, log_to=train_logs) |
|
|
t_writer.add_scalar("train/loss", loss.detach().item(), global_step) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 2.0) |
|
|
t_writer.add_scalar("train/grad_norm", grad_norm.detach().item(), global_step) |
|
|
|
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
progress_bar.update(1) |
|
|
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"}) |
|
|
global_step += 1 |
|
|
|
|
|
if global_step % 100 == 0: |
|
|
plot_logs(train_logs) |
|
|
t_writer.add_figure("train_loss", plt.gcf(), global_step) |
|
|
plt.close() |
|
|
|
|
|
if global_step % config.checkpointing_steps == 0: |
|
|
save_checkpoint(global_step) |
|
|
|
|
|
if global_step >= config.max_train_steps: |
|
|
save_checkpoint(global_step) |
|
|
print("\n✅ Training complete!") |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
config = TrainConfig() |
|
|
train(config) |