Spaces:
Runtime error
Runtime error
| from collections.abc import Callable | |
| import traceback | |
| from typing import List, Union | |
| from datasets import Dataset | |
| import re | |
| import pickle | |
| import os | |
| from transformers.pipelines.pt_utils import KeyDataset | |
| from transformers import AutoTokenizer | |
| from tqdm.auto import tqdm | |
| URL_REGEX = r"\b(https?://\S+)\b" | |
| EMAIL_REGEX = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" | |
| TAG_REGEX = r"<[^>]+>" | |
| HANDLE_REGEX = r"[^a-zA-Z](@\w+)" | |
| class Translator: | |
| def __init__( | |
| self, | |
| pipe: Callable, | |
| max_length: int = 500, | |
| batch_size: int = 16, | |
| save_every_step=100, | |
| text_key="text", | |
| save_filename=None, | |
| replace_chinese_puncts=False, | |
| verbose=False, | |
| ): | |
| self.pipe = pipe | |
| self.max_length = max_length | |
| self.batch_size = batch_size | |
| self.save_every_step = save_every_step | |
| self.save_filename = save_filename | |
| self.text_key = text_key | |
| self.replace_chinese_puncts = replace_chinese_puncts | |
| self.verbose = verbose | |
| if max_length == None and hasattr(pipe.model.config, "max_length"): | |
| self.max_length = pipe.model.config.max_length | |
| def _is_chinese(self, text: str) -> bool: | |
| return ( | |
| re.search( | |
| r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002ebef\U00030000-\U000323af\ufa0e\ufa0f\ufa11\ufa13\ufa14\ufa1f\ufa21\ufa23\ufa24\ufa27\ufa28\ufa29\u3006\u3007][\ufe00-\ufe0f\U000e0100-\U000e01ef]?", | |
| text, | |
| ) | |
| is not None | |
| ) | |
| def _split_sentences(self, text: str) -> List[str]: | |
| tokens = self.pipe.tokenizer(text, add_special_tokens=False) | |
| token_size = len(tokens.input_ids) | |
| if len(text) <= self.max_length: | |
| return [text] | |
| delimiter = set() | |
| delimiter.update("。!?;…!?;") | |
| sent_list = [] | |
| sent = text | |
| while token_size > self.max_length: | |
| orig_sent_len = token_size | |
| # find the index of delimiter near the max_length | |
| for i in range(token_size - 2, 0, -1): | |
| token = tokens.token_to_chars(0, i) | |
| char = sent[token.start : token.end] | |
| if char in delimiter: | |
| split_char_index = token.end | |
| next_sent = sent[split_char_index:] | |
| if len(next_sent) == 1: | |
| continue | |
| sent_list = [next_sent] + sent_list | |
| sent = sent[0:split_char_index] | |
| break | |
| tokens = self.pipe.tokenizer(sent, add_special_tokens=False) | |
| token_size = len(tokens.input_ids) | |
| # no delimiter found, leave the sentence as it is | |
| if token_size == orig_sent_len: | |
| sent_list = [sent] + sent_list | |
| sent = "" | |
| break | |
| if len(sent) > 0: | |
| sent_list = [sent] + sent_list | |
| return sent_list | |
| def _preprocess(self, text: str) -> (str, str): | |
| # extract entities | |
| tags = re.findall(TAG_REGEX, text) | |
| handles = re.findall(HANDLE_REGEX, text) | |
| urls = re.findall(URL_REGEX, text) | |
| emails = re.findall(EMAIL_REGEX, text) | |
| entities = urls + emails + tags + handles | |
| # TODO: escape entity placeholders | |
| for i, entity in enumerate(entities): | |
| text = text.replace(entity, "eeee[%d]" % i, 1) | |
| lines = text.split("\n") | |
| sentences = [] | |
| num_tokens = [] | |
| template = text.replace("{", "{{").replace("}", "}}") | |
| chunk_index = 0 | |
| for line in lines: | |
| sentence = line.strip() | |
| if len(sentence) > 0 and self._is_chinese(sentence): | |
| chunks = self._split_sentences(sentence) | |
| for chunk in chunks: | |
| sentences.append(chunk) | |
| tokens = self.pipe.tokenizer(chunk, add_special_tokens=False) | |
| num_tokens.append(len(tokens.input_ids)) | |
| chunk = chunk.replace("{", "{{").replace("}", "}}") | |
| template = template.replace(chunk, "{%d}" % chunk_index, 1) | |
| chunk_index += 1 | |
| return sentences, template, num_tokens, entities | |
| def _postprocess( | |
| self, | |
| template: str, | |
| src_sentences: List[str], | |
| translations: List[str], | |
| entities: List[str], | |
| ) -> str: | |
| processed = [] | |
| alphanumeric_regex = re.compile( | |
| "([a-zA-Za-zA-Z0-9\d+'\",,(\()\)::;;“”。·\.\??\!!‘’$\[\]<>/]+)" | |
| ) | |
| def hash_text(text: List[str]) -> str: | |
| text = "|".join(text) | |
| puncts_map = str.maketrans(",;:()。?!“”‘’", ",;:().?!\"\"''") | |
| text = text.translate(puncts_map) | |
| return text.lower() | |
| for i, p in enumerate(translations): | |
| src_sentence = src_sentences[i] | |
| if self.replace_chinese_puncts: | |
| p = re.sub(",", ",", p) # replace all commas | |
| p = re.sub(";", ";", p) # replace semi-colon | |
| p = re.sub(":", ":", p) # replace colon | |
| p = re.sub("\(", "(", p) # replace round basket | |
| p = re.sub("\)", ")", p) # replace round basket | |
| p = re.sub(r"([\d]),([\d])", r"\1,\2", p) | |
| src_matches = re.findall(alphanumeric_regex, src_sentence) | |
| tgt_matches = re.findall(alphanumeric_regex, p) | |
| # length not match or no match | |
| if ( | |
| len(src_matches) != len(tgt_matches) | |
| or len(src_matches) == 0 | |
| or len(tgt_matches) == 0 | |
| ): | |
| processed.append(p) | |
| continue | |
| # normalize full-width to half-width and lower case | |
| src_hashes = hash_text(src_matches) | |
| translated_hashes = hash_text(tgt_matches) | |
| if src_hashes != translated_hashes: | |
| # fix unmatched | |
| for j in range(len(src_matches)): | |
| if src_matches[j] != tgt_matches[j]: | |
| p = p.replace(tgt_matches[j], src_matches[j], 1) | |
| processed.append(p) | |
| output = template.format(*processed) | |
| # replace entities | |
| for i, entity in enumerate(entities): | |
| output = output.replace("eeee[%d]" % i, entity, 1) | |
| # TODO: unescape entity placeholders | |
| # fix repeated punctuations | |
| output = re.sub(r"([「」()『』《》。,:])\1+", r"\1", output) | |
| # fix brackets | |
| if "“" in output: | |
| output = re.sub("“", "「", output) | |
| if "”" in output: | |
| output = re.sub("”", "」", output) | |
| return output | |
| def _save(self, translations): | |
| with open(self.save_filename, "wb") as f: | |
| pickle.dump(translations, f) | |
| def __call__(self, inputs: Union[List[str], Dataset]) -> List[str]: | |
| templates = [] | |
| sentences = [] | |
| num_tokens = [] | |
| sentence_indices = [] | |
| outputs = [] | |
| translations = [] | |
| entities_list = [] | |
| resume_from_file = None | |
| if isinstance(inputs, Dataset): | |
| ds = inputs | |
| else: | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| ds = Dataset.from_list([{"text": text} for text in inputs]) | |
| for i, text_input in tqdm( | |
| enumerate(ds), total=len(ds), desc="Preprocessing", disable=not self.verbose | |
| ): | |
| chunks, template, num_tokens, entities = self._preprocess( | |
| text_input["text"] | |
| ) | |
| templates.append(template) | |
| sentence_indices.append([]) | |
| entities_list.append(entities) | |
| for j, chunk in enumerate(chunks): | |
| sentences.append(chunk) | |
| sentence_indices[len(sentence_indices) - 1].append(len(sentences) - 1) | |
| num_tokens.append(num_tokens[j]) | |
| if self.save_filename: | |
| resume_from_file = ( | |
| self.save_filename if os.path.isfile(self.save_filename) else None | |
| ) | |
| if resume_from_file != None: | |
| translations = pickle.load(open(resume_from_file, "rb")) | |
| if self.verbose: | |
| print("translated:", len(translations)) | |
| print("to translate:", len(sentences) - len(translations)) | |
| if resume_from_file != None: | |
| print( | |
| "Resuming from {}({} records)".format( | |
| resume_from_file, len(translations) | |
| ) | |
| ) | |
| ds = Dataset.from_list( | |
| [{"text": text} for text in sentences[len(translations) :]] | |
| ) | |
| max_token_length = max(num_tokens) | |
| if self.verbose: | |
| print("Max Length:", max_token_length) | |
| total_records = len(ds) | |
| if total_records > 0: | |
| step = 0 | |
| with tqdm( | |
| disable=not self.verbose, desc="Translating", total=total_records | |
| ) as pbar: | |
| for out in self.pipe( | |
| KeyDataset(ds, self.text_key), | |
| batch_size=self.batch_size, | |
| max_length=self.max_length, | |
| ): | |
| translations.append(out[0]) | |
| # export generate result every n steps | |
| if ( | |
| step != 0 | |
| and self.save_filename != None | |
| and step % self.save_every_step == 0 | |
| ): | |
| self._save(translations) | |
| step += 1 | |
| pbar.update(1) | |
| if self.save_filename != None and total_records > 0: | |
| self._save(translations) | |
| for i, template in tqdm( | |
| enumerate(templates), | |
| total=len(templates), | |
| desc="Postprocessing", | |
| disable=not self.verbose, | |
| ): | |
| try: | |
| src_sentences = [sentences[index] for index in sentence_indices[i]] | |
| tgt_sentences = [ | |
| translations[index]["translation_text"] | |
| for index in sentence_indices[i] | |
| ] | |
| output = self._postprocess( | |
| template, src_sentences, tgt_sentences, entities_list[i] | |
| ) | |
| outputs.append(output) | |
| except Exception as error: | |
| print(error) | |
| print(template) | |
| traceback.print_exc() | |
| # print(template, sentence_indices[i], len(translations)) | |
| return outputs | |
| class Object(object): | |
| pass | |
| class FakePipe(object): | |
| def __init__(self, max_length: int = 500): | |
| self.model = Object() | |
| self.model.config = Object() | |
| self.model.config.max_length = max_length | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "indiejoseph/bart-translation-zh-yue" | |
| ) | |
| def __call__(self, text: List[str], batch_size: str, max_length: int): | |
| for i in range(len(text)): | |
| sentence = text[i] | |
| # extract entities | |
| tags = re.findall(TAG_REGEX, sentence) | |
| handles = re.findall(HANDLE_REGEX, sentence) | |
| urls = re.findall(URL_REGEX, sentence) | |
| emails = re.findall(EMAIL_REGEX, sentence) | |
| entities = urls + emails + tags + handles | |
| for i, entity in enumerate(entities): | |
| sentence = sentence.replace(entity, "eeee[%d]" % i, 1) | |
| if "123" in sentence: | |
| yield [{"translation_text": sentence.replace("123", "123")}] | |
| continue | |
| if "abc" in sentence: | |
| yield [{"translation_text": sentence.replace("abc", "ABC")}] | |
| continue | |
| if "Acetaminophen" in sentence: | |
| yield [ | |
| { | |
| "translation_text": sentence.replace( | |
| "Acetaminophen", "ACEtaminidien" | |
| ) | |
| } | |
| ] | |
| continue | |
| yield [{"translation_text": sentence}] | |
| if __name__ == "__main__": | |
| fake_pipe = FakePipe(60) | |
| translator = Translator(fake_pipe, max_length=60, batch_size=2, verbose=True) | |
| text1 = "对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人:" | |
| text2 = """对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人: | |
| ``` | |
| # 设置用于匹配输入的关键字,并定义相应的回答数据字典。 | |
| keywords = {'你好': '你好!很高兴见到你。', | |
| '再见': '再见!有机会再聊。', | |
| '你叫什么': '我是一个聊天机器人。', | |
| '你是谁': '我是一个基于人工智能技术制作的聊天机器人。'} | |
| # 定义用于处理用户输入的函数。 | |
| def chatbot(input_text): | |
| # 遍历关键字数据字典,匹配用户的输入。 | |
| for key in keywords: | |
| if key in input_text: | |
| # 如果匹配到了关键字,返回相应的回答。 | |
| return keywords[key] | |
| # 如果没有找到匹配的关键字,返回默认回答。 | |
| return "对不起,我不知道你在说什么。" | |
| # 运行聊天机器人。 | |
| while True: | |
| # 获取用户输入。 | |
| user_input = input('用户: ') | |
| # 如果用户输入“再见”,退出程序。 | |
| if user_input == '再见': | |
| break | |
| # 处理用户输入,并打印回答。 | |
| print('机器人: ' + chatbot(user_input)) | |
| ``` | |
| 这是一个非常简单的例子。对于实用的聊天机器人,可能需要使用更复杂的 NLP 技术和机器学习模型,以更好地理解和回答用户的问题。""" | |
| text3 = "布洛芬(Ibuprofen)同撲熱息痛(Acetaminophen)係兩種常見嘅非處方藥,用於緩解疼痛、發燒同關節痛。" | |
| text4 = "123 “abc” def's http://www.google.com [email protected] @abc 網址:http://localhost/abc下載" | |
| text5 = "新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」" | |
| outputs = translator([text1, text2, text3, text4, text5]) | |
| # for i, line in enumerate(outputs[1].split("\n")): | |
| # input_text = text2.split("\n")[i] | |
| # if line != input_text: | |
| # print(line, text2.split("\n")[i]) | |
| assert outputs[0] == text1 | |
| assert outputs[1] == text2.replace("“", "「").replace("”", "」") | |
| assert outputs[2] == text3 | |
| assert outputs[3] == text4.replace("“", "「").replace("”", "」") | |
| assert outputs[4] == text5 | |
| # exception | |
| assert ( | |
| len( | |
| translator._split_sentences( | |
| "新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」" | |
| ) | |
| ) | |
| == 1 | |
| ) | |
| translator = Translator(fake_pipe, max_length=5, batch_size=2, verbose=True) | |
| assert ( | |
| len( | |
| translator._split_sentences("====。====。====。====。====。====。====。====。====。") | |
| ) | |
| == 9 | |
| ) | |