Spaces:
Running
Running
| import torch.nn | |
| from torch.utils.data import DataLoader | |
| from utils.data import FFTDataset, SplitDataset, AudioINRDataset | |
| from datasets import load_dataset | |
| from utils.train import Trainer, INRTrainer | |
| from utils.models import MultiGraph, ImplicitEncoder | |
| from omegaconf import OmegaConf | |
| # from .utils.models import CNNKan, KanEncoder | |
| from utils.inr import INR | |
| from utils.data_utils import * | |
| from huggingface_hub import login | |
| import yaml | |
| import datetime | |
| import json | |
| import numpy as np | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from scipy.signal import savgol_filter as savgol | |
| from utils.kan import FasterKAN | |
| from utils.relational_transformer import RelationalTransformer | |
| from collections import OrderedDict | |
| def plot_results(dims, i, data, losses, pred_values): | |
| data = savgol(data.cpu().detach().numpy(), window_length=250, polyorder=1) | |
| pred_values = pred_values.transpose(-1, -2).unflatten(-1, data.shape[-2:]).squeeze(0).cpu().detach().numpy() | |
| pred_values = (pred_values - np.min(pred_values)) / (np.max(pred_values) - np.min(pred_values)) | |
| data = (data - np.min(data)) / (np.max(data) - np.min(data)) | |
| plt.plot(data.squeeze()) | |
| plt.plot(pred_values.squeeze()) | |
| # axes[0].set_title('Original') | |
| # axes[1].set_title('Reconstruction') | |
| plt.show() | |
| # plt.plot(np.arange(len(losses)), losses) | |
| # plt.xlabel('Iteration') | |
| # plt.ylabel('Reconstruction MSE Error') | |
| # plt.show() | |
| # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| current_date = datetime.date.today().strftime("%Y-%m-%d") | |
| datetime_dir = f"frugal_{current_date}" | |
| args_dir = 'utils/config.yaml' | |
| data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data']) | |
| exp_num = data_args.exp_num | |
| model_name = data_args.model_name | |
| rt_args = Container(**yaml.safe_load(open(args_dir, 'r'))['RelationalTransformer']) | |
| cnn_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f']) | |
| conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer']) | |
| kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN_INR']) | |
| inr_args = Container(**yaml.safe_load(open(args_dir, 'r'))['INR']) | |
| if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"): | |
| os.makedirs(f"{data_args.log_dir}/{datetime_dir}") | |
| with open("../../logs/token.txt", "r") as f: | |
| api_key = f.read() | |
| # local_rank, world_size, gpus_per_node = setup() | |
| local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| login(api_key) | |
| dataset = load_dataset("rfcx/frugalai", streaming=True) | |
| train_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=True) | |
| train_dl = DataLoader(train_ds, batch_size=data_args.batch_size) | |
| val_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=False) | |
| val_dl = DataLoader(val_ds, batch_size=data_args.batch_size) | |
| test_ds = AudioINRDataset(FFTDataset(dataset["test"])) | |
| test_dl = DataLoader(test_ds, batch_size=data_args.batch_size) | |
| # for i, batch in enumerate(train_ds): | |
| # fft_phase, fft_mag, audio = batch['audio']['fft_phase'], batch['audio']['fft_mag'], batch['audio']['array'] | |
| # label = batch['label'] | |
| # fig, axes = plt.subplots(nrows=1, ncols=3) | |
| # axes = axes.flatten() | |
| # axes[0].plot(fft_phase) | |
| # axes[1].plot(fft_mag) | |
| # axes[2].plot(audio) | |
| # fig.suptitle(label) | |
| # plt.tight_layout() | |
| # plt.show() | |
| # if i > 20: | |
| # break | |
| # model = DualEncoder(model_args, model_args_f, conformer_args) | |
| # model = FasterKAN([18000,64,64,16,1]) | |
| # model = INR(in_features=1) | |
| # model.kan.speed() | |
| # model = KanEncoder(kan_args.get_dict()) | |
| # model = model.to(local_rank) | |
| # state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu')) | |
| # new_state_dict = OrderedDict() | |
| # for key, value in state_dict.items(): | |
| # if key.startswith('module.'): | |
| # key = key[7:] | |
| # new_state_dict[key] = value | |
| # missing, unexpected = model.load_state_dict(new_state_dict) | |
| # model = DDP(model, device_ids=[local_rank], output_device=local_rank) | |
| # num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| # print(f"Number of parameters: {num_params}") | |
| # | |
| # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| # total_steps = int(data_args.num_epochs) * 1000 | |
| # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, | |
| # T_max=total_steps, | |
| # eta_min=float((5e-4) / 10)) | |
| loss_fn = torch.nn.BCEWithLogitsLoss() | |
| inr_criterion = torch.nn.MSELoss() | |
| for i, batch in enumerate(train_ds): | |
| coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array'] | |
| coords = coords.to(local_rank) | |
| fft = fft.to(local_rank) | |
| audio = audio.to(local_rank) | |
| # values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1) | |
| model = INR(hidden_features=128, n_layers=4, | |
| in_features=1, | |
| out_features=1).to(local_rank) | |
| # model = FasterKAN(layers_hidden=[1,16,16,1]).to(local_rank) | |
| optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3) | |
| pbar = tqdm(range(200)) | |
| losses = [] | |
| print(coords.shape) | |
| for t in pbar: | |
| optimizer.zero_grad() | |
| pred_values = model(coords.to(local_rank)).float() | |
| loss = inr_criterion(pred_values, fft) | |
| loss.backward() | |
| optimizer.step() | |
| pbar.set_description(f'loss: {loss.item()}') | |
| losses.append(loss.item()) | |
| state_dict = model.state_dict() | |
| torch.save(state_dict, 'test') | |
| # print(f'Sample {i+offset} label {label} saved in {inr_path}') | |
| plot_results(1, i, fft, losses, pred_values) | |
| # # | |
| exit() | |
| # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path)) | |
| # print(f"Missing keys: {missing}") | |
| # print(f"Unexpected keys: {unexpected}") | |
| layer_layout = [inr_args.in_features] + [inr_args.hidden_features for _ in range(inr_args.n_layers)] + [inr_args.out_features] | |
| graph_constructor = OmegaConf.create( | |
| { | |
| "_target_": "utils.graph_constructor.GraphConstructor", | |
| "_recursive_": False, | |
| "_convert_": "all", | |
| "d_in": 1, | |
| "d_edge_in": 1, | |
| "zero_out_bias": False, | |
| "zero_out_weights": False, | |
| "sin_emb": True, | |
| "sin_emb_dim": rt_args.d_node, | |
| "use_pos_embed": False, | |
| "input_layers": 1, | |
| "inp_factor": 1, | |
| "num_probe_features": 0, | |
| "inr_model": None, | |
| "stats": None, | |
| "sparsify": False, | |
| 'sym_edges': False, | |
| } | |
| ) | |
| rt_model = RelationalTransformer(layer_layout=layer_layout, graph_constructor=graph_constructor, | |
| **rt_args.get_dict()).to(local_rank) | |
| rt_model.proj_out= torch.nn.Identity() | |
| multi_graph = MultiGraph(rt_model, cnn_args) | |
| implicit_net = INR(**inr_args.get_dict()) | |
| model = ImplicitEncoder(implicit_net, multi_graph).to(local_rank) | |
| num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Number of parameters: {num_parameters}") | |
| optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3) | |
| trainer = Trainer(model=model, optimizer=optimizer, | |
| criterion=loss_fn, output_dim=1, scaler=None, | |
| scheduler=None, train_dataloader=train_dl, | |
| val_dataloader=val_dl, device=local_rank, | |
| exp_num=datetime_dir, log_path=data_args.log_dir, | |
| range_update=None, | |
| accumulation_step=1, max_iter=100, | |
| exp_name=f"frugal_kan_{exp_num}") | |
| fit_res = trainer.fit(num_epochs=100, device=local_rank, | |
| early_stopping=10, only_p=False, best='loss', conf=True) | |
| output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json' | |
| with open(output_filename, "w") as f: | |
| json.dump(fit_res, f, indent=2) | |
| preds, acc = trainer.predict(test_dl, local_rank) | |
| print(f"Accuracy: {acc}") | |