Maximofn commited on
Commit
e55b892
·
1 Parent(s): 543aa6c

Switch from local SmolLM2 model to Qwen2.5-72B-Instruct via HuggingFace Inference API

Browse files

- Replace local model loading with HuggingFace InferenceClient
- Update call_model function to use new API-based chat completion method
- Simplify model initialization and remove local model loading code
- Switch from SmolLM2-135M-Instruct to Qwen2.5-72B-Instruct model

Files changed (1) hide show
  1. app.py +20 -48
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
 
5
  from langchain_core.messages import HumanMessage, AIMessage
6
  from langgraph.checkpoint.memory import MemorySaver
@@ -10,47 +10,19 @@ import os
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
13
- # Configure cache directory
14
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
15
- os.environ["HF_HOME"] = "/tmp/hf_home"
16
-
17
- # HuggingFace token
18
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))
19
- print(f"Token HuggingFace: {HUGGINGFACE_TOKEN}")
20
-
21
- # Model to use
22
- MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct"
23
 
24
- # Initialize the model and tokenizer locally with authentication
25
- print(f"Loading model {MODEL_NAME} locally...")
26
- tokenizer = AutoTokenizer.from_pretrained(
27
- MODEL_NAME,
28
- token=HUGGINGFACE_TOKEN, # Add token for authentication
29
- cache_dir="/tmp/transformers_cache" # Specify cache directory
30
- )
31
- model = AutoModelForCausalLM.from_pretrained(
32
- MODEL_NAME,
33
- device_map="auto",
34
- token=HUGGINGFACE_TOKEN, # Add token for authentication
35
- cache_dir="/tmp/transformers_cache" # Specify cache directory
36
- )
37
-
38
- # Create a pipeline to facilitate generation
39
- generator = pipeline(
40
- "text-generation",
41
- model=model,
42
- tokenizer=tokenizer,
43
- max_new_tokens=64,
44
- do_sample=True,
45
- temperature=0.5,
46
- top_p=0.7,
47
- pad_token_id=tokenizer.eos_token_id
48
  )
49
 
50
  # Define the function that calls the model
51
  def call_model(state: MessagesState):
52
  """
53
- Call the local model with the given messages
54
 
55
  Args:
56
  state: MessagesState
@@ -58,24 +30,24 @@ def call_model(state: MessagesState):
58
  Returns:
59
  dict: A dictionary containing the generated text and the thread ID
60
  """
61
- # System prompt to guide the model's behavior
62
- system_prompt = "You are a friendly Chatbot. Always reply in the language in which the user is writing to you."
63
-
64
- # Convert LangChain messages to a format that the local model can understand
65
- prompt = f"System: {system_prompt}\n\n"
66
  for msg in state["messages"]:
67
  if isinstance(msg, HumanMessage):
68
- prompt += f"User: {msg.content}\n"
69
  elif isinstance(msg, AIMessage):
70
- prompt += f"Assistant: {msg.content}\n"
71
-
72
- prompt += "Assistant: "
73
 
74
- # Generate response with the local model
75
- response = generator(prompt, return_full_text=False)[0]['generated_text']
 
 
 
 
 
76
 
77
- # Convert the response to the LangChain format
78
- ai_message = AIMessage(content=response.strip())
79
  return {"messages": state["messages"] + [ai_message]}
80
 
81
  # Define the graph
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
4
 
5
  from langchain_core.messages import HumanMessage, AIMessage
6
  from langgraph.checkpoint.memory import MemorySaver
 
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
 
 
 
 
 
13
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))
14
+ print(HUGGINGFACE_TOKEN)
 
 
 
15
 
16
+ # Initialize the HuggingFace model
17
+ model = InferenceClient(
18
+ model="Qwen/Qwen2.5-72B-Instruct",
19
+ api_key=os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
 
22
  # Define the function that calls the model
23
  def call_model(state: MessagesState):
24
  """
25
+ Call the model with the given messages
26
 
27
  Args:
28
  state: MessagesState
 
30
  Returns:
31
  dict: A dictionary containing the generated text and the thread ID
32
  """
33
+ # Convert LangChain messages to HuggingFace format
34
+ hf_messages = []
 
 
 
35
  for msg in state["messages"]:
36
  if isinstance(msg, HumanMessage):
37
+ hf_messages.append({"role": "user", "content": msg.content})
38
  elif isinstance(msg, AIMessage):
39
+ hf_messages.append({"role": "assistant", "content": msg.content})
 
 
40
 
41
+ # Call the API
42
+ response = model.chat_completion(
43
+ messages=hf_messages,
44
+ temperature=0.5,
45
+ max_tokens=64,
46
+ top_p=0.7
47
+ )
48
 
49
+ # Convert the response to LangChain format
50
+ ai_message = AIMessage(content=response.choices[0].message.content)
51
  return {"messages": state["messages"] + [ai_message]}
52
 
53
  # Define the graph