Maximofn commited on
Commit
34aabf7
·
1 Parent(s): 167addd

Switch from HuggingFace InferenceClient to local model loading

Browse files

- Replace InferenceClient with local model loading using transformers library
- Use AutoModelForCausalLM and AutoTokenizer for direct model initialization
- Create a text generation pipeline with custom generation parameters
- Modify call_model function to work with local model generation
- Improve token and model loading logging

Files changed (1) hide show
  1. app.py +39 -20
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
4
 
5
  from langchain_core.messages import HumanMessage, AIMessage
6
  from langgraph.checkpoint.memory import MemorySaver
@@ -10,19 +10,41 @@ import os
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
 
13
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))
14
- print(HUGGINGFACE_TOKEN)
15
 
16
- # Initialize the HuggingFace model
17
- model = InferenceClient(
18
- model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
19
- api_key=os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
 
22
  # Define the function that calls the model
23
  def call_model(state: MessagesState):
24
  """
25
- Call the model with the given messages
26
 
27
  Args:
28
  state: MessagesState
@@ -30,24 +52,21 @@ def call_model(state: MessagesState):
30
  Returns:
31
  dict: A dictionary containing the generated text and the thread ID
32
  """
33
- # Convert LangChain messages to HuggingFace format
34
- hf_messages = []
35
  for msg in state["messages"]:
36
  if isinstance(msg, HumanMessage):
37
- hf_messages.append({"role": "user", "content": msg.content})
38
  elif isinstance(msg, AIMessage):
39
- hf_messages.append({"role": "assistant", "content": msg.content})
 
 
40
 
41
- # Call the API
42
- response = model.chat_completion(
43
- messages=hf_messages,
44
- temperature=0.5,
45
- max_tokens=64,
46
- top_p=0.7
47
- )
48
 
49
- # Convert the response to LangChain format
50
- ai_message = AIMessage(content=response.choices[0].message.content)
51
  return {"messages": state["messages"] + [ai_message]}
52
 
53
  # Define the graph
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
  from langchain_core.messages import HumanMessage, AIMessage
6
  from langgraph.checkpoint.memory import MemorySaver
 
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
13
+ # HuggingFace token
14
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))
15
+ print(f"Token HuggingFace: {HUGGINGFACE_TOKEN}")
16
 
17
+ # Model to use
18
+ MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
19
+
20
+ # Initialize the model and tokenizer locally with authentication
21
+ print(f"Loading model {MODEL_NAME} locally...")
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ MODEL_NAME,
24
+ token=HUGGINGFACE_TOKEN # Add token for authentication
25
+ )
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_NAME,
28
+ device_map="auto",
29
+ token=HUGGINGFACE_TOKEN # Add token for authentication
30
+ )
31
+
32
+ # Create a pipeline to facilitate generation
33
+ generator = pipeline(
34
+ "text-generation",
35
+ model=model,
36
+ tokenizer=tokenizer,
37
+ max_new_tokens=64,
38
+ do_sample=True,
39
+ temperature=0.5,
40
+ top_p=0.7,
41
+ pad_token_id=tokenizer.eos_token_id
42
  )
43
 
44
  # Define the function that calls the model
45
  def call_model(state: MessagesState):
46
  """
47
+ Call the local model with the given messages
48
 
49
  Args:
50
  state: MessagesState
 
52
  Returns:
53
  dict: A dictionary containing the generated text and the thread ID
54
  """
55
+ # Convert LangChain messages to a format that the local model can understand
56
+ prompt = ""
57
  for msg in state["messages"]:
58
  if isinstance(msg, HumanMessage):
59
+ prompt += f"User: {msg.content}\n"
60
  elif isinstance(msg, AIMessage):
61
+ prompt += f"Assistant: {msg.content}\n"
62
+
63
+ prompt += "Assistant: "
64
 
65
+ # Generate response with the local model
66
+ response = generator(prompt, return_full_text=False)[0]['generated_text']
 
 
 
 
 
67
 
68
+ # Convert the response to the LangChain format
69
+ ai_message = AIMessage(content=response.strip())
70
  return {"messages": state["messages"] + [ai_message]}
71
 
72
  # Define the graph