Spaces:
Runtime error
Runtime error
| from typing import Tuple | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from matplotlib import cm | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import argparse | |
| import json | |
| import wandb | |
| from datasets import load_dataset | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from models.unet import UNet | |
| from dataset import SegmentationDataset, collate_fn | |
| from utils import get_transform, mask_transform, EMA | |
| from get_loss import get_composite_criterion | |
| from models.vit import ViTSegmentation | |
| from models.dino import DINOSegmentationModel | |
| color_map = cm.get_cmap('tab20', 18) | |
| fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255 | |
| def mask_to_color(mask: np.ndarray) -> np.ndarray: | |
| h, w = mask.shape | |
| color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
| for class_idx in range(18): | |
| color_mask[mask == class_idx] = fixed_colors[class_idx] | |
| return color_mask | |
| def create_combined_image( | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| y_pred: torch.Tensor, | |
| mean: list[float] = [0.485, 0.456, 0.406], | |
| std: list[float] = [0.229, 0.224, 0.225] | |
| ) -> np.ndarray: | |
| batch_size, _, height, width = x.shape | |
| combined_height = height * 3 | |
| combined_width = width * batch_size | |
| combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8) | |
| for i in range(batch_size): | |
| image = x[i].cpu().permute(1, 2, 0).numpy() | |
| image = (image * std + mean).clip(0, 1) | |
| image = (image * 255).astype(np.uint8) | |
| true_mask = y[i].cpu().numpy() | |
| true_mask_color = mask_to_color(true_mask) | |
| pred_mask = y_pred[i].cpu().numpy() | |
| pred_mask_color = mask_to_color(pred_mask) | |
| combined_image[:height, i * width:(i + 1) * width, :] = image | |
| combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color | |
| combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color | |
| return combined_image | |
| def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 18) -> Tuple[float, float, float, float, float, float]: | |
| pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1) | |
| target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1) | |
| class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float() | |
| tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float() | |
| fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float() | |
| fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float() | |
| tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float() | |
| overall_tp = tp.sum() | |
| overall_fp = fp.sum() | |
| overall_fn = fn.sum() | |
| overall_tn = tn.sum() | |
| precision = tp / (tp + fp).clamp(min=1e-8) | |
| recall = tp / (tp + fn).clamp(min=1e-8) | |
| accuracy = (tp + tn) / (tp + tn + fp + fn) | |
| macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item() | |
| macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item() | |
| macro_accuracy = accuracy.mean().item() | |
| micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item() | |
| micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item() | |
| global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item() | |
| return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train a model on human parsing dataset") | |
| parser.add_argument("--data-path", type=str, default="mattmdjaga/human_parsing_dataset", help="Path to the data") | |
| parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing") | |
| parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader") | |
| parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader") | |
| parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs") | |
| parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type") | |
| parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer") | |
| parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping") | |
| parser.add_argument("--logs-dir", type=str, default="dino-logs", help="Directory for saving logs") | |
| parser.add_argument("--model", type=str, default="dino", choices=["unet", "vit", "dino"], help="Model class name") | |
| parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses") | |
| parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision") | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps") | |
| parser.add_argument("--project-name", type=str, default="human_parsing_segmentation_ttk", help="WandB project name") | |
| parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights") | |
| parser.add_argument("--log-steps", type=int, default=400, help="Number of steps between logging training images and metrics") | |
| parser.add_argument("--seed", type=int, default=42, help="Value of the seed") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| set_seed(args.seed) | |
| accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision) | |
| with open(args.losses_path, "r") as fp: | |
| losses_config = json.load(fp) | |
| with accelerator.main_process_first(): | |
| logs_dir = Path(args.logs_dir) | |
| logs_dir.mkdir(exist_ok=True) | |
| wandb.init(project=args.project_name, dir=logs_dir) | |
| wandb.save(args.losses_path) | |
| optimizer_class = getattr(torch.optim, args.optimizer) | |
| if args.model == "unet": | |
| model = UNet().to(accelerator.device) | |
| optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) | |
| elif args.model == "vit": | |
| model = ViTSegmentation().to(accelerator.device) | |
| optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) | |
| elif args.model == "dino": | |
| model = DINOSegmentationModel().to(accelerator.device) | |
| optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate) | |
| else: | |
| raise NotImplementedError("Incorrect model name") | |
| transform = get_transform(model.mean, model.std) | |
| dataset = load_dataset(args.data_path, split="train") | |
| train_dataset = SegmentationDataset(dataset, train=True, transform=transform, target_transform=mask_transform) | |
| valid_dataset = SegmentationDataset(dataset, train=False, transform=transform, target_transform=mask_transform) | |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn) | |
| valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn) | |
| criterion = get_composite_criterion(losses_config) | |
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader)) | |
| model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler) | |
| best_accuracy = 0 | |
| print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M") | |
| print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M") | |
| train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA() | |
| for epoch in range(1, args.num_epochs + 1): | |
| model.train() | |
| print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M") | |
| exit() | |
| pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}") | |
| for index, (x, y) in enumerate(pbar): | |
| x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device) | |
| with accelerator.accumulate(model): | |
| with accelerator.autocast(): | |
| output = model(x) | |
| loss = criterion(output, y) | |
| accelerator.backward(loss) | |
| train_loss = loss.item() | |
| grad_norm = None | |
| _, y_pred = output.max(dim=1) | |
| train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y) | |
| if accelerator.sync_gradients: | |
| grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| if (index + 1) % args.log_steps == 0 and accelerator.is_main_process: | |
| images_to_log = [] | |
| combined_image = create_combined_image(x, y, y_pred) | |
| images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})")) | |
| wandb.log({"train_samples": images_to_log}) | |
| pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)}) | |
| log_data = { | |
| "train/epoch": epoch, | |
| "train/loss": train_loss, | |
| "train/macro_accuracy": train_macro_accuracy, | |
| "train/learning_rate": optimizer.param_groups[0]["lr"], | |
| "train/macro_precision": train_macro_precision, | |
| "train/macro_recall": train_macro_recall, | |
| "train/micro_precision": train_micro_precision, | |
| "train/micro_recall": train_micro_recall, | |
| "train/global_accuracy": train_global_accuracy, | |
| } | |
| if grad_norm is not None: | |
| log_data["train/grad_norm"] = grad_norm | |
| if accelerator.is_main_process: | |
| wandb.log(log_data) | |
| accelerator.wait_for_everyone() | |
| model.eval() | |
| valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 | |
| with torch.inference_mode(): | |
| pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}") | |
| for index, (x, y) in enumerate(valid_loader): | |
| x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device) | |
| output = model(x) | |
| _, y_pred = output.max(dim=1) | |
| if (index + 1) % args.log_steps == 0 and accelerator.is_main_process: | |
| images_to_log = [] | |
| combined_image = create_combined_image(x, y, y_pred) | |
| images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})")) | |
| wandb.log({"valid_samples": images_to_log}) | |
| valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y) | |
| valid_macro_precisions += valid_macro_precision | |
| valid_macro_recalls += valid_macro_recall | |
| valid_macro_accuracies += valid_macro_accuracy | |
| valid_micro_precisions += valid_micro_precision | |
| valid_micro_recalls += valid_micro_recall | |
| valid_global_accuracies += valid_global_accuracy | |
| loss = criterion(output, y) | |
| valid_loss += loss.item() | |
| valid_loss = valid_loss / len(valid_loader) | |
| valid_macro_accuracies = valid_macro_accuracies / len(valid_loader) | |
| valid_macro_precisions = valid_macro_precisions / len(valid_loader) | |
| valid_macro_recalls = valid_macro_recalls / len(valid_loader) | |
| valid_global_accuracies = valid_global_accuracies / len(valid_loader) | |
| valid_micro_precisions = valid_micro_precisions / len(valid_loader) | |
| valid_micro_recalls = valid_micro_recalls / len(valid_loader) | |
| accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}") | |
| if accelerator.is_main_process: | |
| wandb.log( | |
| { | |
| "val/epoch": epoch, | |
| "val/loss": valid_loss, | |
| "val/macro_accuracy": valid_macro_accuracies, | |
| "val/macro_precision": valid_macro_precisions, | |
| "val/macro_recall": valid_macro_recalls, | |
| "val/global_accuracy": valid_global_accuracies, | |
| "val/micro_precision": valid_micro_precisions, | |
| "val/micro_recall": valid_micro_recalls, | |
| } | |
| ) | |
| if valid_global_accuracies > best_accuracy: | |
| best_accuracy = valid_global_accuracies | |
| if args.model in ["dino", "vit"]: | |
| accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth") | |
| else: | |
| accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth") | |
| accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}") | |
| if epoch % args.save_frequency == 0: | |
| if args.model in ["dino", "vit"]: | |
| accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth") | |
| else: | |
| accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth") | |
| accelerator.wait_for_everyone() | |
| accelerator.wait_for_everyone() | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() |