Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| from huggingface_hub import hf_hub_download | |
| from transformers import BertModel, BertTokenizer, CLIPModel, BertConfig, CLIPConfig, CLIPProcessor | |
| from modeling_unimo import UnimoForMaskedLM | |
| def load_dict_text(path): | |
| with open(path, 'r') as f: | |
| load_data = {} | |
| lines = f.readlines() | |
| for line in lines: | |
| key, value = line.split('\t') | |
| load_data[key] = value.replace('\n', '') | |
| return load_data | |
| def load_text(path): | |
| with open(path, 'r') as f: | |
| lines = f.readlines() | |
| load_data = [] | |
| for line in lines: | |
| load_data.append(line.strip().replace('\n', '')) | |
| return load_data | |
| class MKGformerModel(nn.Module): | |
| def __init__(self, text_config, vision_config): | |
| super().__init__() | |
| self.model = UnimoForMaskedLM(text_config, vision_config) | |
| def farword(self, batch): | |
| return self.model(**batch, return_dict=True) | |
| # tokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| # entity and relation | |
| ent2text = load_dict_text('./dataset/MarKG/entity2text.txt') | |
| rel2text = load_dict_text('./dataset/MarKG/relation2text.txt') | |
| analogy_entities = load_text('./dataset/MARS/analogy_entities.txt') | |
| analogy_relations = load_text('./dataset/MARS/analogy_relations.txt') | |
| ent2description = load_dict_text('./dataset/MarKG/entity2textlong.txt') | |
| text2ent = {text: ent for ent, text in ent2text.items()} | |
| ent2token = {ent: f"[ENTITY_{i}]" for i, ent in enumerate(ent2description)} | |
| rel2token = {rel: f"[RELATION_{i}]" for i, rel in enumerate(rel2text)} | |
| analogy_ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(ent2description) if ent in analogy_entities} | |
| analogy_rel2token = {rel : f"[RELATION_{i}]" for i, rel in enumerate(rel2text) if rel in analogy_relations} | |
| entity_list = list(ent2token.values()) | |
| relation_list = list(rel2token.values()) | |
| analogy_ent_list = list(analogy_ent2token.values()) | |
| analogy_rel_list = list(analogy_rel2token.values()) | |
| num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': entity_list}) | |
| num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': relation_list}) | |
| vocab = tokenizer.get_added_vocab() # dict: word: idx | |
| relation_id_st = vocab[relation_list[0]] | |
| relation_id_ed = vocab[relation_list[-1]] + 1 | |
| entity_id_st = vocab[entity_list[0]] | |
| entity_id_ed = vocab[entity_list[-1]] + 1 | |
| # analogy entities and relations | |
| analogy_entity_ids = [vocab[ent] for ent in analogy_ent_list] | |
| analogy_relation_ids = [vocab[rel] for rel in analogy_rel_list] | |
| num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': ["[R]"]}) | |
| # model | |
| checkpoint_path = hf_hub_download(repo_id='flow3rdown/mkgformer_mart_ft', filename="mkgformer_mart_ft", repo_type='model') | |
| clip_config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch32').vision_config | |
| clip_config.device = 'cpu' | |
| bert_config = BertConfig.from_pretrained('bert-base-uncased') | |
| mkgformer = MKGformerModel(clip_config, bert_config) | |
| mkgformer.model.resize_token_embeddings(len(tokenizer)) | |
| mkgformer.load_state_dict(torch.load(checkpoint_path, map_location='cpu')["state_dict"]) | |
| # processor | |
| processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') | |
| def single_inference_iit(head_img, head_id, tail_img, tail_id, question_txt, question_id): | |
| # (I, I) -> (T, ?) | |
| ques_ent_text = ent2description[question_id] | |
| inputs = tokenizer( | |
| tokenizer.sep_token.join([analogy_ent2token[head_id] + " ", "[R] ", analogy_ent2token[tail_id] + " "]), | |
| tokenizer.sep_token.join([analogy_ent2token[question_id] + " " + ques_ent_text, "[R] ", "[MASK]"]), | |
| truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) | |
| sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] | |
| inputs['sep_idx'] = torch.tensor(sep_idx) | |
| inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() | |
| for i, idx in enumerate(sep_idx): | |
| inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 | |
| # image | |
| pixel_values = processor(images=[head_img, tail_img], return_tensors='pt')['pixel_values'].squeeze() | |
| inputs['pixel_values'] = pixel_values.unsqueeze(0) | |
| input_ids = inputs['input_ids'] | |
| model_output = mkgformer.model(**inputs, return_dict=True) | |
| logits = model_output[0].logits | |
| bsz = input_ids.shape[0] | |
| _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz | |
| mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity | |
| answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] | |
| return answer | |
| def single_inference_tti(head_txt, head_id, tail_txt, tail_id, question_img, question_id): | |
| # (T, T) -> (I, ?) | |
| head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id] | |
| inputs = tokenizer( | |
| tokenizer.sep_token.join([analogy_ent2token[head_id] + " " + head_ent_text, "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]), | |
| tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]), | |
| truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) | |
| sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] | |
| inputs['sep_idx'] = torch.tensor(sep_idx) | |
| inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() | |
| for i, idx in enumerate(sep_idx): | |
| inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 | |
| # image | |
| pixel_values = processor(images=question_img, return_tensors='pt')['pixel_values'].unsqueeze(1) | |
| pixel_values = torch.cat((pixel_values, torch.zeros_like(pixel_values)), dim=1) | |
| inputs['pixel_values'] = pixel_values | |
| input_ids = inputs['input_ids'] | |
| model_output = mkgformer.model(**inputs, return_dict=True) | |
| logits = model_output[0].logits | |
| bsz = input_ids.shape[0] | |
| _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz | |
| mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity | |
| answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] | |
| return answer | |
| def blended_inference_iti(head_img, head_id, tail_txt, tail_id, question_img, question_id): | |
| # (I, T) -> (I, ?) | |
| head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id] | |
| inputs = tokenizer( | |
| tokenizer.sep_token.join([analogy_ent2token[head_id], "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]), | |
| tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]), | |
| truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) | |
| sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] | |
| inputs['sep_idx'] = torch.tensor(sep_idx) | |
| inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() | |
| for i, idx in enumerate(sep_idx): | |
| inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 | |
| # image | |
| pixel_values = processor(images=[head_img, question_img], return_tensors='pt')['pixel_values'].squeeze() | |
| inputs['pixel_values'] = pixel_values.unsqueeze(0) | |
| input_ids = inputs['input_ids'] | |
| model_output = mkgformer.model(**inputs, return_dict=True) | |
| logits = model_output[0].logits | |
| bsz = input_ids.shape[0] | |
| _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz | |
| mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity | |
| answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] | |
| return answer | |
| def single_tab_iit(): | |
| with gr.Column(): | |
| gr.Markdown(""" $(I_h, I_t) : (T_q, ?)$ | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| head_image = gr.Image(type='pil', label="Head Image") | |
| head_ent = gr.Textbox(lines=1, label="Head Entity") | |
| with gr.Column(): | |
| tail_image = gr.Image(type='pil', label="Tail Image") | |
| tail_ent = gr.Textbox(lines=1, label="Tail Entity") | |
| with gr.Column(): | |
| question_text = gr.Textbox(lines=1, label="Question Name") | |
| question_ent = gr.Textbox(lines=1, label="Question Entity") | |
| submit_btn = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Output") | |
| submit_btn.click(fn=single_inference_iit, | |
| inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent], | |
| outputs=[output_text]) | |
| examples=[['examples/tree.jpg', 'Q10884', 'examples/forest.jpg', 'Q4421', "Anhui", 'Q40956']] | |
| ex = gr.Examples( | |
| examples=examples, | |
| fn=single_inference_iit, | |
| inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent], | |
| outputs=[output_text], | |
| cache_examples=False, | |
| run_on_click=False | |
| ) | |
| def single_tab_tti(): | |
| with gr.Column(): | |
| gr.Markdown(""" $(T_h, T_t) : (I_q, ?)$ | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| head_text = gr.Textbox(lines=1, label="Head Name") | |
| head_ent = gr.Textbox(lines=1, label="Head Entity") | |
| with gr.Column(): | |
| tail_text = gr.Textbox(lines=1, label="Tail Name") | |
| tail_ent = gr.Textbox(lines=1, label="Tail Entity") | |
| with gr.Column(): | |
| question_image = gr.Image(type='pil', label="Question Image") | |
| question_ent = gr.Textbox(lines=1, label="Question Entity") | |
| submit_btn = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Output") | |
| submit_btn.click(fn=single_inference_tti, | |
| inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent], | |
| outputs=[output_text]) | |
| examples=[['scrap', 'Q3217573', 'watch', 'Q178794', 'examples/root.jpg', 'Q111029']] | |
| ex = gr.Examples( | |
| examples=examples, | |
| fn=single_inference_iit, | |
| inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent], | |
| outputs=[output_text], | |
| cache_examples=False, | |
| run_on_click=False | |
| ) | |
| def blended_tab_iti(): | |
| with gr.Column(): | |
| gr.Markdown(""" $(I_h, T_t) : (I_q, ?)$ | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| head_image = gr.Image(type='pil', label="Head Image") | |
| head_ent = gr.Textbox(lines=1, label="Head Entity") | |
| with gr.Column(): | |
| tail_txt = gr.Textbox(lines=1, label="Tail Name") | |
| tail_ent = gr.Textbox(lines=1, label="Tail Entity") | |
| with gr.Column(): | |
| question_image = gr.Image(type='pil', label="Question Image") | |
| question_ent = gr.Textbox(lines=1, label="Question Entity") | |
| submit_btn = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Output") | |
| submit_btn.click(fn=blended_inference_iti, | |
| inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent], | |
| outputs=[output_text]) | |
| examples=[['examples/watermelon.jpg', 'Q38645', 'fruit', 'Q3314483', 'examples/coffee.jpeg', 'Q8486']] | |
| ex = gr.Examples( | |
| examples=examples, | |
| fn=single_inference_iit, | |
| inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent], | |
| outputs=[output_text], | |
| cache_examples=False, | |
| run_on_click=False | |
| ) | |
| TITLE = """MKG Analogy""" | |
| with gr.Blocks() as block: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(TITLE) | |
| with gr.Tab("Single Analogical Reasoning"): | |
| single_tab_iit() | |
| single_tab_tti() | |
| with gr.Tab("Blended Analogical Reasoning"): | |
| blended_tab_iti() | |
| # gr.HTML(ARTICLE) | |
| block.queue(max_size=64).launch(enable_queue=True) |