Maximofn commited on
Commit
e2f16cb
·
1 Parent(s): 3896f1b

Re-enable HuggingFace model integration and restore generate endpoint

Browse files

- Uncomment HuggingFace client initialization
- Restore model calling and graph configuration
- Re-implement generate endpoint with improved docstring
- Add detailed error handling for text generation

Files changed (1) hide show
  1. app.py +58 -47
app.py CHANGED
@@ -7,47 +7,47 @@ from langgraph.checkpoint.memory import MemorySaver
7
  from langgraph.graph import START, MessagesState, StateGraph
8
 
9
  # Initialize the HuggingFace client
10
- # client = InferenceClient(
11
- # model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
12
- # )
13
 
14
  # Define the function that calls the model
15
- # def call_model(state: MessagesState):
16
- # # Convert LangChain messages to HuggingFace format
17
- # hf_messages = []
18
- # for msg in state["messages"]:
19
- # if isinstance(msg, HumanMessage):
20
- # hf_messages.append({"role": "user", "content": msg.content})
21
- # elif isinstance(msg, AIMessage):
22
- # hf_messages.append({"role": "assistant", "content": msg.content})
23
 
24
- # # Call the API
25
- # response = client.chat_completion(
26
- # messages=hf_messages,
27
- # temperature=0.5,
28
- # max_tokens=64,
29
- # top_p=0.7
30
- # )
31
 
32
- # # Convert the response to LangChain format
33
- # ai_message = AIMessage(content=response.choices[0].message.content)
34
- # return {"messages": state["messages"] + [ai_message]}
35
 
36
  # Define the graph
37
- # workflow = StateGraph(state_schema=MessagesState)
38
 
39
  # Define the node in the graph
40
- # workflow.add_edge(START, "model")
41
- # workflow.add_node("model", call_model)
42
 
43
  # Add memory
44
- # memory = MemorySaver()
45
- # graph_app = workflow.compile(checkpointer=memory)
46
 
47
  # Define the data model for the request
48
- # class QueryRequest(BaseModel):
49
- # query: str
50
- # thread_id: str = "default"
51
 
52
  # Create the FastAPI application
53
  app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
@@ -57,28 +57,39 @@ async def api_home():
57
  """Welcome endpoint"""
58
  return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
59
 
60
- # @app.post("/generate")
61
- # async def generate(request: QueryRequest):
62
- # """Endpoint to generate text using the language model"""
63
- # try:
64
- # # Configure the thread ID
65
- # config = {"configurable": {"thread_id": request.thread_id}}
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # # Create the input message
68
- # input_messages = [HumanMessage(content=request.query)]
69
 
70
- # # Invoke the graph
71
- # output = graph_app.invoke({"messages": input_messages}, config)
72
 
73
- # # Get the model response
74
- # response = output["messages"][-1].content
75
 
76
- # return {
77
- # "generated_text": response,
78
- # "thread_id": request.thread_id
79
- # }
80
- # except Exception as e:
81
- # raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
82
 
83
  # Add an endpoint to test the token directly
84
  # @app.get("/test-token")
 
7
  from langgraph.graph import START, MessagesState, StateGraph
8
 
9
  # Initialize the HuggingFace client
10
+ client = InferenceClient(
11
+ model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
12
+ )
13
 
14
  # Define the function that calls the model
15
+ def call_model(state: MessagesState):
16
+ # Convert LangChain messages to HuggingFace format
17
+ hf_messages = []
18
+ for msg in state["messages"]:
19
+ if isinstance(msg, HumanMessage):
20
+ hf_messages.append({"role": "user", "content": msg.content})
21
+ elif isinstance(msg, AIMessage):
22
+ hf_messages.append({"role": "assistant", "content": msg.content})
23
 
24
+ # Call the API
25
+ response = client.chat_completion(
26
+ messages=hf_messages,
27
+ temperature=0.5,
28
+ max_tokens=64,
29
+ top_p=0.7
30
+ )
31
 
32
+ # Convert the response to LangChain format
33
+ ai_message = AIMessage(content=response.choices[0].message.content)
34
+ return {"messages": state["messages"] + [ai_message]}
35
 
36
  # Define the graph
37
+ workflow = StateGraph(state_schema=MessagesState)
38
 
39
  # Define the node in the graph
40
+ workflow.add_edge(START, "model")
41
+ workflow.add_node("model", call_model)
42
 
43
  # Add memory
44
+ memory = MemorySaver()
45
+ graph_app = workflow.compile(checkpointer=memory)
46
 
47
  # Define the data model for the request
48
+ class QueryRequest(BaseModel):
49
+ query: str
50
+ thread_id: str = "default"
51
 
52
  # Create the FastAPI application
53
  app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
 
57
  """Welcome endpoint"""
58
  return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
59
 
60
+ # Generate endpoint
61
+ @app.post("/generate")
62
+ async def generate(request: QueryRequest):
63
+ """
64
+ Endpoint to generate text using the language model
65
+
66
+ Args:
67
+ request: QueryRequest
68
+ query: str
69
+ thread_id: str = "default"
70
+
71
+ Returns:
72
+ dict: A dictionary containing the generated text and the thread ID
73
+ """
74
+ try:
75
+ # Configure the thread ID
76
+ config = {"configurable": {"thread_id": request.thread_id}}
77
 
78
+ # Create the input message
79
+ input_messages = [HumanMessage(content=request.query)]
80
 
81
+ # Invoke the graph
82
+ output = graph_app.invoke({"messages": input_messages}, config)
83
 
84
+ # Get the model response
85
+ response = output["messages"][-1].content
86
 
87
+ return {
88
+ "generated_text": response,
89
+ "thread_id": request.thread_id
90
+ }
91
+ except Exception as e:
92
+ raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
93
 
94
  # Add an endpoint to test the token directly
95
  # @app.get("/test-token")