DLRNA-BERTa / chemberta.py
IlPakoZ's picture
Upload 18 files
3912a9f verified
raw
history blame
4.67 kB
from tokenizers.models import WordLevel
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import Split
from tokenizers import Regex
from tokenizers.processors import TemplateProcessing
from transformers import BatchEncoding
import torch
class ChembertaTokenizer:
def __init__(self, vocab_file):
self.tokenizer = Tokenizer(
WordLevel.from_file(
vocab_file,
unk_token='[UNK]'
))
self.tokenizer.pre_tokenizer = Split(
pattern=Regex(r"\[(.*?)\]|Cl|Br|>>|\\|.*?"),
behavior='isolated'
)
# Disable padding
self.tokenizer.encode_special_tokens = True
self.special_token_ids = {
self.tokenizer.token_to_id('[CLS]'),
self.tokenizer.token_to_id('[SEP]'),
self.tokenizer.token_to_id('[PAD]'),
self.tokenizer.token_to_id('[UNK]')
}
self.tokenizer.post_processor = TemplateProcessing(
single='[CLS] $A [SEP]',
pair='[CLS] $A [SEP] $B:1 [SEP]:1',
special_tokens=[
('[CLS]', self.tokenizer.token_to_id('[CLS]')),
('[SEP]', self.tokenizer.token_to_id('[SEP]'))
]
)
def encode(self, inputs, padding=None, truncation=False,
max_length=None, return_tensors=None):
# Configure padding/truncation
if padding:
self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id('[PAD]'),
pad_token='[PAD]', length=max_length)
else:
self.tokenizer.no_padding()
if truncation:
self.tokenizer.enable_truncation(max_length=max_length)
else:
self.tokenizer.no_truncation()
if return_tensors == 'pt':
tensor_type = 'pt'
else:
tensor_type = None
# Handle batch or single input
if isinstance(inputs, list):
enc = self.tokenizer.encode_batch(inputs)
data = {
"input_ids": [e.ids for e in enc],
"attention_mask": [e.attention_mask for e in enc]
}
return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
else:
# Single sequence: wrap into batch of size 1
enc = [self.tokenizer.encode(inputs)]
data = {
"input_ids": [e.ids for e in enc],
"attention_mask": [e.attention_mask for e in enc]
}
return BatchEncoding(data=data, encoding=enc, tensor_type=tensor_type)
def __call__(self, inputs, padding=None, truncation=False,
max_length=None, return_tensors=None):
return self.encode(inputs, padding=padding, truncation=truncation,
max_length=max_length, return_tensors=return_tensors)
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
def _decode_sequence(seq):
if skip_special_tokens:
seq = [idx for idx in seq if idx not in self.special_token_ids]
return [self.tokenizer.id_to_token(idx) for idx in seq]
# 1) batch: list of lists or torch tensor
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
if len(ids) == 1:
ids = ids[0]
if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
return [_decode_sequence(seq) for seq in ids]
# 2) single sequence: list of ints or torch tensor
if isinstance(ids, (list)):
return _decode_sequence(ids)
# 3) single int
if isinstance(ids, int):
return self.tokenizer.id_to_token(ids)
def decode(self, ids, skip_special_tokens=False):
def _decode_sequence(seq):
if skip_special_tokens:
seq = [idx for idx in seq if idx not in self.special_token_ids]
return ''.join(self.tokenizer.id_to_token(idx) for idx in seq)
# 1) batch: list of lists or torch tensor
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
if len(ids) == 1:
ids = ids[0]
if isinstance(ids, (list)) and len(ids) > 0 and isinstance(ids[0], (list)):
return [_decode_sequence(seq) for seq in ids]
# 2) single sequence: list of ints or torch tensor
if isinstance(ids, (list)):
return _decode_sequence(ids)
# 3) single int
if isinstance(ids, int):
return self.tokenizer.id_to_token(ids)