Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import time | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from matplotlib import colors | |
| from matplotlib import pyplot as plt | |
| from torch import Tensor, nn | |
| from torch.utils.data import ConcatDataset | |
| class CharsetMapper(object): | |
| """A simple class to map ids into strings. | |
| It works only when the character set is 1:1 mapping between individual | |
| characters and individual ids. | |
| """ | |
| def __init__(self, | |
| filename='', | |
| max_length=30, | |
| null_char=u'\u2591'): | |
| """Creates a lookup table. | |
| Args: | |
| filename: Path to charset file which maps characters to ids. | |
| max_sequence_length: The max length of ids and string. | |
| null_char: A unicode character used to replace '<null>' character. | |
| the default value is a light shade block '░'. | |
| """ | |
| self.null_char = null_char | |
| self.max_length = max_length | |
| self.label_to_char = self._read_charset(filename) | |
| self.char_to_label = dict(map(reversed, self.label_to_char.items())) | |
| self.num_classes = len(self.label_to_char) | |
| def _read_charset(self, filename): | |
| """Reads a charset definition from a tab separated text file. | |
| Args: | |
| filename: a path to the charset file. | |
| Returns: | |
| a dictionary with keys equal to character codes and values - unicode | |
| characters. | |
| """ | |
| import re | |
| pattern = re.compile(r'(\d+)\t(.+)') | |
| charset = {} | |
| self.null_label = 0 | |
| charset[self.null_label] = self.null_char | |
| with open(filename, 'r') as f: | |
| for i, line in enumerate(f): | |
| m = pattern.match(line) | |
| assert m, f'Incorrect charset file. line #{i}: {line}' | |
| label = int(m.group(1)) + 1 | |
| char = m.group(2) | |
| charset[label] = char | |
| return charset | |
| def trim(self, text): | |
| assert isinstance(text, str) | |
| return text.replace(self.null_char, '') | |
| def get_text(self, labels, length=None, padding=True, trim=False): | |
| """ Returns a string corresponding to a sequence of character ids. | |
| """ | |
| length = length if length else self.max_length | |
| labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] | |
| if padding: | |
| labels = labels + [self.null_label] * (length-len(labels)) | |
| text = ''.join([self.label_to_char[label] for label in labels]) | |
| if trim: text = self.trim(text) | |
| return text | |
| def get_labels(self, text, length=None, padding=True, case_sensitive=False): | |
| """ Returns the labels of the corresponding text. | |
| """ | |
| length = length if length else self.max_length | |
| if padding: | |
| text = text + self.null_char * (length - len(text)) | |
| if not case_sensitive: | |
| text = text.lower() | |
| labels = [self.char_to_label[char] for char in text] | |
| return labels | |
| def pad_labels(self, labels, length=None): | |
| length = length if length else self.max_length | |
| return labels + [self.null_label] * (length - len(labels)) | |
| def digits(self): | |
| return '0123456789' | |
| def digit_labels(self): | |
| return self.get_labels(self.digits, padding=False) | |
| def alphabets(self): | |
| all_chars = list(self.char_to_label.keys()) | |
| valid_chars = [] | |
| for c in all_chars: | |
| if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': | |
| valid_chars.append(c) | |
| return ''.join(valid_chars) | |
| def alphabet_labels(self): | |
| return self.get_labels(self.alphabets, padding=False) | |
| class Timer(object): | |
| """A simple timer.""" | |
| def __init__(self): | |
| self.data_time = 0. | |
| self.data_diff = 0. | |
| self.data_total_time = 0. | |
| self.data_call = 0 | |
| self.running_time = 0. | |
| self.running_diff = 0. | |
| self.running_total_time = 0. | |
| self.running_call = 0 | |
| def tic(self): | |
| self.start_time = time.time() | |
| self.running_time = self.start_time | |
| def toc_data(self): | |
| self.data_time = time.time() | |
| self.data_diff = self.data_time - self.running_time | |
| self.data_total_time += self.data_diff | |
| self.data_call += 1 | |
| def toc_running(self): | |
| self.running_time = time.time() | |
| self.running_diff = self.running_time - self.data_time | |
| self.running_total_time += self.running_diff | |
| self.running_call += 1 | |
| def total_time(self): | |
| return self.data_total_time + self.running_total_time | |
| def average_time(self): | |
| return self.average_data_time() + self.average_running_time() | |
| def average_data_time(self): | |
| return self.data_total_time / (self.data_call or 1) | |
| def average_running_time(self): | |
| return self.running_total_time / (self.running_call or 1) | |
| class Logger(object): | |
| _handle = None | |
| _root = None | |
| def init(output_dir, name, phase): | |
| format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ | |
| '%(message)s'.format(name) | |
| logging.basicConfig(level=logging.INFO, format=format) | |
| try: os.makedirs(output_dir) | |
| except: pass | |
| config_path = os.path.join(output_dir, f'{phase}.txt') | |
| Logger._handle = logging.FileHandler(config_path) | |
| Logger._root = logging.getLogger() | |
| def enable_file(): | |
| if Logger._handle is None or Logger._root is None: | |
| raise Exception('Invoke Logger.init() first!') | |
| Logger._root.addHandler(Logger._handle) | |
| def disable_file(): | |
| if Logger._handle is None or Logger._root is None: | |
| raise Exception('Invoke Logger.init() first!') | |
| Logger._root.removeHandler(Logger._handle) | |
| class Config(object): | |
| def __init__(self, config_path, host=True): | |
| def __dict2attr(d, prefix=''): | |
| for k, v in d.items(): | |
| if isinstance(v, dict): | |
| __dict2attr(v, f'{prefix}{k}_') | |
| else: | |
| if k == 'phase': | |
| assert v in ['train', 'test'] | |
| if k == 'stage': | |
| assert v in ['pretrain-vision', 'pretrain-language', | |
| 'train-semi-super', 'train-super'] | |
| self.__setattr__(f'{prefix}{k}', v) | |
| assert os.path.exists(config_path), '%s does not exists!' % config_path | |
| with open(config_path) as file: | |
| config_dict = yaml.load(file, Loader=yaml.FullLoader) | |
| with open('configs/template.yaml') as file: | |
| default_config_dict = yaml.load(file, Loader=yaml.FullLoader) | |
| __dict2attr(default_config_dict) | |
| __dict2attr(config_dict) | |
| self.global_workdir = os.path.join(self.global_workdir, self.global_name) | |
| def __getattr__(self, item): | |
| attr = self.__dict__.get(item) | |
| if attr is None: | |
| attr = dict() | |
| prefix = f'{item}_' | |
| for k, v in self.__dict__.items(): | |
| if k.startswith(prefix): | |
| n = k.replace(prefix, '') | |
| attr[n] = v | |
| return attr if len(attr) > 0 else None | |
| else: | |
| return attr | |
| def __repr__(self): | |
| str = 'ModelConfig(\n' | |
| for i, (k, v) in enumerate(sorted(vars(self).items())): | |
| str += f'\t({i}): {k} = {v}\n' | |
| str += ')' | |
| return str | |
| def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): | |
| # normalize mask | |
| mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) | |
| if mask.shape != image.shape: | |
| mask = cv2.resize(mask,(image.shape[1], image.shape[0])) | |
| # get color map | |
| color_map = plt.get_cmap(cmap) | |
| mask = color_map(mask)[:,:,:3] | |
| # convert float to uint8 | |
| mask = (mask * 255).astype(dtype=np.uint8) | |
| # set the basic color | |
| basic_color = np.array(colors.to_rgb(color)) * 255 | |
| basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) | |
| basic_color = basic_color.astype(dtype=np.uint8) | |
| # blend with basic color | |
| blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0) | |
| # blend with mask | |
| blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0) | |
| return blended_img | |
| def onehot(label, depth, device=None): | |
| """ | |
| Args: | |
| label: shape (n1, n2, ..., ) | |
| depth: a scalar | |
| Returns: | |
| onehot: (n1, n2, ..., depth) | |
| """ | |
| if not isinstance(label, torch.Tensor): | |
| label = torch.tensor(label, device=device) | |
| onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) | |
| onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) | |
| return onehot | |
| class MyDataParallel(nn.DataParallel): | |
| def gather(self, outputs, target_device): | |
| r""" | |
| Gathers tensors from different GPUs on a specified device | |
| (-1 means the CPU). | |
| """ | |
| def gather_map(outputs): | |
| out = outputs[0] | |
| if isinstance(out, (str, int, float)): | |
| return out | |
| if isinstance(out, list) and isinstance(out[0], str): | |
| return [o for out in outputs for o in out] | |
| if isinstance(out, torch.Tensor): | |
| return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) | |
| if out is None: | |
| return None | |
| if isinstance(out, dict): | |
| if not all((len(out) == len(d) for d in outputs)): | |
| raise ValueError('All dicts must have the same number of keys') | |
| return type(out)(((k, gather_map([d[k] for d in outputs])) | |
| for k in out)) | |
| return type(out)(map(gather_map, zip(*outputs))) | |
| # Recursive function calls like this create reference cycles. | |
| # Setting the function to None clears the refcycle. | |
| try: | |
| res = gather_map(outputs) | |
| finally: | |
| gather_map = None | |
| return res | |
| class MyConcatDataset(ConcatDataset): | |
| def __getattr__(self, k): | |
| return getattr(self.datasets[0], k) | |