gaia-eval-l1-20 / verifier.py
kengboon
Add verifier
8a7cb66
raw
history blame
1.66 kB
import json
DATA_PATH = "groundtruths/gaia-2023-validation.jsonl"
class Verifier:
def __init__(self, data_path: str | None=None):
if data_path is None:
data_path = DATA_PATH
self.data_path = data_path
self.data: dict = self.load_data()
self.correct_cnt = 0
self.total_cnt = 0
def load_data(self) -> dict:
data = {}
with open(self.data_path, "r") as f:
for line in f:
record = json.loads(line)
data[record["task_id"]] = record
return data
def verify(self, task_id: str, answer: str) -> None | tuple[bool, str]:
record = self.data.get(task_id)
if not record:
print(f"Task ID {task_id} not found.")
return None
self.total_cnt += 1
if record["Final answer"] == answer:
self.correct_cnt += 1
return True, record["Final answer"]
return False, record["Final answer"]
def get_answer(self, task_id: str) -> str:
record = self.data.get(task_id)
if not record:
print(f"Task ID {task_id} not found.")
return ""
return record["Final answer"]
def get_accuracy(self) -> float:
if self.total_cnt == 0:
return 0.0
return self.correct_cnt / self.total_cnt
def get_output(self) -> str:
return f"**Correct:** {self.correct_cnt}/{self.total_cnt} ({self.get_accuracy():.2%})"
if __name__ == "__main__":
verifier = Verifier()
print(f"Loaded {len(verifier.data)} records from {verifier.data_path}")
for record in verifier.data:
print(record)