import os, time from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, AIMessage from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.graph import StateGraph, MessagesState, START, END from langgraph.prebuilt import ToolNode, tools_condition from tools import my_tools try: if os.path.exists("system_prompt.txt"): with open("system_prompt.txt", "r") as f: system_prompt = f.read() else: system_prompt = "" except Exception as e: print(f"Error reading system prompt file: {e}") system_prompt = "" sys_msg = SystemMessage(system_prompt) class MyMessagesState(MessagesState): tool_call_cnt: int=0 class MyAgent: def __init__( self, model_name: str="gemini-2.0-flash", tools: list | None=None, max_tool_call: int | None=5 ): self.graph = self._build_graph(model_name, tools, max_tool_call) def _build_graph( self, model_name: str, tools: list | None, max_tool_call: int | None ): self.model_name = model_name # Delays to avoid exceeding rate limits if "gemini-2.5" in model_name: self.delay = 6 elif "gemini-2.0" in model_name: self.delay = 4 else: self.delay = 0 self.start_time = None llm = ChatGoogleGenerativeAI(model=model_name) if not tools: tools = my_tools llm_with_tools = llm.bind_tools(tools) def _delay_wait(): if self.start_time is not None: elapsed = time.time() - self.start_time if elapsed < self.delay: time.sleep(self.delay - elapsed) self.start_time = time.time() def assistant_with_tools(state: MyMessagesState) -> MyMessagesState: print("assistant_with_tools") _delay_wait() return { "messages": [llm_with_tools.invoke(state["messages"])], "tool_call_cnt": state["tool_call_cnt"] } def assistant_without_tools(state: MyMessagesState) -> MyMessagesState: print("assistant_without_tools") _delay_wait() return { "messages": [llm.invoke(state["messages"][:2])], "tool_call_cnt": state["tool_call_cnt"] } def increment_tool_call_cnt(state: MyMessagesState) -> MyMessagesState: print("increment_tool_call_cnt") return { "messages": state["messages"], "tool_call_cnt": state["tool_call_cnt"] + 1 } def check_tool_call_status(state: MyMessagesState): # Initialize tool calling counter if "tool_call_cnt" not in state: state["tool_call_cnt"] = 0 # If tool call error, fall back to LLM without tools instead of return failure message if "messages" in state: last_msg = state["messages"][-1] if isinstance(last_msg, AIMessage) and "sorry" in last_msg.content.lower(): return "assistant_without_tools" tool_msg = state["messages"][-2] if isinstance(tool_msg, ToolMessage) and tool_msg.status == "error": return "assistant_without_tools" target_node = tools_condition(state) if (target_node == "tools"): # If tool calling count exceeds limit, fallback to LLM without tools if (max_tool_call is not None) and state["tool_call_cnt"] >= max_tool_call: return "assistant_without_tools" return target_node builder = StateGraph(MyMessagesState) # Add nodes builder.add_node("assistant_with_tools", assistant_with_tools) builder.add_node("assistant_without_tools", assistant_without_tools) builder.add_node("tools", ToolNode(tools)) builder.add_node("increment_tool_call_cnt", increment_tool_call_cnt) # Add edges builder.add_edge(START, "assistant_with_tools") builder.add_conditional_edges("assistant_with_tools", check_tool_call_status) builder.add_edge("tools", "increment_tool_call_cnt") builder.add_edge("increment_tool_call_cnt", "assistant_with_tools") builder.add_edge("assistant_without_tools", END) return builder.compile() def ask(self, question: str, final_answer_only: bool=True) -> str: if "gemma" in self.model_name: messages = [ HumanMessage(f"{system_prompt} {question}") ] else: messages = [ sys_msg, HumanMessage(question) ] messages = self.graph.invoke({"messages": messages, "tool_call_cnt": 0}) response = messages["messages"][-1].content.strip() return self.format_output(response) if final_answer_only else response def format_output(self, output: str) -> str: # Find the text behind "FINAL ANSWER: " if "FINAL ANSWER: " in output: return output.split("FINAL ANSWER: ")[-1].strip() # Find the first text between ** and ** (bold markdown) elif "**" in output: return output.split("**")[1].strip() else: return output def __call__(self, question: str) -> str: return self.ask(question) if __name__ == "__main__": from main import main as main_fn main_fn()