Maximofn commited on
Commit
e4413df
·
1 Parent(s): e2f16cb

Refactor app.py with improved model initialization and endpoint structure

Browse files

- Rename HuggingFace client to 'model' for clarity
- Add docstring to call_model function
- Update generate endpoint route to '/api/generate'
- Remove commented-out test-token endpoint
- Enhance code readability and documentation

Files changed (1) hide show
  1. app.py +14 -18
app.py CHANGED
@@ -6,13 +6,22 @@ from langchain_core.messages import HumanMessage, AIMessage
6
  from langgraph.checkpoint.memory import MemorySaver
7
  from langgraph.graph import START, MessagesState, StateGraph
8
 
9
- # Initialize the HuggingFace client
10
- client = InferenceClient(
11
  model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
12
  )
13
 
14
  # Define the function that calls the model
15
  def call_model(state: MessagesState):
 
 
 
 
 
 
 
 
 
16
  # Convert LangChain messages to HuggingFace format
17
  hf_messages = []
18
  for msg in state["messages"]:
@@ -22,7 +31,7 @@ def call_model(state: MessagesState):
22
  hf_messages.append({"role": "assistant", "content": msg.content})
23
 
24
  # Call the API
25
- response = client.chat_completion(
26
  messages=hf_messages,
27
  temperature=0.5,
28
  max_tokens=64,
@@ -52,13 +61,14 @@ class QueryRequest(BaseModel):
52
  # Create the FastAPI application
53
  app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
54
 
 
55
  @app.get("/")
56
  async def api_home():
57
  """Welcome endpoint"""
58
  return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
59
 
60
  # Generate endpoint
61
- @app.post("/generate")
62
  async def generate(request: QueryRequest):
63
  """
64
  Endpoint to generate text using the language model
@@ -91,20 +101,6 @@ async def generate(request: QueryRequest):
91
  except Exception as e:
92
  raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
93
 
94
- # Add an endpoint to test the token directly
95
- # @app.get("/test-token")
96
- # async def test_token():
97
- # """Endpoint to test the authentication with HuggingFace"""
98
- # try:
99
- # # Make a simple request to verify that the token works
100
- # response = client.chat_completion(
101
- # messages=[{"role": "user", "content": "Hello"}],
102
- # max_tokens=10
103
- # )
104
- # return {"status": "success", "message": "Token is valid", "response": response.choices[0].message.content}
105
- # except Exception as e:
106
- # return {"status": "error", "message": str(e)}
107
-
108
  if __name__ == "__main__":
109
  import uvicorn
110
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
6
  from langgraph.checkpoint.memory import MemorySaver
7
  from langgraph.graph import START, MessagesState, StateGraph
8
 
9
+ # Initialize the HuggingFace model
10
+ model = InferenceClient(
11
  model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
12
  )
13
 
14
  # Define the function that calls the model
15
  def call_model(state: MessagesState):
16
+ """
17
+ Call the model with the given messages
18
+
19
+ Args:
20
+ state: MessagesState
21
+
22
+ Returns:
23
+ dict: A dictionary containing the generated text and the thread ID
24
+ """
25
  # Convert LangChain messages to HuggingFace format
26
  hf_messages = []
27
  for msg in state["messages"]:
 
31
  hf_messages.append({"role": "assistant", "content": msg.content})
32
 
33
  # Call the API
34
+ response = model.chat_completion(
35
  messages=hf_messages,
36
  temperature=0.5,
37
  max_tokens=64,
 
61
  # Create the FastAPI application
62
  app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
63
 
64
+ # Welcome endpoint
65
  @app.get("/")
66
  async def api_home():
67
  """Welcome endpoint"""
68
  return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
69
 
70
  # Generate endpoint
71
+ @app.post("/api/generate")
72
  async def generate(request: QueryRequest):
73
  """
74
  Endpoint to generate text using the language model
 
101
  except Exception as e:
102
  raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if __name__ == "__main__":
105
  import uvicorn
106
  uvicorn.run(app, host="0.0.0.0", port=8000)