Spaces:
Runtime error
Runtime error
| import os | |
| import math | |
| import wandb | |
| import random | |
| import logging | |
| import inspect | |
| import argparse | |
| import datetime | |
| import subprocess | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from safetensors import safe_open | |
| from typing import Dict, Optional, Tuple | |
| import torch | |
| import torchvision | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from torch.optim.swa_utils import AveragedModel | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import diffusers | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.models import UNet2DConditionModel | |
| from diffusers.pipelines import StableDiffusionPipeline | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.utils import check_min_version | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from animatediff.models.resnet import InflatedConv3d | |
| import transformers | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from animatediff.data.dataset_web import WebVid10M | |
| from animatediff.models.unet import UNet3DConditionModel | |
| from animatediff.pipelines.pipeline_animation import AnimationPipeline | |
| from animatediff.pipelines.validation_pipeline import ValidationPipeline | |
| from animatediff.utils.util import save_videos_grid, zero_rank_print, prepare_mask_coef, prepare_mask_coef_by_score | |
| def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs): | |
| """Initializes distributed environment.""" | |
| if launcher == 'pytorch': | |
| rank = int(os.environ['RANK']) | |
| num_gpus = torch.cuda.device_count() | |
| local_rank = rank % num_gpus | |
| torch.cuda.set_device(local_rank) | |
| dist.init_process_group(backend=backend, **kwargs) | |
| elif launcher == 'slurm': | |
| proc_id = int(os.environ['SLURM_PROCID']) | |
| ntasks = int(os.environ['SLURM_NTASKS']) | |
| node_list = os.environ['SLURM_NODELIST'] | |
| num_gpus = torch.cuda.device_count() | |
| local_rank = proc_id % num_gpus | |
| torch.cuda.set_device(local_rank) | |
| addr = subprocess.getoutput( | |
| f'scontrol show hostname {node_list} | head -n1') | |
| os.environ['MASTER_ADDR'] = addr | |
| os.environ['WORLD_SIZE'] = str(ntasks) | |
| os.environ['RANK'] = str(proc_id) | |
| port = os.environ.get('PORT', port) | |
| os.environ['MASTER_PORT'] = str(port) | |
| dist.init_process_group(backend=backend) | |
| zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}") | |
| else: | |
| raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!') | |
| return local_rank | |
| def main( | |
| image_finetune: bool, | |
| name: str, | |
| use_wandb: bool, | |
| launcher: str, | |
| output_dir: str, | |
| pretrained_model_path: str, | |
| train_data: Dict, | |
| validation_data: Dict, | |
| cfg_random_null_text: bool = True, | |
| cfg_random_null_text_ratio: float = 0.1, | |
| unet_checkpoint_path: str = "", | |
| unet_additional_kwargs: Dict = {}, | |
| ema_decay: float = 0.9999, | |
| noise_scheduler_kwargs = None, | |
| max_train_epoch: int = -1, | |
| max_train_steps: int = 100, | |
| validation_steps: int = 100, | |
| validation_steps_tuple: Tuple = (-1,), | |
| learning_rate: float = 3e-5, | |
| scale_lr: bool = False, | |
| lr_warmup_steps: int = 0, | |
| lr_scheduler: str = "constant", | |
| trainable_modules: Tuple[str] = (None, ), | |
| num_workers: int = 32, | |
| train_batch_size: int = 1, | |
| adam_beta1: float = 0.9, | |
| adam_beta2: float = 0.999, | |
| adam_weight_decay: float = 1e-2, | |
| adam_epsilon: float = 1e-08, | |
| max_grad_norm: float = 1.0, | |
| gradient_accumulation_steps: int = 32, | |
| gradient_checkpointing: bool = False, | |
| checkpointing_epochs: int = 5, | |
| checkpointing_steps: int = -1, | |
| mixed_precision_training: bool = True, | |
| enable_xformers_memory_efficient_attention: bool = True, | |
| statistic: list = [1, 40], | |
| global_seed: int = 42, | |
| is_debug: bool = False, | |
| mask_frame: list = [0], | |
| pretrained_motion_module_path: str = '', | |
| pretrained_sd_path: str = '', | |
| mask_sim_range: list = [0.2, 1.0], | |
| ): | |
| check_min_version("0.10.0.dev0") | |
| # Initialize distributed training | |
| local_rank = init_dist(launcher=launcher) | |
| global_rank = dist.get_rank() | |
| num_processes = dist.get_world_size() | |
| is_main_process = global_rank == 0 | |
| seed = global_seed + global_rank | |
| torch.manual_seed(seed) | |
| # Logging folder | |
| folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") | |
| output_dir = os.path.join(output_dir, folder_name) | |
| if is_debug and os.path.exists(output_dir): | |
| os.system(f"rm -rf {output_dir}") | |
| *_, config = inspect.getargvalues(inspect.currentframe()) | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| filemode='a', | |
| filename='train_v2_2.log', | |
| ) | |
| if is_main_process and (not is_debug) and use_wandb: | |
| run = wandb.init(project="image2video", name=folder_name, config=config) | |
| # Handle the output folder creation | |
| if is_main_process: | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(f"{output_dir}/samples", exist_ok=True) | |
| os.makedirs(f"{output_dir}/sanity_check", exist_ok=True) | |
| os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) | |
| OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) | |
| # Load scheduler, tokenizer and models. | |
| noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
| if not image_finetune: | |
| unet = UNet3DConditionModel.from_pretrained_2d( | |
| pretrained_model_path, subfolder="unet", | |
| unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) | |
| ) | |
| else: | |
| unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") | |
| # Load pretrained unet weights | |
| if unet_checkpoint_path != "": | |
| zero_rank_print(f"from checkpoint: {unet_checkpoint_path}") | |
| unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu") | |
| if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}") | |
| state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path | |
| m, u = unet.load_state_dict(state_dict, strict=False) | |
| zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") | |
| #assert len(u) == 0 | |
| old_weights = unet.conv_in.weight | |
| old_bias = unet.conv_in.bias | |
| new_conv1 = InflatedConv3d(9, old_weights.shape[0], kernel_size=unet.conv_in.kernel_size, stride=unet.conv_in.stride, padding=unet.conv_in.padding, bias=True if old_bias is not None else False) | |
| param = torch.zeros((320,5,3,3),requires_grad=True) | |
| new_conv1.weight = torch.nn.Parameter(torch.cat((old_weights,param),dim=1)) | |
| if old_bias is not None: | |
| new_conv1.bias = old_bias | |
| unet.conv_in = new_conv1 | |
| unet.config["in_channels"] = 9 | |
| # Load webvid-Pretrained sd | |
| '''webvid_sd_ckpt = torch.load(pretrained_sd_path) | |
| unet.load_state_dict(webvid_sd_ckpt, strict=False) | |
| vae.load_state_dict(webvid_sd_ckpt, strict=False) | |
| print('Webvid_pretrained sd loaded')''' | |
| # Freeze vae and text_encoder | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| # Set unet trainable parameters | |
| unet.requires_grad_(False) | |
| for name, param in unet.named_parameters(): | |
| for trainable_module_name in trainable_modules: | |
| if trainable_module_name in name: | |
| logging.info(f'{name} is trainable \n') | |
| #print(f'{name} is trainable') | |
| param.requires_grad = True | |
| break | |
| # Load pre-trained motion module | |
| unet_state_dict = unet.state_dict().keys() | |
| pretrained_motion_module = torch.load(pretrained_motion_module_path) | |
| for (name, param) in zip(pretrained_motion_module.keys(), pretrained_motion_module.values()): | |
| if name in unet_state_dict: | |
| unet.state_dict()[name].copy_(param) | |
| #print(f"{name} weight replace") | |
| trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) | |
| optimizer = torch.optim.AdamW( | |
| trainable_params, | |
| lr=learning_rate, | |
| betas=(adam_beta1, adam_beta2), | |
| weight_decay=adam_weight_decay, | |
| eps=adam_epsilon, | |
| ) | |
| if is_main_process: | |
| zero_rank_print(f"trainable params number: {len(trainable_params)}") | |
| zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") | |
| # Enable xformers | |
| if enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| # Enable gradient checkpointing | |
| if gradient_checkpointing: | |
| unet.enable_gradient_checkpointing() | |
| # Move models to GPU | |
| vae.to(local_rank) | |
| text_encoder.to(local_rank) | |
| # Get the training dataset | |
| train_dataset = WebVid10M(**train_data, is_image=image_finetune) | |
| distributed_sampler = DistributedSampler( | |
| train_dataset, | |
| num_replicas=num_processes, | |
| rank=global_rank, | |
| shuffle=True, | |
| seed=global_seed, | |
| ) | |
| # DataLoaders creation: | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=train_batch_size, | |
| shuffle=False, | |
| sampler=distributed_sampler, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| # Get the training iteration | |
| if max_train_steps == -1: | |
| assert max_train_epoch != -1 | |
| max_train_steps = max_train_epoch * len(train_dataloader) | |
| if checkpointing_steps == -1: | |
| assert checkpointing_epochs != -1 | |
| checkpointing_steps = checkpointing_epochs * len(train_dataloader) | |
| if scale_lr: | |
| learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes) | |
| # Scheduler | |
| lr_scheduler = get_scheduler( | |
| lr_scheduler, | |
| optimizer=optimizer, | |
| num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
| num_training_steps=max_train_steps * gradient_accumulation_steps, | |
| ) | |
| # Validation pipeline | |
| if not image_finetune: | |
| validation_pipeline = ValidationPipeline( | |
| unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, | |
| ).to(local_rank) | |
| else: | |
| validation_pipeline = ValidationPipeline( | |
| unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, | |
| ).to(local_rank) | |
| validation_pipeline.enable_vae_slicing() | |
| # DDP warpper | |
| unet.to(local_rank) | |
| unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) | |
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | |
| # Afterwards we recalculate our number of training epochs | |
| num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
| # Train! | |
| total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps | |
| if is_main_process: | |
| logging.info("***** Running training *****") | |
| logging.info(f" Num examples = {len(train_dataset)}") | |
| logging.info(f" Num Epochs = {num_train_epochs}") | |
| logging.info(f" Instantaneous batch size per device = {train_batch_size}") | |
| logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
| logging.info(f" Total optimization steps = {max_train_steps}") | |
| global_step = 0 | |
| first_epoch = 0 | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process) | |
| progress_bar.set_description("Steps") | |
| # Support mixed-precision training | |
| scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None | |
| motion_module_trainable = False | |
| for epoch in range(first_epoch, num_train_epochs): | |
| train_dataloader.sampler.set_epoch(epoch) | |
| unet.train() | |
| for step, batch in enumerate(train_dataloader): | |
| if cfg_random_null_text: | |
| batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']] | |
| # Data batch sanity check | |
| if epoch == first_epoch and step == 0: | |
| pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] | |
| ### >>>> Training >>>> ### | |
| # Convert videos to latent space, sampling from video | |
| pixel_values = batch["pixel_values"].to(local_rank) | |
| video_length = pixel_values.shape[1] | |
| # scores (b f) cond_frames(b f) | |
| scores = batch['score'] | |
| scores = torch.stack([score for score in scores]) | |
| cond_frames = batch['cond_frames'] | |
| with torch.no_grad(): | |
| if not image_finetune: | |
| pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") | |
| latents = vae.encode(pixel_values).latent_dist | |
| latents = latents.sample() | |
| latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) | |
| else: | |
| latents = vae.encode(pixel_values).latent_dist | |
| latents = latents.sample() | |
| latents = latents * 0.18215 | |
| pixel_values = rearrange(pixel_values, "(b f) c h w -> b f c h w", f=video_length) | |
| pixel_values = pixel_values / 2. + 0.5 | |
| pixel_values*= 255 | |
| # Create Mask and Masked_image_latent | |
| # b c f h w | |
| mask = torch.zeros((latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])) | |
| masked_image = torch.zeros_like(latents) | |
| '''rand_mask = random.random() | |
| if rand_mask > 0.2: | |
| rand_frame = random.randint(0, video_length - 1) | |
| mask[:,:,rand_frame,:,:] = 1 | |
| for f in range(video_length): | |
| masked_image[:,:,f,:,:] = latents[:,:,rand_frame,:,:].clone() | |
| else: | |
| masked_image = torch.zeros_like(latents) | |
| mask = torch.zeros((latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4]))''' | |
| is_cond = random.random() | |
| rand_size = latents.shape[0] | |
| if is_cond > 0.2: | |
| for rs in range(rand_size): | |
| #rand_frame = random.randint(0, video_length - 1) | |
| video_shape = [pixel_values.shape[0], pixel_values.shape[1]] | |
| mask_coef = prepare_mask_coef_by_score(video_shape, cond_frame_idx=cond_frames, | |
| statistic=statistic, score=torch.tensor(scores).unsqueeze(0)) | |
| #mask_coef = prepare_mask_coef(video_length, rand_frame, mask_sim_range) | |
| #mask[:,:,rand_frame,:,:] = 1 | |
| for f in range(video_length): | |
| mask[rs,:,f,:,:] = mask_coef[rs, f] | |
| masked_image[rs,:,f,:,:] = latents[rs,:,cond_frames[rs],:,:].clone() | |
| else: | |
| masked_image = torch.zeros_like(latents) | |
| mask = torch.zeros((latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])) | |
| '''mask[:,:,0,:,:] = 1 | |
| for f in range(video_length): | |
| masked_image[:,:,f,:,:] = latents[:,:,0,:,:].clone()''' | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| # Sample a random timestep for each video | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | |
| timesteps = timesteps.long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| # Get the text embedding for conditioning | |
| with torch.no_grad(): | |
| prompt_ids = tokenizer( | |
| batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| ).input_ids.to(latents.device) | |
| encoder_hidden_states = text_encoder(prompt_ids)[0] | |
| # Get the target for loss depending on the prediction type | |
| if noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif noise_scheduler.config.prediction_type == "v_prediction": | |
| raise NotImplementedError | |
| else: | |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
| # Predict the noise residual and compute loss | |
| # Mixed-precision training | |
| with torch.cuda.amp.autocast(enabled=mixed_precision_training): | |
| model_pred = unet(noisy_latents, mask, masked_image, timesteps, encoder_hidden_states).sample | |
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
| loss = loss / gradient_accumulation_steps | |
| '''if (step + 1) % gradient_accumulation_steps == 0: | |
| optimizer.zero_grad()''' | |
| # Backpropagate, accumulate gradient | |
| if mixed_precision_training: | |
| scaler.scale(loss).backward() | |
| """ >>> gradient clipping >>> """ | |
| if (step + 1) % gradient_accumulation_steps == 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) | |
| # Calculate the gradient norm | |
| if (step + 1) % gradient_accumulation_steps == 0: | |
| if isinstance(unet.parameters(), torch.Tensor): | |
| params = [unet.parameters()] | |
| grads = [p.grad for p in params if p.grad is not None] | |
| else: | |
| grads = [p.grad for p in unet.parameters() if p.grad is not None] | |
| total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2.0) for g in grads]), 2.0) | |
| """ <<< gradient clipping <<< """ | |
| if (step + 1) % gradient_accumulation_steps == 0: | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| """ >>> gradient clipping >>> """ | |
| torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) | |
| # Calculate the gradient norm | |
| if (step + 1) % gradient_accumulation_steps == 0: | |
| if isinstance(unet.parameters(), torch.Tensor): | |
| params = [unet.parameters()] | |
| grads = [p.grad for p in params if p.grad is not None] | |
| else: | |
| grads = [p.grad for p in unet.parameters() if p.grad is not None] | |
| total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2.0) for g in grads]), 2.0) | |
| """ <<< gradient clipping <<< """ | |
| if (step + 1) % gradient_accumulation_steps == 0: | |
| optimizer.step() | |
| if (step + 1) % gradient_accumulation_steps == 0: | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(1 * gradient_accumulation_steps) | |
| global_step += 1 | |
| # Set motion module trainable TODO: Debug | |
| '''if (motion_module_trainable == False) and (step > motion_module_trainable_step) and ((step + 1) % gradient_accumulation_steps == 0): | |
| for name, param in unet.named_parameters(): | |
| if 'motion_modules.' in name: | |
| logging.info(f'{name} is trainable \n') | |
| #print(f'{name} is trainable') | |
| param.requires_grad = True | |
| zero_rank_print('motion module is trainable now!') | |
| motion_module_trainable = True''' | |
| ### <<<< Training <<<< ### | |
| # Wandb logging | |
| if is_main_process and (not is_debug) and use_wandb and ((step + 1) % gradient_accumulation_steps == 0): | |
| wandb.log({"gradient_norm": total_norm.item()}, step=global_step) | |
| # Save checkpoint and Periodically validation | |
| if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple): | |
| samples = [] | |
| generator = torch.Generator(device=latents.device) | |
| generator.manual_seed(global_seed) | |
| height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size | |
| width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size | |
| prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts | |
| for idx, prompt in enumerate(prompts): | |
| use_image = False | |
| if not image_finetune: | |
| if idx < 2: | |
| use_image = idx + 1 | |
| else: | |
| use_image = False | |
| sample = validation_pipeline( | |
| prompt, | |
| use_image = use_image, | |
| generator = generator, | |
| video_length = train_data.sample_n_frames, | |
| height = 512, | |
| width = 512, | |
| **validation_data, | |
| ).videos | |
| save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") | |
| samples.append(sample) | |
| else: | |
| sample = validation_pipeline( | |
| prompt, | |
| generator = generator, | |
| height = height, | |
| width = width, | |
| num_inference_steps = validation_data.get("num_inference_steps", 25), | |
| guidance_scale = validation_data.get("guidance_scale", 8.), | |
| ).images[0] | |
| sample = torchvision.transforms.functional.to_tensor(sample) | |
| samples.append(sample) | |
| if not image_finetune: | |
| samples = torch.concat(samples) | |
| save_path = f"{output_dir}/samples/sample-{global_step}.gif" | |
| save_videos_grid(samples, save_path) | |
| else: | |
| samples = torch.stack(samples) | |
| save_path = f"{output_dir}/samples/sample-{global_step}.png" | |
| torchvision.utils.save_image(samples, save_path, nrow=4) | |
| logging.info(f"Saved samples to {save_path}") | |
| save_path = os.path.join(output_dir, f"checkpoints") | |
| state_dict = { | |
| "epoch": epoch, | |
| "global_step": global_step, | |
| "state_dict": unet.state_dict(), | |
| } | |
| inpaint_ckpt = state_dict['state_dict'] | |
| trained_ckpt = {} | |
| for (key, value) in zip(inpaint_ckpt.keys(), inpaint_ckpt.values()): | |
| new_key = key.replace('module.', '') | |
| trained_ckpt[new_key] = value | |
| if step == len(train_dataloader) - 1: | |
| torch.save(trained_ckpt, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt")) | |
| else: | |
| torch.save(trained_ckpt, os.path.join(save_path, f"checkpoint{step+1}.ckpt")) | |
| logging.info(f"Saved state to {save_path} (global_step: {global_step})") | |
| logging.info(f"(global_step: {global_step}) loss: {loss.detach().item()}") | |
| logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
| progress_bar.set_postfix(**logs) | |
| if global_step >= max_train_steps: | |
| break | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, required=True) | |
| parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="slurm") | |
| parser.add_argument("--wandb", action="store_true", default=True) | |
| args = parser.parse_args() | |
| name = Path(args.config).stem | |
| config = OmegaConf.load(args.config) | |
| main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config) | |