Spaces:
Sleeping
Sleeping
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import logging | |
| import json | |
| import os | |
| import minigpt4.common.dist_utils as dist_utils | |
| from minigpt4.common.registry import registry | |
| from minigpt4.common.vqa_tools.vqa import VQA | |
| from minigpt4.common.vqa_tools.vqa_eval import VQAEval | |
| from minigpt4.tasks.base_task import BaseTask | |
| class VQATask(BaseTask): | |
| def __init__( | |
| self, | |
| num_beams, | |
| max_len, | |
| min_len, | |
| evaluate, | |
| num_ans_candidates, | |
| inference_method="rank", | |
| prompt="", | |
| ): | |
| super().__init__() | |
| self.num_beams = num_beams | |
| self.max_len = max_len | |
| self.min_len = min_len | |
| self.evaluate = evaluate | |
| self.inference_method = inference_method | |
| self.num_ans_candidates = num_ans_candidates | |
| self.prompt = prompt | |
| self.answer_list = None | |
| self.ques_files = dict() | |
| self.anno_files = dict() | |
| def setup_task(cls, cfg): | |
| run_cfg = cfg.run_cfg | |
| num_beams = run_cfg.get("num_beams", 3) | |
| max_len = run_cfg.get("max_len", 10) | |
| min_len = run_cfg.get("min_len", 1) | |
| evaluate = run_cfg.get("evaluate", False) | |
| inference_method = run_cfg.get("inference_method", "rank") | |
| num_ans_candidates = run_cfg.get("num_ans_candidates", 128) | |
| prompt = run_cfg.get("prompt", "") | |
| return cls( | |
| num_beams=num_beams, | |
| max_len=max_len, | |
| min_len=min_len, | |
| evaluate=evaluate, | |
| num_ans_candidates=num_ans_candidates, | |
| inference_method=inference_method, | |
| prompt=prompt, | |
| ) | |
| def build_datasets(self, cfg): | |
| datasets = super().build_datasets(cfg) | |
| # get question file, annotation file and anwser list in COCO format | |
| for dataset in datasets.values(): | |
| for split in dataset: | |
| if ( | |
| hasattr(dataset[split], "coco_fmt_qust_file") | |
| and dataset[split].coco_fmt_qust_file is not None | |
| ): | |
| self.ques_files[split] = dataset[split].coco_fmt_qust_file | |
| self.anno_files[split] = dataset[split].coco_fmt_anno_file | |
| try: | |
| self.answer_list = dataset[split].answer_list | |
| except AttributeError: | |
| # if answer_list is not provided, then set it to None | |
| pass | |
| if len(self.ques_files) > 0: | |
| assert len(self.ques_files) == len( | |
| self.anno_files | |
| ), "Only support one split for evaluation." | |
| return datasets | |
| def valid_step(self, model, samples): | |
| answers = model.predict_answers( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| prompt=self.prompt, | |
| ) | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| for answer, ques_id in zip(answers, question_id): | |
| ques_id = int(ques_id.item()) | |
| pred_qa_pairs.append({"question_id": ques_id, "answer": answer}) | |
| return pred_qa_pairs | |
| def after_evaluation(self, val_result, split_name, result_dir): | |
| result_file = self.save_result( | |
| val_result, | |
| result_dir=result_dir, #registry.get_path("result_dir"), | |
| filename=split_name, | |
| remove_duplicate="question_id", | |
| ) | |
| # metrics = self._report_metrics(result_file=result_file, split=split_name) | |
| # return metrics | |
| def _report_metrics(self, result_file, split): | |
| """ | |
| Use official VQA evaluation script to report metrics. | |
| """ | |
| metrics = {} | |
| if split in self.ques_files and split in self.anno_files: | |
| vqa = VQA(self.anno_files[split], self.ques_files[split]) | |
| vqa_result = vqa.loadRes( | |
| resFile=result_file, quesFile=self.ques_files[split] | |
| ) | |
| # create vqaEval object by taking vqa and vqaRes | |
| # n is precision of accuracy (number of places after decimal), default is 2 | |
| vqa_scorer = VQAEval(vqa, vqa_result, n=2) | |
| logging.info("Start VQA evaluation.") | |
| vqa_scorer.evaluate() | |
| # print accuracies | |
| overall_acc = vqa_scorer.accuracy["overall"] | |
| metrics["agg_metrics"] = overall_acc | |
| logging.info("Overall Accuracy is: %.02f\n" % overall_acc) | |
| logging.info("Per Answer Type Accuracy is the following:") | |
| for ans_type in vqa_scorer.accuracy["perAnswerType"]: | |
| logging.info( | |
| "%s : %.02f" | |
| % (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type]) | |
| ) | |
| metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type] | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(metrics) + "\n") | |
| return metrics | |
| class GQATask(VQATask): | |
| def valid_step(self, model, samples): | |
| answers = model.predict_answers( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| prompt=self.prompt, | |
| ) | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| gt_answers = samples["answer"] | |
| for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
| ques_id = int(ques_id.item()) | |
| pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) | |
| return pred_qa_pairs | |
| def _report_metrics(self, result_file, split): | |
| """ | |
| TODO: add other evaluation metrics for GQA | |
| """ | |
| results = json.load(open(result_file, "r")) | |
| acc = [] | |
| vqa_tool = VQAEval() | |
| for res in results: | |
| if res["gt_ans"] is None: | |
| # prepare test results for leaderboard evaluation | |
| self._save_result_leaderboard(results) | |
| return | |
| gt_ans = res["gt_ans"] | |
| pred = res["pred_ans"] | |
| if self.inference_method == "generate": | |
| pred = vqa_tool.processPunctuation(pred) | |
| pred = vqa_tool.processDigitArticle(pred) | |
| vqa_acc = 1 if pred == gt_ans else 0 | |
| acc.append(vqa_acc) | |
| accuracy = sum(acc) / len(acc) * 100 | |
| metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(metrics) + "\n") | |
| logging.info(metrics) | |
| return metrics | |
| class ScienceQATask(GQATask): | |
| def valid_step(self, model, samples): | |
| answers = model.predict_class( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| prompt=self.prompt, | |
| ) | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| gt_answers = samples["answer"] | |
| for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
| ques_id = int(ques_id.item()) | |
| pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) | |
| return pred_qa_pairs | |
| class AOKVQATask(VQATask): | |
| def valid_step(self, model, samples): | |
| answers = model.predict_answers( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| ) | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| gt_answers = samples["direct_answers"] | |
| for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
| pred_qa_pairs.append( | |
| {"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer} | |
| ) | |
| return pred_qa_pairs | |
| def _report_metrics(self, result_file, split): | |
| """ | |
| Implementing accuracy computation for AOKVQA, see | |
| https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details. | |
| """ | |
| # TODO add evaluation for multi-choice | |
| results = json.load(open(result_file, "r")) | |
| acc = [] | |
| for res in results: | |
| if res["gt_ans"] is None: | |
| # prepare test results for leaderboard evaluation | |
| self._save_result_leaderboard(results) | |
| return | |
| pred = res["pred_ans"] | |
| gt_ans = res["gt_ans"] | |
| num_match = sum([pred == gt for gt in gt_ans]) | |
| vqa_acc = min(1.0, num_match / 3.0) | |
| acc.append(vqa_acc) | |
| accuracy = sum(acc) / len(acc) * 100 | |
| metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(metrics) + "\n") | |
| logging.info(metrics) | |
| return metrics | |
| def _save_result_leaderboard(self, results): | |
| """ | |
| Saving the results in the format required for leaderboard evaluation. | |
| [TODO] add support for multi-choice. | |
| """ | |
| result_leaderboard = dict() | |
| for res in results: | |
| result_leaderboard[res["question_id"]] = { | |
| "direct_answer": res["pred_ans"], | |
| "multiple_choice": "", | |
| } | |
| result_file = registry.get_path("result_dir") + "_leaderboard.json" | |
| with open(result_file, "w") as f: | |
| json.dump(result_leaderboard, f) | |
| logging.info(f"Saved results for leaderboard evaluation at {result_file}") | |