Spaces:
Running
on
Zero
Running
on
Zero
| import os, pickle, random, json, os, base64, io | |
| import torch | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import numpy as np | |
| from glob import glob | |
| from tqdm import tqdm, trange | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from collections import defaultdict | |
| from concurrent import futures | |
| from pathlib import Path | |
| from accelerate import Accelerator | |
| from typing import Any, Callable, Optional, Union | |
| from warnings import warn | |
| from peft import LoraConfig, get_peft_model | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from modeling_flux_base import DefaultDDPOFluxPipeline | |
| from ddpo_flux_config import DDPOFluxConfig | |
| from transformers import is_wandb_available | |
| if is_wandb_available(): | |
| import wandb | |
| logger = get_logger(__name__) | |
| class DDPOTrainer_edit(PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| config: DDPOFluxConfig, | |
| reward_function: Callable[[], tuple[str, Any]], | |
| prompt_function: Callable[[], tuple[str, Any]], | |
| edit_pipeline: DefaultDDPOFluxPipeline, | |
| image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, | |
| ): | |
| if image_samples_hook is None: | |
| warn("No image_samples_hook provided; no images will be logged") | |
| self.prompt_fn = prompt_function | |
| self.reward_fn = reward_function | |
| self.config = config | |
| self.image_samples_callback = image_samples_hook | |
| accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) | |
| self.project_dir = accelerator_project_config.project_dir | |
| if self.config.resume_from: | |
| if self.config.resume_from == "latest": | |
| dirs = os.listdir(self.project_dir) | |
| dirs = [d for d in dirs if d.startswith("checkpoint_lora")] | |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | |
| if len(dirs) == 0: | |
| print(f"Checkpoint '{self.config.resume_from}' does not exist. Starting a new training run.") | |
| self.config.resume_from = "" | |
| path = dirs[-1] | |
| else: | |
| path = os.path.basename(self.config.resume_from) | |
| self.config.resume_from = os.path.join(self.project_dir, path) | |
| accelerator_project_config.iteration = int(path.split("-")[1])+1 | |
| # number of timesteps within each trajectory to train on | |
| self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction - 1) | |
| self.accelerator = Accelerator( | |
| log_with=self.config.log_with, | |
| mixed_precision=self.config.mixed_precision, | |
| project_config=accelerator_project_config, | |
| # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the | |
| # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get | |
| # the total number of optimizer steps to accumulate across. | |
| gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, | |
| **self.config.accelerator_kwargs, | |
| ) | |
| is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" | |
| if self.accelerator.is_main_process: | |
| self.accelerator.init_trackers( | |
| self.config.tracker_project_name, | |
| config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), | |
| init_kwargs=self.config.tracker_kwargs, | |
| ) | |
| is_okay, message = self._config_check() | |
| if not is_okay: | |
| raise ValueError(message) | |
| logger.info(f"\n{config}") | |
| set_seed(self.config.seed, device_specific=True) | |
| self.edit_pipeline = edit_pipeline | |
| self.edit_pipeline.set_progress_bar_config( | |
| position=1, | |
| disable=not self.accelerator.is_local_main_process, | |
| leave=False, | |
| desc="Timestep", | |
| dynamic_ncols=True, | |
| ) | |
| # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision | |
| # as these weights are only used for inference, keeping weights in full precision is not required. | |
| if self.accelerator.mixed_precision == "fp16": | |
| inference_dtype = torch.float16 | |
| elif self.accelerator.mixed_precision == "bf16": | |
| inference_dtype = torch.bfloat16 | |
| else: | |
| inference_dtype = torch.float32 | |
| self.edit_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) | |
| self.edit_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) | |
| self.edit_pipeline.text_encoder_2.to(self.accelerator.device, dtype=inference_dtype) | |
| lora_config = LoraConfig( | |
| r=self.config.lora_rank, | |
| lora_alpha=self.config.lora_alpha, | |
| init_lora_weights="gaussian", | |
| target_modules=["to_k", "to_q", "to_v", "to_out.0"], | |
| ) | |
| self.edit_pipeline.flux_pipeline.transformer.requires_grad_(False) | |
| self.edit_pipeline.flux_pipeline.transformer = get_peft_model(self.edit_pipeline.flux_pipeline.transformer, lora_config) | |
| trainable_params = [p for p in list(self.edit_pipeline.flux_pipeline.transformer.parameters()) if p.requires_grad] | |
| total_params = sum(p.numel() for p in trainable_params) | |
| self.optimizer = torch.optim.AdamW( | |
| trainable_params, | |
| lr=self.config.train_learning_rate, | |
| betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), | |
| weight_decay=self.config.train_adam_weight_decay, | |
| eps=self.config.train_adam_epsilon, | |
| ) | |
| ( | |
| self.negative_prompt_embeds, | |
| self.negative_pooled_prompt_embeds, | |
| self.negative_text_ids, | |
| ) = self.edit_pipeline.flux_pipeline.encode_prompt( | |
| prompt=[""] if self.config.negative_prompts is None else self.config.negative_prompts, | |
| prompt_2=[""], | |
| device=self.accelerator.device, | |
| ) | |
| # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | |
| # more memory | |
| self.autocast = self.edit_pipeline.autocast or self.accelerator.autocast | |
| if self.config.resume_from: | |
| print(f"Resuming from {self.config.resume_from}") | |
| logger.info(f"Resuming from {self.config.resume_from}") | |
| self.edit_pipeline.flux_pipeline.transformer.load_adapter(self.config.resume_from, adapter_name="default", is_trainable=True) | |
| self.edit_pipeline.flux_pipeline.transformer.train() | |
| self.first_epoch = accelerator_project_config.iteration | |
| else: | |
| self.first_epoch = 0 | |
| self.edit_pipeline.flux_pipeline.transformer, self.optimizer = self.accelerator.prepare(self.edit_pipeline.flux_pipeline.transformer, self.optimizer) | |
| self.trainable_layers = list(filter(lambda p: p.requires_grad, self.edit_pipeline.flux_pipeline.transformer.parameters())) | |
| self.executor = futures.ThreadPoolExecutor(max_workers=self.config.max_workers)#config.max_workers | |
| def compute_rewards(self, prompt_image_pairs): | |
| all_rewards = [] | |
| all_meta_data = [] | |
| for img, prompt, raw_img, img_path in prompt_image_pairs: | |
| data_pair_vllm = [] | |
| for idx in range(len(img)): | |
| data_pair_vllm.append((raw_img[idx][0],raw_img[idx][1], prompt[idx], img[idx])) | |
| # rewards = self.executor.map(lambda x: self.reward_fn(*x), data_pair_vllm) | |
| # -------- submit + as_completed -------- | |
| fut_to_idx = { | |
| self.executor.submit(self.reward_fn, *triple): idx | |
| for idx, triple in enumerate(data_pair_vllm) | |
| } | |
| # Collect results in original order | |
| rewards = [None] * len(data_pair_vllm) | |
| for fut in futures.as_completed(fut_to_idx): | |
| idx = fut_to_idx[fut] | |
| rewards[idx] = fut.result() | |
| rewards_ = [torch.as_tensor(reward, device=self.accelerator.device) for reward, reward_metadata in rewards] | |
| rewards_ = torch.stack(rewards_) | |
| all_rewards.append(rewards_) | |
| all_meta_data.append(img_path) | |
| return all_rewards, all_meta_data | |
| def step(self, epoch: int, global_step: int): | |
| """ | |
| Perform a single step of training. | |
| Args: | |
| epoch (int): The current epoch. | |
| global_step (int): The current global step. | |
| Side Effects: | |
| - Model weights are updated | |
| - Logs the statistics to the accelerator trackers. | |
| - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, | |
| and the accelerator tracker. | |
| Returns: | |
| global_step (int): The updated global step. | |
| """ | |
| samples, prompt_image_data = self._generate_samples( | |
| iterations=self.config.sample_num_batches_per_epoch, | |
| batch_size=self.config.sample_batch_size, | |
| ) | |
| # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | |
| local_rank = self.accelerator.local_process_index | |
| samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | |
| rewards, rewards_metadata = self.compute_rewards(prompt_image_data) | |
| for i, image_data in enumerate(prompt_image_data): | |
| image_data.extend([rewards[i], rewards_metadata[i]]) | |
| if self.image_samples_callback is not None and self.accelerator.is_main_process: | |
| self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) | |
| rewards = torch.cat(rewards) | |
| rewards = self.accelerator.gather(rewards).cpu().numpy() | |
| if self.accelerator.is_main_process: | |
| print(rewards.mean()) | |
| self.accelerator.log( | |
| { | |
| "reward": rewards, | |
| "epoch": epoch, | |
| "reward_mean": rewards.mean(), | |
| "reward_std": rewards.std(), | |
| }, | |
| step=global_step, | |
| ) | |
| advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | |
| # ungather advantages; keep the entries corresponding to the samples on this process | |
| samples["advantages"] = ( | |
| torch.as_tensor(advantages) | |
| .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] | |
| .to(self.accelerator.device) | |
| ) | |
| del samples["prompt_ids"] | |
| del samples["text_ids"] | |
| del samples["latent_ids"] | |
| del samples["negative_text_ids"] | |
| total_batch_size, num_timesteps = samples["timesteps"].shape | |
| self.accelerator.wait_for_everyone() | |
| for inner_epoch in range(self.config.train_num_inner_epochs): | |
| # shuffle samples along batch dimension | |
| perm = torch.randperm(total_batch_size, device=self.accelerator.device) | |
| samples = {k: v[perm] for k, v in samples.items()} | |
| # shuffle along time dimension independently for each sample | |
| # still trying to understand the code below | |
| perms = torch.stack( | |
| [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] | |
| ) | |
| for key in ["timesteps", "latents", "next_latents", "log_probs"]: | |
| samples[key] = samples[key][ | |
| torch.arange(total_batch_size, device=self.accelerator.device)[:, None], | |
| perms, | |
| ] | |
| original_keys = samples.keys() | |
| original_values = samples.values() | |
| # rebatch them as user defined train_batch_size is different from sample_batch_size | |
| reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] | |
| # Transpose the list of original values | |
| transposed_values = zip(*reshaped_values) | |
| # Create new dictionaries for each row of transposed values | |
| samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] | |
| self.edit_pipeline.transformer.train() | |
| global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) | |
| # ensure optimization step at the end of the inner epoch | |
| if not self.accelerator.sync_gradients: | |
| raise ValueError( | |
| "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." | |
| ) | |
| if self.accelerator.sync_gradients: | |
| if self.accelerator.is_main_process: | |
| print("Save checkpoint on epoch", epoch) | |
| save_model = self.edit_pipeline.flux_pipeline.transformer | |
| unwrapped_model = self.accelerator.unwrap_model(save_model) | |
| unwrapped_model.save_pretrained( | |
| f"{self.project_dir}/checkpoint_lora-{epoch}", | |
| is_main_process=self.accelerator.is_main_process, | |
| save_function=self.accelerator.save, | |
| state_dict=self.accelerator.get_state_dict(save_model), | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| return global_step, rewards.mean() | |
| def calculate_loss(self, latents, image_latents, timestep, next_latents, log_probs, advantages, pooled_prompt_embeds, prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_embeds): | |
| """ | |
| Calculate the loss for a batch of an unpacked sample | |
| Args: | |
| latents (torch.Tensor): | |
| The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] | |
| timesteps (torch.Tensor): | |
| The timesteps sampled from the diffusion model, shape: [batch_size] | |
| next_latents (torch.Tensor): | |
| The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, | |
| width] | |
| log_probs (torch.Tensor): | |
| The log probabilities of the latents, shape: [batch_size] | |
| advantages (torch.Tensor): | |
| The advantages of the latents, shape: [batch_size] | |
| embeds (torch.Tensor): | |
| The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] Note: the "or" is because if | |
| train_cfg is True, the expectation is that negative prompts are concatenated to the embeds | |
| Returns: | |
| loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) (all of these are of shape (1,)) | |
| """ | |
| torch.autograd.set_detect_anomaly(True) | |
| with self.autocast(): | |
| latent_model_input = torch.cat([latents, image_latents], dim=1) | |
| latent_model_input = latent_model_input.detach() | |
| pooled_prompt_embeds = pooled_prompt_embeds.detach() | |
| prompt_embeds = prompt_embeds.detach() | |
| guidance = torch.full([1], self.config.sample_guidance, device=self.edit_pipeline.transformer.device, dtype=torch.bfloat16) | |
| guidance = guidance.expand(latent_model_input.shape[0]) | |
| noise_pred = self.edit_pipeline.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep.detach() / 1000, | |
| guidance=guidance.detach(), | |
| pooled_projections=pooled_prompt_embeds, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=self.text_ids.detach(), | |
| img_ids=self.latent_ids.detach(), | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_pred[:, : latents.size(1)] | |
| if self.config.train_cfg: | |
| neg_noise_pred = self.edit_pipeline.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep / 1000, | |
| guidance=self.config.sample_guidance, | |
| pooled_projections=negative_pooled_prompt_embeds, | |
| encoder_hidden_states=negative_prompt_embeds, | |
| txt_ids=self.negative_text_ids, | |
| img_ids=self.latent_ids, | |
| return_dict=False, | |
| )[0] | |
| neg_noise_pred = neg_noise_pred[:, : latents.size(1)] | |
| noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) | |
| # compute the log prob of next_latents given latents under the current model | |
| scheduler_step_output = self.edit_pipeline.scheduler.step( | |
| noise_pred, | |
| timestep.detach(), | |
| latents.detach(), | |
| prev_sample=next_latents.detach(), | |
| return_dict=True, | |
| init_step=True, | |
| ) | |
| log_prob = scheduler_step_output.log_probs | |
| advantages = torch.clamp( | |
| advantages, | |
| -self.config.train_adv_clip_max, | |
| self.config.train_adv_clip_max, | |
| ) | |
| ratio = torch.exp(log_prob - log_probs) | |
| loss = self.loss(advantages, self.config.train_clip_range, ratio) | |
| approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) | |
| clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) | |
| return loss, approx_kl, clipfrac | |
| def loss( | |
| self, | |
| advantages: torch.Tensor, | |
| clip_range: float, | |
| ratio: torch.Tensor, | |
| ): | |
| unclipped_loss = -advantages * ratio | |
| clipped_loss = -advantages * torch.clamp( | |
| ratio, | |
| 1.0 - clip_range, | |
| 1.0 + clip_range, | |
| ) | |
| return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) | |
| def _generate_samples(self, iterations, batch_size): | |
| """ | |
| Generate samples from the model | |
| Args: | |
| iterations (int): Number of iterations to generate samples for | |
| batch_size (int): Batch size to use for sampling | |
| Returns: | |
| samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]]) | |
| """ | |
| samples = [] | |
| prompt_image_pairs = [] | |
| self.edit_pipeline.transformer.eval() | |
| sample_neg_prompt_embeds = self.negative_prompt_embeds.repeat(batch_size, 1, 1) | |
| sample_neg_pooled_prompt_embeds = self.negative_pooled_prompt_embeds.repeat(batch_size, 1) | |
| sample_neg_text_ids = self.negative_text_ids | |
| for iters in range(iterations): | |
| prompts, raw_images, img_paths = map(list, zip(*[self.prompt_fn('multi') for _ in range(batch_size)])) | |
| if len(raw_images) == batch_size: | |
| raw_images = list(map(list, zip(*raw_images))) | |
| ( | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| text_ids, | |
| ) = self.edit_pipeline.flux_pipeline.encode_prompt( | |
| prompt=prompts, | |
| prompt_2=prompts, | |
| device=self.accelerator.device, | |
| ) | |
| prompt_ids = self.edit_pipeline.tokenizer( | |
| prompts, | |
| padding="max_length", | |
| max_length=self.edit_pipeline.flux_pipeline.tokenizer_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids.to(self.accelerator.device) | |
| generator = torch.Generator(device='cuda') | |
| generator.seed() | |
| with self.autocast(): | |
| with torch.no_grad(): | |
| edit_output = self.edit_pipeline( | |
| image=raw_images, | |
| height=self.config.height, | |
| width=self.config.width, | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_prompt_embeds=sample_neg_prompt_embeds, | |
| negative_pooled_prompt_embeds=sample_neg_pooled_prompt_embeds, | |
| num_inference_steps=self.config.sample_num_steps, | |
| guidance_scale=self.config.sample_guidance, | |
| generator=generator, | |
| output_type="pt", | |
| max_area=self.config.max_size**2, | |
| ) | |
| images = edit_output.images | |
| latents = edit_output.latents | |
| log_probs = edit_output.log_probs | |
| timesteps = edit_output.timesteps | |
| latent_ids = edit_output.latent_ids | |
| image_latents = edit_output.image_latents | |
| latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) | |
| log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) | |
| timesteps = torch.stack(timesteps, dim=1) | |
| samples.append( | |
| { | |
| "prompt_ids": prompt_ids.float(), | |
| "timesteps": timesteps[:, :-1], | |
| "latents": latents[:, :-2], # each entry is the latent before timestep t | |
| "next_latents": latents[:, 1:-1], # each entry is the latent after timestep t | |
| "log_probs": log_probs[:, :-1], | |
| "pooled_prompt_embeds":pooled_prompt_embeds, | |
| "prompt_embeds":prompt_embeds, | |
| "negative_prompt_embeds":sample_neg_prompt_embeds, | |
| "negative_pooled_prompt_embeds":sample_neg_pooled_prompt_embeds, | |
| "text_ids":text_ids, | |
| "latent_ids":latent_ids, | |
| "negative_text_ids":sample_neg_text_ids, | |
| "image_latents":image_latents, | |
| } | |
| ) | |
| raw_images = [list(x) for x in zip(*raw_images)] | |
| prompt_image_pairs.append([images, prompts, raw_images, img_paths]) | |
| local_rank = self.accelerator.local_process_index | |
| self.text_ids = samples[0]['text_ids'] | |
| self.latent_ids = samples[0]['latent_ids'] | |
| self.negative_text_ids = samples[0]['negative_text_ids'] | |
| return samples, prompt_image_pairs | |
| def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): | |
| """ | |
| Train on a batch of samples. Main training segment | |
| Args: | |
| inner_epoch (int): The current inner epoch | |
| epoch (int): The current epoch | |
| global_step (int): The current global step | |
| batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on | |
| Side Effects: | |
| - Model weights are updated | |
| - Logs the statistics to the accelerator trackers. | |
| Returns: | |
| global_step (int): The updated global step | |
| """ | |
| info = defaultdict(list) | |
| for _i, sample in enumerate(batched_samples): | |
| for j in trange(self.num_train_timesteps): | |
| with self.accelerator.accumulate(self.edit_pipeline.transformer): | |
| loss, approx_kl, clipfrac = self.calculate_loss( | |
| sample["latents"][:, j], | |
| sample["image_latents"], | |
| sample["timesteps"][:, j], | |
| sample["next_latents"][:, j], | |
| sample["log_probs"][:, j], | |
| sample["advantages"], | |
| sample["pooled_prompt_embeds"], | |
| sample["prompt_embeds"], | |
| sample["negative_pooled_prompt_embeds"], | |
| sample["negative_prompt_embeds"], | |
| ) | |
| info["approx_kl"].append(approx_kl) | |
| info["clipfrac"].append(clipfrac) | |
| info["loss"].append(loss) | |
| self.accelerator.backward(loss) | |
| if self.accelerator.sync_gradients: | |
| self.accelerator.clip_grad_norm_( | |
| self.trainable_layers.parameters() | |
| if not isinstance(self.trainable_layers, list) | |
| else self.trainable_layers, | |
| self.config.train_max_grad_norm, | |
| ) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if self.accelerator.sync_gradients: | |
| # log training-related stuff | |
| info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} | |
| info = self.accelerator.reduce(info, reduction="mean") | |
| info.update({"epoch": epoch, "inner_epoch": inner_epoch}) | |
| self.accelerator.log(info, step=global_step) | |
| global_step += 1 | |
| info = defaultdict(list) | |
| return global_step | |
| def _config_check(self) -> tuple[bool, str]: | |
| samples_per_epoch = ( | |
| self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch | |
| ) | |
| total_train_batch_size = ( | |
| self.config.train_batch_size | |
| * self.accelerator.num_processes | |
| * self.config.train_gradient_accumulation_steps | |
| ) | |
| if not self.config.sample_batch_size >= self.config.train_batch_size: | |
| return ( | |
| False, | |
| f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", | |
| ) | |
| if not self.config.sample_batch_size % self.config.train_batch_size == 0: | |
| return ( | |
| False, | |
| f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", | |
| ) | |
| if not samples_per_epoch % total_train_batch_size == 0: | |
| return ( | |
| False, | |
| f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", | |
| ) | |
| return True, "" | |
| def train(self, epochs: Optional[int] = None): | |
| """ | |
| Train the model for a given number of epochs | |
| """ | |
| global_step = 0 | |
| rewards_curve = [] | |
| if epochs is None: | |
| epochs = self.config.num_epochs | |
| for epoch in range(self.first_epoch, epochs): | |
| global_step, reward_mean = self.step(epoch, global_step) | |
| rewards_curve.append(reward_mean) | |
| return rewards_curve | |