File size: 3,332 Bytes
f00a406
78c6106
 
f00a406
78c6106
 
 
f00a406
78c6106
 
 
956e282
af05832
78c6106
34aabf7
e55b892
78c6106
 
 
 
f00a406
159fcfe
78c6106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159fcfe
78c6106
 
 
 
 
 
 
159fcfe
78c6106
 
 
b0424cf
 
78c6106
f00a406
b0424cf
78c6106
 
f00a406
b0424cf
78c6106
 
f00a406
b0424cf
78c6106
 
 
f00a406
159fcfe
 
f00a406
e4413df
f00a406
b0424cf
159fcfe
f00a406
 
e2f16cb
78c6106
 
 
 
e2f16cb
78c6106
 
 
 
 
 
 
 
 
 
 
f00a406
78c6106
 
f00a406
78c6106
 
f00a406
78c6106
 
f00a406
78c6106
 
 
 
 
 
f00a406
f0a3a20
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient

from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph

import os
from dotenv import load_dotenv
load_dotenv()

# HuggingFace token
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))

# Initialize the HuggingFace model
model = InferenceClient(
    model="Qwen/Qwen2.5-72B-Instruct",
    api_key=os.getenv("HUGGINGFACE_TOKEN")
)

# Define the function that calls the model
def call_model(state: MessagesState):
    """
    Call the model with the given messages

    Args:
        state: MessagesState

    Returns:
        dict: A dictionary containing the generated text and the thread ID
    """
    # Convert LangChain messages to HuggingFace format
    hf_messages = []
    for msg in state["messages"]:
        if isinstance(msg, HumanMessage):
            hf_messages.append({"role": "user", "content": msg.content})
        elif isinstance(msg, AIMessage):
            hf_messages.append({"role": "assistant", "content": msg.content})
    
    # Call the API
    response = model.chat_completion(
        messages=hf_messages,
        temperature=0.5,
        max_tokens=64,
        top_p=0.7
    )
    
    # Convert the response to LangChain format
    ai_message = AIMessage(content=response.choices[0].message.content)
    return {"messages": state["messages"] + [ai_message]}

# Define the graph
workflow = StateGraph(state_schema=MessagesState)

# Define the node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
graph_app = workflow.compile(checkpointer=memory)

# Define the data model for the request
class QueryRequest(BaseModel):
    query: str
    thread_id: str = "default"

# Create the FastAPI application
app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")

# Welcome endpoint
@app.get("/")
async def api_home():
    """Welcome endpoint"""
    return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}

# Generate endpoint
@app.post("/generate")
async def generate(request: QueryRequest):
    """
    Endpoint to generate text using the language model
    
    Args:
        request: QueryRequest
        query: str
        thread_id: str = "default"

    Returns:
        dict: A dictionary containing the generated text and the thread ID
    """
    try:
        # Configure the thread ID
        config = {"configurable": {"thread_id": request.thread_id}}
        
        # Create the input message
        input_messages = [HumanMessage(content=request.query)]
        
        # Invoke the graph
        output = graph_app.invoke({"messages": input_messages}, config)
        
        # Get the model response
        response = output["messages"][-1].content
        
        return {
            "generated_text": response,
            "thread_id": request.thread_id
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)