gaia-eval-l1-20 / agent.py
kengboon
Add verifier
8a7cb66
raw
history blame
5.58 kB
import os, time
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from tools import my_tools
try:
if os.path.exists("system_prompt.txt"):
with open("system_prompt.txt", "r") as f:
system_prompt = f.read()
else:
system_prompt = ""
except Exception as e:
print(f"Error reading system prompt file: {e}")
system_prompt = ""
sys_msg = SystemMessage(system_prompt)
class MyMessagesState(MessagesState):
tool_call_cnt: int=0
class MyAgent:
def __init__(
self,
model_name: str="gemini-2.0-flash",
tools: list | None=None,
max_tool_call: int | None=5
):
self.graph = self._build_graph(model_name, tools, max_tool_call)
def _build_graph(
self,
model_name: str,
tools: list | None,
max_tool_call: int | None
):
self.model_name = model_name
# Delays to avoid exceeding rate limits
if "gemini-2.5" in model_name:
self.delay = 6
elif "gemini-2.0" in model_name:
self.delay = 4
else:
self.delay = 0
self.start_time = None
llm = ChatGoogleGenerativeAI(model=model_name)
if not tools:
tools = my_tools
llm_with_tools = llm.bind_tools(tools)
def _delay_wait():
if self.start_time is not None:
elapsed = time.time() - self.start_time
if elapsed < self.delay:
time.sleep(self.delay - elapsed)
self.start_time = time.time()
def assistant_with_tools(state: MyMessagesState) -> MyMessagesState:
print("assistant_with_tools")
_delay_wait()
return {
"messages": [llm_with_tools.invoke(state["messages"])],
"tool_call_cnt": state["tool_call_cnt"]
}
def assistant_without_tools(state: MyMessagesState) -> MyMessagesState:
print("assistant_without_tools")
_delay_wait()
return {
"messages": [llm.invoke(state["messages"][:2])],
"tool_call_cnt": state["tool_call_cnt"]
}
def increment_tool_call_cnt(state: MyMessagesState) -> MyMessagesState:
print("increment_tool_call_cnt")
return {
"messages": state["messages"],
"tool_call_cnt": state["tool_call_cnt"] + 1
}
def check_tool_call_status(state: MyMessagesState):
# Initialize tool calling counter
if "tool_call_cnt" not in state:
state["tool_call_cnt"] = 0
# If tool call error, fall back to LLM without tools instead of return failure message
if "messages" in state:
last_msg = state["messages"][-1]
if isinstance(last_msg, AIMessage) and "sorry" in last_msg.content.lower():
return "assistant_without_tools"
tool_msg = state["messages"][-2]
if isinstance(tool_msg, ToolMessage) and tool_msg.status == "error":
return "assistant_without_tools"
target_node = tools_condition(state)
if (target_node == "tools"):
# If tool calling count exceeds limit, fallback to LLM without tools
if (max_tool_call is not None) and state["tool_call_cnt"] >= max_tool_call:
return "assistant_without_tools"
return target_node
builder = StateGraph(MyMessagesState)
# Add nodes
builder.add_node("assistant_with_tools", assistant_with_tools)
builder.add_node("assistant_without_tools", assistant_without_tools)
builder.add_node("tools", ToolNode(tools))
builder.add_node("increment_tool_call_cnt", increment_tool_call_cnt)
# Add edges
builder.add_edge(START, "assistant_with_tools")
builder.add_conditional_edges("assistant_with_tools", check_tool_call_status)
builder.add_edge("tools", "increment_tool_call_cnt")
builder.add_edge("increment_tool_call_cnt", "assistant_with_tools")
builder.add_edge("assistant_without_tools", END)
return builder.compile()
def ask(self, question: str, final_answer_only: bool=True) -> str:
if "gemma" in self.model_name:
messages = [
HumanMessage(f"{system_prompt} {question}")
]
else:
messages = [
sys_msg,
HumanMessage(question)
]
messages = self.graph.invoke({"messages": messages, "tool_call_cnt": 0})
response = messages["messages"][-1].content.strip()
return self.format_output(response) if final_answer_only else response
def format_output(self, output: str) -> str:
# Find the text behind "FINAL ANSWER: "
if "FINAL ANSWER: " in output:
return output.split("FINAL ANSWER: ")[-1].strip()
# Find the first text between ** and ** (bold markdown)
elif "**" in output:
return output.split("**")[1].strip()
else:
return output
def __call__(self, question: str) -> str:
return self.ask(question)
if __name__ == "__main__":
from main import main as main_fn
main_fn()