Spaces:
Running
Running
| from tokenizers.models import WordLevel | |
| from tokenizers import Tokenizer | |
| from tokenizers.pre_tokenizers import Split | |
| from tokenizers import Regex | |
| from tokenizers.processors import TemplateProcessing | |
| from transformers import BatchEncoding | |
| import torch | |
| class ChembertaTokenizer: | |
| def __init__(self, vocab_file): | |
| self.tokenizer = Tokenizer( | |
| WordLevel.from_file( | |
| vocab_file, | |
| unk_token='[UNK]' | |
| )) | |
| self.tokenizer.pre_tokenizer = Split( | |
| pattern=Regex(r"\[(.*?)\]|Cl|Br|>>|\\|.*?"), | |
| behavior='isolated' | |
| ) | |
| # Disable padding | |
| self.tokenizer.encode_special_tokens = True | |
| self.special_token_ids = { | |
| self.tokenizer.token_to_id('[CLS]'), | |
| self.tokenizer.token_to_id('[SEP]'), | |
| self.tokenizer.token_to_id('[PAD]'), | |
| self.tokenizer.token_to_id('[UNK]') | |
| } | |
| self.tokenizer.post_processor = TemplateProcessing( | |
| single='[CLS] $A [SEP]', | |
| pair='[CLS] $A [SEP] $B:1 [SEP]:1', | |
| special_tokens=[ | |
| ('[CLS]', self.tokenizer.token_to_id('[CLS]')), | |
| ('[SEP]', self.tokenizer.token_to_id('[SEP]')) | |
| ] | |
| ) | |
| def encode(self, inputs, padding=None, truncation=False, | |
| max_length=None, return_tensors=None): | |
| # Configure padding/truncation | |
| if padding: | |
| self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('[PAD]'), | |
| pad_token='[PAD]', length=max_length) | |
| else: | |
| self.tokenizer.no_padding() | |
| if truncation: | |
| self.tokenizer.enable_truncation(max_length=max_length) | |
| else: | |
| self.tokenizer.no_truncation() | |
| if return_tensors == 'pt': | |
| tensor_type = 'pt' | |
| else: | |
| tensor_type = None | |
| # Handle batch or single input | |
| if isinstance(inputs, list): | |
| enc = self.tokenizer.encode_batch(inputs) | |
| data = { | |
| "input_ids": [e.ids for e in enc], | |
| "attention_mask": [e.attention_mask for e in enc] | |
| } | |
| return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type) | |
| else: | |
| # Single sequence: wrap into batch of size 1 | |
| enc = [self.tokenizer.encode(inputs)] | |
| data = { | |
| "input_ids": [e.ids for e in enc], | |
| "attention_mask": [e.attention_mask for e in enc] | |
| } | |
| return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type) | |
| def __call__(self, inputs, padding=None, truncation=False, | |
| max_length=None, return_tensors=None): | |
| return self.encode(inputs, padding=padding, truncation=truncation, | |
| max_length=max_length, return_tensors=return_tensors) | |
| def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | |
| def _decode_sequence(seq): | |
| if skip_special_tokens: | |
| seq = [idx for idx in seq if idx not in self.special_token_ids] | |
| return [self.tokenizer.id_to_token(idx) for idx in seq] | |
| # 1) batch: list of lists or torch tensor | |
| if isinstance(ids, torch.Tensor): | |
| ids = ids.tolist() | |
| if len(ids) == 1: | |
| ids = ids[0] | |
| if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)): | |
| return [_decode_sequence(seq) for seq in ids] | |
| # 2) single sequence: list of ints or torch tensor | |
| if isinstance(ids, (list)): | |
| return _decode_sequence(ids) | |
| # 3) single int | |
| if isinstance(ids, int): | |
| return self.tokenizer.id_to_token(ids) | |
| def decode(self, ids, skip_special_tokens=False): | |
| def _decode_sequence(seq): | |
| if skip_special_tokens: | |
| seq = [idx for idx in seq if idx not in self.special_token_ids] | |
| return ''.join(self.tokenizer.id_to_token(idx) for idx in seq) | |
| # 1) batch: list of lists or torch tensor | |
| if isinstance(ids, torch.Tensor): | |
| ids = ids.tolist() | |
| if len(ids) == 1: | |
| ids = ids[0] | |
| if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)): | |
| return [_decode_sequence(seq) for seq in ids] | |
| # 2) single sequence: list of ints or torch tensor | |
| if isinstance(ids, (list)): | |
| return _decode_sequence(ids) | |
| # 3) single int | |
| if isinstance(ids, int): | |
| return self.tokenizer.id_to_token(ids) | |