from tokenizers.models import WordLevel from tokenizers import Tokenizer from tokenizers.pre_tokenizers import Split from tokenizers import Regex from tokenizers.processors import TemplateProcessing from transformers import BatchEncoding import torch class ChembertaTokenizer: def __init__(self, vocab_file): self.tokenizer = Tokenizer( WordLevel.from_file( vocab_file, unk_token='[UNK]' )) self.tokenizer.pre_tokenizer = Split( pattern=Regex(r"\[(.*?)\]|Cl|Br|>>|\\|.*?"), behavior='isolated' ) # Disable padding self.tokenizer.encode_special_tokens = True self.special_token_ids = { self.tokenizer.token_to_id('[CLS]'), self.tokenizer.token_to_id('[SEP]'), self.tokenizer.token_to_id('[PAD]'), self.tokenizer.token_to_id('[UNK]') } self.tokenizer.post_processor = TemplateProcessing( single='[CLS] $A [SEP]', pair='[CLS] $A [SEP] $B:1 [SEP]:1', special_tokens=[ ('[CLS]', self.tokenizer.token_to_id('[CLS]')), ('[SEP]', self.tokenizer.token_to_id('[SEP]')) ] ) def encode(self, inputs, padding=None, truncation=False, max_length=None, return_tensors=None): # Configure padding/truncation if padding: self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('[PAD]'), pad_token='[PAD]', length=max_length) else: self.tokenizer.no_padding() if truncation: self.tokenizer.enable_truncation(max_length=max_length) else: self.tokenizer.no_truncation() if return_tensors == 'pt': tensor_type = 'pt' else: tensor_type = None # Handle batch or single input if isinstance(inputs, list): enc = self.tokenizer.encode_batch(inputs) data = { "input_ids": [e.ids for e in enc], "attention_mask": [e.attention_mask for e in enc] } return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type) else: # Single sequence: wrap into batch of size 1 enc = [self.tokenizer.encode(inputs)] data = { "input_ids": [e.ids for e in enc], "attention_mask": [e.attention_mask for e in enc] } return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type) def __call__(self, inputs, padding=None, truncation=False, max_length=None, return_tensors=None): return self.encode(inputs, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors) def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def _decode_sequence(seq): if skip_special_tokens: seq = [idx for idx in seq if idx not in self.special_token_ids] return [self.tokenizer.id_to_token(idx) for idx in seq] # 1) batch: list of lists or torch tensor if isinstance(ids, torch.Tensor): ids = ids.tolist() if len(ids) == 1: ids = ids[0] if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)): return [_decode_sequence(seq) for seq in ids] # 2) single sequence: list of ints or torch tensor if isinstance(ids, (list)): return _decode_sequence(ids) # 3) single int if isinstance(ids, int): return self.tokenizer.id_to_token(ids) def decode(self, ids, skip_special_tokens=False): def _decode_sequence(seq): if skip_special_tokens: seq = [idx for idx in seq if idx not in self.special_token_ids] return ''.join(self.tokenizer.id_to_token(idx) for idx in seq) # 1) batch: list of lists or torch tensor if isinstance(ids, torch.Tensor): ids = ids.tolist() if len(ids) == 1: ids = ids[0] if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)): return [_decode_sequence(seq) for seq in ids] # 2) single sequence: list of ints or torch tensor if isinstance(ids, (list)): return _decode_sequence(ids) # 3) single int if isinstance(ids, int): return self.tokenizer.id_to_token(ids)