SmolLM2_backend / app.py
Maximofn's picture
Re-enable HuggingFace model integration and restore generate endpoint
e2f16cb
raw
history blame
3.52 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
# Initialize the HuggingFace client
client = InferenceClient(
model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
)
# Define the function that calls the model
def call_model(state: MessagesState):
# Convert LangChain messages to HuggingFace format
hf_messages = []
for msg in state["messages"]:
if isinstance(msg, HumanMessage):
hf_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
hf_messages.append({"role": "assistant", "content": msg.content})
# Call the API
response = client.chat_completion(
messages=hf_messages,
temperature=0.5,
max_tokens=64,
top_p=0.7
)
# Convert the response to LangChain format
ai_message = AIMessage(content=response.choices[0].message.content)
return {"messages": state["messages"] + [ai_message]}
# Define the graph
workflow = StateGraph(state_schema=MessagesState)
# Define the node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)
# Add memory
memory = MemorySaver()
graph_app = workflow.compile(checkpointer=memory)
# Define the data model for the request
class QueryRequest(BaseModel):
query: str
thread_id: str = "default"
# Create the FastAPI application
app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
@app.get("/")
async def api_home():
"""Welcome endpoint"""
return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
# Generate endpoint
@app.post("/generate")
async def generate(request: QueryRequest):
"""
Endpoint to generate text using the language model
Args:
request: QueryRequest
query: str
thread_id: str = "default"
Returns:
dict: A dictionary containing the generated text and the thread ID
"""
try:
# Configure the thread ID
config = {"configurable": {"thread_id": request.thread_id}}
# Create the input message
input_messages = [HumanMessage(content=request.query)]
# Invoke the graph
output = graph_app.invoke({"messages": input_messages}, config)
# Get the model response
response = output["messages"][-1].content
return {
"generated_text": response,
"thread_id": request.thread_id
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
# Add an endpoint to test the token directly
# @app.get("/test-token")
# async def test_token():
# """Endpoint to test the authentication with HuggingFace"""
# try:
# # Make a simple request to verify that the token works
# response = client.chat_completion(
# messages=[{"role": "user", "content": "Hello"}],
# max_tokens=10
# )
# return {"status": "success", "message": "Token is valid", "response": response.choices[0].message.content}
# except Exception as e:
# return {"status": "error", "message": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)