Spaces:
Running
Running
| import os, time | |
| from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, AIMessage | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langgraph.graph import StateGraph, MessagesState, START, END | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from tools import my_tools | |
| try: | |
| if os.path.exists("system_prompt.txt"): | |
| with open("system_prompt.txt", "r") as f: | |
| system_prompt = f.read() | |
| else: | |
| system_prompt = "" | |
| except Exception as e: | |
| print(f"Error reading system prompt file: {e}") | |
| system_prompt = "" | |
| sys_msg = SystemMessage(system_prompt) | |
| class MyMessagesState(MessagesState): | |
| tool_call_cnt: int=0 | |
| class MyAgent: | |
| def __init__( | |
| self, | |
| model_name: str="gemini-2.0-flash", | |
| tools: list | None=None, | |
| max_tool_call: int | None=5 | |
| ): | |
| self.graph = self._build_graph(model_name, tools, max_tool_call) | |
| def _build_graph( | |
| self, | |
| model_name: str, | |
| tools: list | None, | |
| max_tool_call: int | None | |
| ): | |
| self.model_name = model_name | |
| # Delays to avoid exceeding rate limits | |
| if "gemini-2.5" in model_name: | |
| self.delay = 6 | |
| elif "gemini-2.0" in model_name: | |
| self.delay = 4 | |
| else: | |
| self.delay = 0 | |
| self.start_time = None | |
| llm = ChatGoogleGenerativeAI(model=model_name) | |
| if not tools: | |
| tools = my_tools | |
| llm_with_tools = llm.bind_tools(tools) | |
| def _delay_wait(): | |
| if self.start_time is not None: | |
| elapsed = time.time() - self.start_time | |
| if elapsed < self.delay: | |
| time.sleep(self.delay - elapsed) | |
| self.start_time = time.time() | |
| def assistant_with_tools(state: MyMessagesState) -> MyMessagesState: | |
| print("assistant_with_tools") | |
| _delay_wait() | |
| return { | |
| "messages": [llm_with_tools.invoke(state["messages"])], | |
| "tool_call_cnt": state["tool_call_cnt"] | |
| } | |
| def assistant_without_tools(state: MyMessagesState) -> MyMessagesState: | |
| print("assistant_without_tools") | |
| _delay_wait() | |
| return { | |
| "messages": [llm.invoke(state["messages"][:2])], | |
| "tool_call_cnt": state["tool_call_cnt"] | |
| } | |
| def increment_tool_call_cnt(state: MyMessagesState) -> MyMessagesState: | |
| print("increment_tool_call_cnt") | |
| return { | |
| "messages": state["messages"], | |
| "tool_call_cnt": state["tool_call_cnt"] + 1 | |
| } | |
| def check_tool_call_status(state: MyMessagesState): | |
| # Initialize tool calling counter | |
| if "tool_call_cnt" not in state: | |
| state["tool_call_cnt"] = 0 | |
| # If tool call error, fall back to LLM without tools instead of return failure message | |
| if "messages" in state: | |
| last_msg = state["messages"][-1] | |
| if isinstance(last_msg, AIMessage) and "sorry" in last_msg.content.lower(): | |
| return "assistant_without_tools" | |
| tool_msg = state["messages"][-2] | |
| if isinstance(tool_msg, ToolMessage) and tool_msg.status == "error": | |
| return "assistant_without_tools" | |
| target_node = tools_condition(state) | |
| if (target_node == "tools"): | |
| # If tool calling count exceeds limit, fallback to LLM without tools | |
| if (max_tool_call is not None) and state["tool_call_cnt"] >= max_tool_call: | |
| return "assistant_without_tools" | |
| return target_node | |
| builder = StateGraph(MyMessagesState) | |
| # Add nodes | |
| builder.add_node("assistant_with_tools", assistant_with_tools) | |
| builder.add_node("assistant_without_tools", assistant_without_tools) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_node("increment_tool_call_cnt", increment_tool_call_cnt) | |
| # Add edges | |
| builder.add_edge(START, "assistant_with_tools") | |
| builder.add_conditional_edges("assistant_with_tools", check_tool_call_status) | |
| builder.add_edge("tools", "increment_tool_call_cnt") | |
| builder.add_edge("increment_tool_call_cnt", "assistant_with_tools") | |
| builder.add_edge("assistant_without_tools", END) | |
| return builder.compile() | |
| def ask(self, question: str, final_answer_only: bool=True) -> str: | |
| if "gemma" in self.model_name: | |
| messages = [ | |
| HumanMessage(f"{system_prompt} {question}") | |
| ] | |
| else: | |
| messages = [ | |
| sys_msg, | |
| HumanMessage(question) | |
| ] | |
| messages = self.graph.invoke({"messages": messages, "tool_call_cnt": 0}) | |
| response = messages["messages"][-1].content.strip() | |
| return self.format_output(response) if final_answer_only else response | |
| def format_output(self, output: str) -> str: | |
| # Find the text behind "FINAL ANSWER: " | |
| if "FINAL ANSWER: " in output: | |
| return output.split("FINAL ANSWER: ")[-1].strip() | |
| # Find the first text between ** and ** (bold markdown) | |
| elif "**" in output: | |
| return output.split("**")[1].strip() | |
| else: | |
| return output | |
| def __call__(self, question: str) -> str: | |
| return self.ask(question) | |
| if __name__ == "__main__": | |
| from main import main as main_fn | |
| main_fn() |