Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.cuda.amp import autocast | |
| import numpy as np | |
| import time | |
| import os | |
| import yaml | |
| from matplotlib import pyplot as plt | |
| import glob | |
| from collections import OrderedDict | |
| from tqdm import tqdm | |
| import torch.distributed as dist | |
| import pandas as pd | |
| import xgboost as xgb | |
| from sklearn.metrics import accuracy_score, classification_report, roc_auc_score | |
| from torch.nn import ModuleList | |
| # from inr import INR | |
| # from kan import FasterKAN | |
| class Trainer(object): | |
| """ | |
| A class that encapsulates the training loop for a PyTorch model. | |
| """ | |
| def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2, | |
| scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None, | |
| grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None, | |
| cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1, | |
| update_func=lambda x: x): | |
| self.model = model | |
| self.optimizer = optimizer | |
| self.criterion = criterion | |
| self.scaler = scaler | |
| self.grad_clip = grad_clip | |
| self.cos_inc = cos_inc | |
| self.output_dim = output_dim | |
| self.scheduler = scheduler | |
| self.train_dl = train_dataloader | |
| self.val_dl = val_dataloader | |
| self.train_sampler = self.get_sampler_from_dataloader(train_dataloader) | |
| self.val_sampler = self.get_sampler_from_dataloader(val_dataloader) | |
| self.max_iter = max_iter | |
| self.device = device | |
| self.world_size = world_size | |
| self.exp_num = exp_num | |
| self.exp_name = exp_name | |
| self.log_path = log_path | |
| self.best_state_dict = None | |
| self.plot_every = plot_every | |
| self.logger = None | |
| self.range_update = range_update | |
| self.accumulation_step = accumulation_step | |
| self.wandb = wandb_log | |
| self.num_quantiles = num_quantiles | |
| self.update_func = update_func | |
| # if log_path is not None: | |
| # self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}') | |
| # # print(f"logger path: {self.log_path}/exp{self.exp_num}") | |
| # print("logger is: ", self.logger) | |
| def get_sampler_from_dataloader(self, dataloader): | |
| if hasattr(dataloader, 'sampler'): | |
| if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): | |
| return dataloader.sampler | |
| elif hasattr(dataloader.sampler, 'sampler'): | |
| return dataloader.sampler.sampler | |
| if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'): | |
| return dataloader.batch_sampler.sampler | |
| return None | |
| def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False): | |
| """ | |
| Fits the model for the given number of epochs. | |
| """ | |
| min_loss = np.inf | |
| best_acc = 0 | |
| train_loss, val_loss, = [], [] | |
| train_acc, val_acc = [], [] | |
| lrs = [] | |
| # self.optim_params['lr_history'] = [] | |
| epochs_without_improvement = 0 | |
| # main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu' | |
| main_proccess = True # change in a ddp setting | |
| print(f"Starting training for {num_epochs} epochs") | |
| print("is main process: ", main_proccess, flush=True) | |
| global_time = time.time() | |
| self.epoch = 0 | |
| for epoch in range(num_epochs): | |
| self.epoch = epoch | |
| start_time = time.time() | |
| plot = (self.plot_every is not None) and (epoch % self.plot_every == 0) | |
| t_loss, t_acc = self.train_epoch(device, epoch=epoch) | |
| t_loss_mean = np.nanmean(t_loss) | |
| train_loss.extend(t_loss) | |
| global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean) | |
| if main_proccess: # Only perform this on the master GPU | |
| train_acc.append(global_train_accuracy.mean().item()) | |
| v_loss, v_acc = self.eval_epoch(device, epoch=epoch) | |
| v_loss_mean = np.nanmean(v_loss) | |
| val_loss.extend(v_loss) | |
| global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean) | |
| if main_proccess: # Only perform this on the master GPU | |
| val_acc.append(global_val_accuracy.mean().item()) | |
| current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean() | |
| improved = False | |
| if best == 'loss': | |
| if current_objective < min_loss: | |
| min_loss = current_objective | |
| improved = True | |
| else: | |
| if current_objective > best_acc: | |
| best_acc = current_objective | |
| improved = True | |
| if improved: | |
| model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth' | |
| print(f"saving model at {model_name}...") | |
| torch.save(self.model.state_dict(), model_name) | |
| self.best_state_dict = self.model.state_dict() | |
| epochs_without_improvement = 0 | |
| else: | |
| epochs_without_improvement += 1 | |
| current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \ | |
| else self.scheduler.get_last_lr()[0] | |
| lrs.append(current_lr) | |
| print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\ | |
| f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\ | |
| f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\ | |
| f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True) | |
| if epoch % 10 == 0: | |
| print(os.system('nvidia-smi')) | |
| if epochs_without_improvement == early_stopping: | |
| print('early stopping!', flush=True) | |
| break | |
| if time.time() - global_time > (23.83 * 3600): | |
| print("time limit reached") | |
| break | |
| return {"num_epochs":num_epochs, "train_loss": train_loss, | |
| "val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs} | |
| def process_loss(self, acc, loss_mean): | |
| if torch.cuda.is_available() and torch.distributed.is_initialized(): | |
| global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU | |
| torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU | |
| torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| # Divide both loss and accuracy by world size | |
| world_size = torch.distributed.get_world_size() | |
| global_loss /= world_size | |
| global_accuracy /= world_size | |
| else: | |
| global_loss = torch.tensor(loss_mean) | |
| global_accuracy = torch.tensor(acc) | |
| return global_accuracy, global_loss | |
| def load_best_model(self, to_ddp=True, from_ddp=True): | |
| data_dir = f'{self.log_path}/exp{self.exp_num}' | |
| # data_dir = f'{self.log_path}/exp29' # for debugging | |
| state_dict_files = glob.glob(data_dir + '/*.pth') | |
| print("loading model from ", state_dict_files[-1]) | |
| state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device) | |
| if from_ddp: | |
| print("loading distributed model") | |
| # Remove "module." from keys | |
| new_state_dict = OrderedDict() | |
| for key, value in state_dict.items(): | |
| if key.startswith('module.'): | |
| while key.startswith('module.'): | |
| key = key[7:] | |
| new_state_dict[key] = value | |
| state_dict = new_state_dict | |
| # print("state_dict: ", state_dict.keys()) | |
| # print("model: ", self.model.state_dict().keys()) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| def check_gradients(self): | |
| for name, param in self.model.named_parameters(): | |
| if param.grad is not None: | |
| grad_norm = param.grad.norm().item() | |
| if grad_norm > 10: | |
| print(f"Large gradient in {name}: {grad_norm}") | |
| def train_epoch(self, device, epoch): | |
| """ | |
| Trains the model for one epoch. | |
| """ | |
| if self.train_sampler is not None: | |
| try: | |
| self.train_sampler.set_epoch(epoch) | |
| except AttributeError: | |
| pass | |
| self.model.train() | |
| train_loss = [] | |
| train_acc = 0 | |
| total = 0 | |
| all_accs = torch.zeros(self.output_dim, device=device) | |
| pbar = tqdm(self.train_dl) | |
| for i, batch in enumerate(pbar): | |
| if self.optimizer is not None: | |
| self.optimizer.zero_grad() | |
| loss, acc , y = self.train_batch(batch, i, device) | |
| train_loss.append(loss.item()) | |
| all_accs = all_accs + acc | |
| total += len(y) | |
| pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}") | |
| if i > self.max_iter: | |
| break | |
| print("number of train_accs: ", train_acc) | |
| return train_loss, all_accs/total | |
| def train_batch(self, batch, batch_idx, device): | |
| x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] | |
| features = batch['audio']['features'] | |
| # cwt = batch['audio']['cwt_mag'] | |
| x = x.to(device).float() | |
| fft = fft.to(device).float() | |
| # cwt = cwt.to(device).float() | |
| y = y.to(device).float() | |
| x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
| y_pred = self.model(x_fft, features).squeeze() | |
| loss = self.criterion(y_pred, y) | |
| loss.backward() | |
| self.optimizer.step() | |
| if self.scheduler is not None: | |
| self.scheduler.step() | |
| # get predicted classes | |
| probs = torch.sigmoid(y_pred) | |
| cls_pred = (probs > 0.5).float() | |
| acc = (cls_pred == y).sum() | |
| return loss, acc, y | |
| def eval_epoch(self, device, epoch): | |
| """ | |
| Evaluates the model for one epoch. | |
| """ | |
| self.model.eval() | |
| val_loss = [] | |
| val_acc = 0 | |
| total = 0 | |
| all_accs = torch.zeros(self.output_dim, device=device) | |
| pbar = tqdm(self.val_dl) | |
| for i,batch in enumerate(pbar): | |
| loss, acc, y = self.eval_batch(batch, i, device) | |
| val_loss.append(loss.item()) | |
| all_accs = all_accs + acc | |
| total += len(y) | |
| pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}") | |
| if i > self.max_iter: | |
| break | |
| return val_loss, all_accs/total | |
| def eval_batch(self, batch, batch_idx, device): | |
| x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] | |
| features = batch['audio']['features'] | |
| # features = batch['audio']['features_arr'].to(device).float() | |
| x = x.to(device).float() | |
| fft = fft.to(device).float() | |
| x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
| y = y.to(device).float() | |
| with torch.no_grad(): | |
| y_pred = self.model(x_fft, features).squeeze() | |
| loss = self.criterion(y_pred.squeeze(), y) | |
| probs = torch.sigmoid(y_pred) | |
| cls_pred = (probs > 0.5).float() | |
| acc = (cls_pred == y).sum() | |
| return loss, acc, y | |
| def predict(self, test_dataloader, device): | |
| """ | |
| Returns the predictions of the model on the given dataset. | |
| """ | |
| self.model.eval() | |
| total = 0 | |
| all_accs = 0 | |
| predictions = [] | |
| true_labels = [] | |
| pbar = tqdm(test_dataloader) | |
| for i,batch in enumerate(pbar): | |
| x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] | |
| x = x.to(device).float() | |
| fft = fft.to(device).float() | |
| x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
| y = y.to(device).float() | |
| with torch.no_grad(): | |
| y_pred = self.model(x_fft).squeeze() | |
| loss = self.criterion(y_pred, y) | |
| probs = torch.sigmoid(y_pred) | |
| cls_pred = (probs > 0.5).float() | |
| acc = (cls_pred == y).sum() | |
| predictions.extend(cls_pred.cpu().numpy()) | |
| true_labels.extend(y.cpu().numpy()) | |
| all_accs += acc | |
| total += len(y) | |
| pbar.set_description("acc: {:.4f}".format(acc)) | |
| if i > self.max_iter: | |
| break | |
| return predictions, true_labels, all_accs/total | |
| class INRDatabase: | |
| """Database to store and manage INRs persistently.""" | |
| def __init__(self, save_dir='./inr_database'): | |
| self.inrs = {} # Maps sample_id -> INR | |
| self.optimizers = {} # Maps sample_id -> optimizer state | |
| self.save_dir = save_dir | |
| os.makedirs(save_dir, exist_ok=True) | |
| def get_or_create_inr(self, sample_id, create_fn, device): | |
| """Get existing INR or create new one if not exists.""" | |
| if sample_id not in self.inrs: | |
| # Create new INR | |
| inr = create_fn().to(device) | |
| optimizer = torch.optim.Adam(inr.parameters()) | |
| self.inrs[sample_id] = inr | |
| self.optimizers[sample_id] = optimizer | |
| return self.inrs[sample_id], self.optimizers[sample_id] | |
| def set_inr(self, sample_id, inr, optimizer): | |
| self.inrs[sample_id] = inr | |
| self.optimizers[sample_id] = optimizer | |
| def save_state(self): | |
| """Save all INRs and optimizer states to disk.""" | |
| state = { | |
| 'inrs': { | |
| sample_id: inr.state_dict() | |
| for sample_id, inr in self.inrs.items() | |
| }, | |
| 'optimizers': { | |
| sample_id: opt.state_dict() | |
| for sample_id, opt in self.optimizers.items() | |
| } | |
| } | |
| torch.save(state, os.path.join(self.save_dir, 'inr_database.pt')) | |
| def load_state(self, create_fn, device): | |
| """Load INRs and optimizer states from disk.""" | |
| path = os.path.join(self.save_dir, 'inr_database.pt') | |
| if os.path.exists(path): | |
| state = torch.load(path, map_location=device) | |
| # Restore INRs | |
| for sample_id, inr_state in state['inrs'].items(): | |
| inr = create_fn().to(device) | |
| inr.load_state_dict(inr_state) | |
| self.inrs[sample_id] = inr | |
| # Restore optimizers | |
| for sample_id, opt_state in state['optimizers'].items(): | |
| optimizer = torch.optim.Adam(self.inrs[sample_id].parameters()) | |
| optimizer.load_state_dict(opt_state) | |
| self.optimizers[sample_id] = optimizer | |
| class INRTrainer(Trainer): | |
| def __init__(self, hidden_features=128, n_layers=3, in_features=1, out_features=1, | |
| num_steps=5000, lr=1e-3, inr_criterion=torch.nn.MSELoss(), save_dir='./inr_database', *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.hidden_features = hidden_features | |
| self.n_layers = n_layers | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.num_steps = num_steps | |
| self.lr = lr | |
| self.inr_criterion = inr_criterion | |
| # Initialize INR database | |
| self.db = INRDatabase(save_dir) | |
| # Load existing INRs if available | |
| self.db.load_state(self.create_inr, self.device) | |
| def create_inr(self): | |
| """Factory function to create new INR instances.""" | |
| return INR( | |
| hidden_features=self.hidden_features, | |
| n_layers=self.n_layers, | |
| in_features=self.in_features, | |
| out_features=self.out_features | |
| ) | |
| def create_kan(self): | |
| return FasterKAN(layers_hidden=[self.in_features] + [self.hidden_features] * (self.n_layers) + [self.out_features],) | |
| def get_sample_id(self, batch, idx): | |
| """Extract unique identifier for a sample in the batch. | |
| Override this method based on your data structure.""" | |
| # Example: if your batch contains unique IDs | |
| if 'id' in batch: | |
| return batch['id'][idx] | |
| # Fallback: create hash from the sample data | |
| sample_data = batch['audio']['array'][idx] | |
| return hash(sample_data.cpu().numpy().tobytes()) | |
| def train_inr(self, optimizer, model, coords, values, num_iters=10, plot=False): | |
| # pbar = tqdm(range(num_iters)) | |
| for _ in range(num_iters): | |
| optimizer.zero_grad() | |
| pred_values = model(coords.to(self.device)).float() | |
| loss = self.inr_criterion(pred_values.squeeze(), values) | |
| loss.backward() | |
| optimizer.step() | |
| # pbar.set_description(f'loss: {loss.item()}') | |
| if plot: | |
| plt.plot(values.cpu().detach().numpy()) | |
| plt.plot(pred_values.cpu().detach().numpy()) | |
| plt.title(loss.item()) | |
| plt.show() | |
| return model, optimizer | |
| def train_batch(self, batch, batch_idx, device): | |
| """Train INRs for each sample in batch, persisting progress.""" | |
| coords = batch['audio']['coords'].to(device) # [B, N, 1] | |
| fft = batch['audio']['fft_mag'].to(device) # [B, N] | |
| audio = batch['audio']['array'].to(device) # [B, N] | |
| y = batch['label'].to(device).float() | |
| batch_size = coords.shape[0] | |
| values = audio | |
| batch_losses = [] | |
| batch_optimizers = [] | |
| batch_inrs = [] | |
| batch_weights = tuple() | |
| batch_biases = tuple() | |
| # Training loop | |
| # pbar = tqdm(range(self.num_steps), desc="Training INRs") | |
| plot = batch_idx == 0 | |
| for i in range(batch_size): | |
| sample_id = self.get_sample_id(batch, i) | |
| inr, optimizer = self.db.get_or_create_inr(sample_id, self.create_inr, device) | |
| inr, optimizer = self.train_inr(optimizer, inr, coords[i], values[i]) | |
| self.db.set_inr(sample_id, inr, optimizer) | |
| # pred_values = inr(coords[i]).squeeze() | |
| # batch_losses.append(self.inr_criterion(pred_values, values[i])) | |
| # batch_optimizers.append(optimizer) | |
| state_dict = inr.state_dict() | |
| weights = tuple( | |
| [v.permute(1, 0).unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "weight" in w] | |
| ) | |
| biases = tuple([v.unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "bias" in w]) | |
| if not len(batch_weights): | |
| batch_weights = weights | |
| else: | |
| batch_weights = tuple( | |
| [torch.cat((weights[i], batch_weights[i]), dim=0) for i in range(len(weights))] | |
| ) | |
| if not len(batch_biases): | |
| batch_biases = biases | |
| else: | |
| batch_biases = tuple( | |
| [torch.cat((biases[i], batch_biases[i]), dim=0) for i in range(len(biases))] | |
| ) | |
| # loss_preds = torch.tensor([0]) | |
| # acc = 0 | |
| y_pred = self.model(inputs=(batch_weights, batch_biases)).squeeze() | |
| loss_preds = self.criterion(y_pred, y) | |
| self.optimizer.zero_grad() | |
| loss_preds.backward() | |
| self.optimizer.step() | |
| # for i in range(batch_size): | |
| # batch_optimizers[i].zero_grad() | |
| # batch_losses[i] += loss_preds | |
| # batch_losses[i].backward() | |
| # batch_optimizers[i].step() | |
| if batch_idx % 10 == 0: # Adjust frequency as needed | |
| self.db.save_state() | |
| probs = torch.sigmoid(y_pred) | |
| cls_pred = (probs > 0.5).float() | |
| acc = (cls_pred == y).sum() | |
| return loss_preds, acc, y | |
| def eval_batch(self, batch, batch_idx, device): | |
| """Evaluate INRs for each sample in batch.""" | |
| coords = batch['audio']['coords'].to(device) | |
| fft = batch['audio']['fft_mag'].to(device) | |
| audio = batch['audio']['array'].to(device) | |
| batch_size = coords.shape[0] | |
| # values = torch.cat(( | |
| # audio.unsqueeze(-1), | |
| # fft.unsqueeze(-1) | |
| # ), dim=-1) | |
| values = audio | |
| # Get INRs for each sample | |
| batch_inrs = [] | |
| for i in range(batch_size): | |
| sample_id = self.get_sample_id(batch, i) | |
| inr, _ = self.db.get_or_create_inr(sample_id, self.create_inr, device) | |
| batch_inrs.append(inr) | |
| # Evaluate | |
| with torch.no_grad(): | |
| all_preds = torch.stack([ | |
| inr(coords[i]) | |
| for i, inr in enumerate(batch_inrs) | |
| ]) | |
| batch_losses = torch.stack([ | |
| self.criterion(all_preds[i].squeeze(), values[i]) | |
| for i in range(batch_size) | |
| ]) | |
| avg_loss = batch_losses.mean().item() | |
| acc = torch.zeros(self.output_dim, device=device) | |
| y = values | |
| return torch.tensor(avg_loss), acc, y | |
| def verify_parallel_gradient_isolation(trainer, batch_size=4, sequence_length=1000): | |
| """ | |
| Verify that gradients remain isolated in parallel training. | |
| """ | |
| device = trainer.device | |
| # Create test data | |
| coords = torch.linspace(0, 1, sequence_length).unsqueeze(-1) # [N, 1] | |
| coords = coords.unsqueeze(0).repeat(batch_size, 1, 1) # [B, N, 1] | |
| # Create synthetic signals | |
| targets = torch.stack([ | |
| torch.sin(2 * torch.pi * (i + 1) * coords.squeeze(-1)) | |
| for i in range(batch_size) | |
| ]).to(device) | |
| # Create batch of INRs | |
| inrs = trainer.create_batch_inrs() | |
| # Store initial parameters | |
| initial_params = [{name: param.clone().detach() | |
| for name, param in inr.named_parameters()} | |
| for inr in inrs] | |
| # Create mock batch | |
| batch = { | |
| 'audio': { | |
| 'coords': coords.to(device), | |
| 'fft_mag': targets.clone(), | |
| 'array': targets.clone() | |
| } | |
| } | |
| # Run one training step | |
| trainer.train_batch(batch, 0, device) | |
| # Verify parameter changes | |
| isolation_verified = True | |
| for i, inr in enumerate(inrs): | |
| params_changed = False | |
| for name, param in inr.named_parameters(): | |
| if not torch.allclose(param, initial_params[i][name]): | |
| params_changed = True | |
| # Verify that the changes are unique to this INR | |
| for j, other_inr in enumerate(inrs): | |
| if i != j: | |
| other_param = dict(other_inr.named_parameters())[name] | |
| if not torch.allclose(other_param, initial_params[j][name]): | |
| isolation_verified = False | |
| print(f"Warning: Parameter {name} of INR {j} changed when only INR {i} should have changed") | |
| return isolation_verified | |
| class XGBoostTrainer(): | |
| def __init__(self, model_args, train_ds, val_ds, test_ds): | |
| self.train_ds = train_ds | |
| self.test_ds = test_ds | |
| print("creating train dataframe...") | |
| self.x_train, self.y_train = self.create_dataframe(train_ds, save_name='train') | |
| print("creating validation dataframe...") | |
| self.x_val, self.y_val = self.create_dataframe(val_ds, save_name='val') | |
| print("creating test dataframe...") | |
| self.x_test, self.y_test = self.create_dataframe(test_ds, save_name='test') | |
| # Convert the data to DMatrix format | |
| self.dtrain = xgb.DMatrix(self.x_train, label=self.y_train) | |
| self.dval = xgb.DMatrix(self.x_val, label=self.y_val) | |
| self.dtest = xgb.DMatrix(self.x_test, label=self.y_test) | |
| # Model initialization | |
| self.model_args = model_args | |
| self.model = xgb.XGBClassifier(**model_args) | |
| def create_dataframe(self, ds, save_name='train'): | |
| try: | |
| df = pd.read_csv(f"tasks/utils/dfs/{save_name}.csv") | |
| except FileNotFoundError: | |
| data = [] | |
| # Iterate over the dataset | |
| pbar = tqdm(enumerate(ds)) | |
| for i, batch in pbar: | |
| label = batch['label'] | |
| features = batch['audio']['features'] | |
| # Flatten the nested dictionary structure | |
| feature_dict = {'label': label} | |
| for k, v in features.items(): | |
| if isinstance(v, dict): | |
| for sub_k, sub_v in v.items(): | |
| feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean) | |
| data.append(feature_dict) | |
| # Convert to DataFrame | |
| df = pd.DataFrame(data) | |
| print(os.getcwd()) | |
| df.to_csv(f"tasks/utils/dfs/{save_name}.csv", index=False) | |
| X = df.drop(columns=['label']) | |
| y = df['label'] | |
| return X, y | |
| def fit(self): | |
| # Train using the `train` method with early stopping | |
| params = { | |
| 'objective': 'binary:logistic', | |
| 'eval_metric': 'logloss', | |
| **self.model_args | |
| } | |
| evals_result = {} | |
| num_boost_round = 1000 # Set a large number of boosting rounds | |
| # Watchlist to monitor performance on train and validation data | |
| watchlist = [(self.dtrain, 'train'), (self.dval, 'eval')] | |
| # Train the model | |
| self.model = xgb.train( | |
| params, | |
| self.dtrain, | |
| num_boost_round=num_boost_round, | |
| evals=watchlist, | |
| early_stopping_rounds=10, # Early stopping after 10 rounds with no improvement | |
| evals_result=evals_result, | |
| verbose_eval=True # Show evaluation results for each iteration | |
| ) | |
| return evals_result | |
| def train_xgboost_in_batches(self, dataloader, eval_metric="logloss"): | |
| evals_result = {} | |
| for i, batch in enumerate(dataloader): | |
| # Convert batch data to NumPy arrays | |
| X_batch = torch.cat([batch[key].view(batch[key].size(0), -1) for key in batch if key != "label"], | |
| dim=1).numpy() | |
| y_batch = batch["label"].numpy() | |
| # Create DMatrix for XGBoost | |
| dtrain = xgb.DMatrix(X_batch, label=y_batch) | |
| # Use `train` with each batch | |
| self.model = xgb.train( | |
| params, | |
| dtrain, | |
| num_boost_round=1000, # Use a large number of rounds | |
| evals=[(self.dval, 'eval')], | |
| eval_metric=eval_metric, | |
| early_stopping_rounds=10, | |
| evals_result=evals_result, | |
| verbose_eval=False # Avoid printing every iteration | |
| ) | |
| # Optionally print progress | |
| if i % 10 == 0: | |
| print(f"Batch {i + 1}/{len(dataloader)} processed.") | |
| return evals_result | |
| def predict(self): | |
| # Predict probabilities for class 1 | |
| y_prob = self.model.predict(self.dtest, output_margin=False) | |
| # Convert probabilities to binary labels (0 or 1) using a threshold (e.g., 0.5) | |
| y_pred = (y_prob >= 0.5).astype(int) | |
| # Evaluate performance | |
| accuracy = accuracy_score(self.y_test, y_pred) | |
| roc_auc = roc_auc_score(self.y_test, y_prob) | |
| print(f'Accuracy: {accuracy:.4f}') | |
| print(f'ROC AUC Score: {roc_auc:.4f}') | |
| print(classification_report(self.y_test, y_pred)) | |
| def plot_results(self, evals_result): | |
| train_logloss = evals_result["train"]["logloss"] | |
| val_logloss = evals_result["eval"]["logloss"] | |
| iterations = range(1, len(train_logloss) + 1) | |
| # Plot | |
| plt.figure(figsize=(8, 5)) | |
| plt.plot(iterations, train_logloss, label="Train LogLoss", color="blue") | |
| plt.plot(iterations, val_logloss, label="Validation LogLoss", color="red") | |
| plt.xlabel("Boosting Round (Iteration)") | |
| plt.ylabel("Log Loss") | |
| plt.title("Training and Validation Log Loss over Iterations") | |
| plt.legend() | |
| plt.grid() | |
| plt.show() | |