Spaces:
Runtime error
Runtime error
| import random | |
| import numpy as np | |
| from torch.utils.data.dataset import Dataset | |
| from config.config import cfg | |
| class MultipleDatasets(Dataset): | |
| def __init__(self, | |
| dbs, | |
| partition, | |
| make_same_len=True, | |
| total_len=None, | |
| verbose=False): | |
| self.dbs = dbs | |
| self.db_num = len(self.dbs) | |
| self.max_db_data_num = max([len(db) for db in dbs]) | |
| self.db_len_cumsum = np.cumsum([len(db) for db in dbs]) | |
| self.make_same_len = make_same_len | |
| # self.partition = partition | |
| self.partition = {k: v for k, v in sorted(partition.items(), key=lambda item: item[1])} | |
| self.dataset = {} | |
| for db in dbs: | |
| self.dataset.update({db.__class__.__name__: db}) | |
| if verbose: | |
| print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)]) | |
| print( | |
| f'Sample Ratio: {self.partition}') | |
| def __len__(self): | |
| return self.max_db_data_num | |
| def __getitem__(self, index): | |
| p = np.random.rand() | |
| v = list(self.partition.values()) | |
| k = list(self.partition.keys()) | |
| for i,v_i in enumerate(v): | |
| if p<=v_i: | |
| return self.dataset[k[i]][index % len(self.dataset[k[i]])] | |
| import random | |
| import numpy as np | |
| from torch.utils.data.dataset import Dataset | |
| class MultipleDatasets_debug(Dataset): | |
| def __init__(self, dbs, make_same_len=True, total_len=None, verbose=False): | |
| self.dbs = dbs | |
| self.db_num = len(self.dbs) | |
| self.max_db_data_num = max([len(db) for db in dbs]) | |
| self.db_len_cumsum = np.cumsum([len(db) for db in dbs]) | |
| self.make_same_len = make_same_len | |
| if total_len == 'auto': | |
| self.total_len = self.db_len_cumsum[-1] | |
| self.auto_total_len = True | |
| else: | |
| self.total_len = total_len | |
| self.auto_total_len = False | |
| if total_len is not None: | |
| self.per_db_len = self.total_len // self.db_num | |
| if verbose: | |
| print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)]) | |
| print( | |
| f'Auto total length: {self.auto_total_len}, {self.total_len}') | |
| def __len__(self): | |
| # all dbs have the same length | |
| if self.make_same_len: | |
| if self.total_len is None: | |
| # match the longest length | |
| return self.max_db_data_num * self.db_num | |
| else: | |
| # each dataset has the same length and total len is fixed | |
| return self.total_len | |
| else: | |
| # each db has different length, simply concat | |
| return sum([len(db) for db in self.dbs]) | |
| def __getitem__(self, index): | |
| if self.make_same_len: | |
| if self.total_len is None: | |
| # match the longest length | |
| db_idx = index // self.max_db_data_num | |
| data_idx = index % self.max_db_data_num | |
| if data_idx >= len(self.dbs[db_idx]) * ( | |
| self.max_db_data_num // | |
| len(self.dbs[db_idx])): # last batch: random sampling | |
| data_idx = random.randint(0, len(self.dbs[db_idx]) - 1) | |
| else: # before last batch: use modular | |
| data_idx = data_idx % len(self.dbs[db_idx]) | |
| else: | |
| db_idx = index // self.per_db_len | |
| data_idx = index % self.per_db_len | |
| if db_idx > (self.db_num - 1): | |
| # last batch: randomly choose one dataset | |
| db_idx = random.randint(0, self.db_num - 1) | |
| if len(self.dbs[db_idx]) < self.per_db_len and \ | |
| data_idx >= len(self.dbs[db_idx]) * (self.per_db_len // len(self.dbs[db_idx])): | |
| # last batch: random sampling in this dataset | |
| data_idx = random.randint(0, len(self.dbs[db_idx]) - 1) | |
| else: | |
| # before last batch: use modular | |
| data_idx = data_idx % len(self.dbs[db_idx]) | |
| else: | |
| for i in range(self.db_num): | |
| if index < self.db_len_cumsum[i]: | |
| db_idx = i | |
| break | |
| if db_idx == 0: | |
| data_idx = index | |
| else: | |
| data_idx = index - self.db_len_cumsum[db_idx - 1] | |
| return self.dbs[db_idx][data_idx] | |