Add full tools support to the chat template

#45
by Rocketknight1 HF Staff - opened
Files changed (2) hide show
  1. README.md +50 -37
  2. tokenizer_config.json +1 -1
README.md CHANGED
@@ -43,18 +43,32 @@ result = tokenizer.decode(out_tokens[0])
43
  print(result)
44
  ```
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ## Inference with hugging face `transformers`
47
 
48
  ```py
49
  from transformers import AutoModelForCausalLM
50
-
51
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")
 
 
52
  model.to("cuda")
53
 
54
- generated_ids = model.generate(tokens, max_new_tokens=1000, do_sample=True)
55
 
56
- # decode with mistral tokenizer
57
- result = tokenizer.decode(generated_ids[0].tolist())
58
  print(result)
59
  ```
60
 
@@ -64,7 +78,7 @@ print(result)
64
  ---
65
  The Mixtral-8x22B-Instruct-v0.1 Large Language Model (LLM) is an instruct fine-tuned version of the [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1).
66
 
67
- ## Run the model
68
  ```python
69
  from transformers import AutoModelForCausalLM
70
  from mistral_common.protocol.instruct.messages import (
@@ -122,56 +136,55 @@ sp_tokenizer = tokenizer_v3.instruct_tokenizer.tokenizer
122
  decoded = sp_tokenizer.decode(generated_ids[0])
123
  print(decoded)
124
  ```
125
- Alternatively, you can run this example with the Hugging Face tokenizer.
126
- To use this example, you'll need transformers version 4.39.0 or higher.
127
- ```console
128
- pip install transformers==4.39.0
129
- ```
 
 
130
  ```python
131
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
132
 
133
  model_id = "mistralai/Mixtral-8x22B-Instruct-v0.1"
134
  tokenizer = AutoTokenizer.from_pretrained(model_id)
135
- conversation=[
136
- {"role": "user", "content": "What's the weather like in Paris?"},
137
- {
138
- "role": "tool_calls",
139
- "content": [
140
- {
141
- "name": "get_current_weather",
142
- "arguments": {"location": "Paris, France", "format": "celsius"},
143
-
144
- }
145
- ]
146
- },
147
- {
148
- "role": "tool_results",
149
- "content": {"content": 22}
150
- },
151
- {"role": "assistant", "content": "The current temperature in Paris, France is 22 degrees Celsius."},
152
- {"role": "user", "content": "What about San Francisco?"}
153
- ]
154
-
155
-
156
- tools = [{"type": "function", "function": {"name":"get_current_weather", "description": "Get▁the▁current▁weather", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "format": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use. Infer this from the users location."}},"required":["location","format"]}}}]
157
 
158
  # render the tool use prompt as a string:
159
  tool_use_prompt = tokenizer.apply_chat_template(
160
  conversation,
161
- chat_template="tool_use",
162
  tools=tools,
163
  tokenize=False,
164
  add_generation_prompt=True,
165
-
166
  )
167
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")
168
 
169
  inputs = tokenizer(tool_use_prompt, return_tensors="pt")
170
 
171
- outputs = model.generate(**inputs, max_new_tokens=20)
 
 
172
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
173
  ```
174
 
 
 
 
 
 
 
175
  # Instruct tokenizer
176
  The HuggingFace tokenizer included in this release should match our own. To compare:
177
  `pip install mistral-common`
 
43
  print(result)
44
  ```
45
 
46
+ ## Preparing inputs with Hugging Face `transformers`
47
+
48
+ ```py
49
+ from transformers import AutoTokenizer
50
+
51
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")
52
+
53
+ chat = [{"role": "user", "content": "Explain Machine Learning to me in a nutshell."}]
54
+
55
+ tokens = tokenizer.apply_chat_template(chat, return_dict=True, return_tensors="pt", add_generation_prompt=True)
56
+ ```
57
+
58
  ## Inference with hugging face `transformers`
59
 
60
  ```py
61
  from transformers import AutoModelForCausalLM
62
+ import torch
63
+
64
+ # You can also use 8-bit or 4-bit quantization here
65
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1", torch_dtype=torch.bfloat16, device_map="auto")
66
  model.to("cuda")
67
 
68
+ generated_ids = model.generate(**tokens, max_new_tokens=1000, do_sample=True)
69
 
70
+ # decode with HF tokenizer
71
+ result = tokenizer.decode(generated_ids[0])
72
  print(result)
73
  ```
74
 
 
78
  ---
79
  The Mixtral-8x22B-Instruct-v0.1 Large Language Model (LLM) is an instruct fine-tuned version of the [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1).
80
 
81
+ ## Function calling example
82
  ```python
83
  from transformers import AutoModelForCausalLM
84
  from mistral_common.protocol.instruct.messages import (
 
136
  decoded = sp_tokenizer.decode(generated_ids[0])
137
  print(decoded)
138
  ```
139
+
140
+ ## Function calling with `transformers`
141
+
142
+ To use this example, you'll need `transformers` version 4.42.0 or higher. Please see the
143
+ [function calling guide](https://huggingface.co/docs/transformers/main/chat_templating#advanced-tool-use--function-calling)
144
+ in the `transformers` docs for more information.
145
+
146
  ```python
147
  from transformers import AutoModelForCausalLM, AutoTokenizer
148
+ import torch
149
 
150
  model_id = "mistralai/Mixtral-8x22B-Instruct-v0.1"
151
  tokenizer = AutoTokenizer.from_pretrained(model_id)
152
+
153
+ def get_current_weather(location: str, format: str):
154
+ """
155
+ Get the current weather
156
+
157
+ Args:
158
+ location: The city and state, e.g. San Francisco, CA
159
+ format: The temperature unit to use. Infer this from the users location. (choices: ["celsius", "fahrenheit"])
160
+ """
161
+ pass
162
+
163
+ conversation = [{"role": "user", "content": "What's the weather like in Paris?"}]
164
+ tools = [get_current_weather]
 
 
 
 
 
 
 
 
 
165
 
166
  # render the tool use prompt as a string:
167
  tool_use_prompt = tokenizer.apply_chat_template(
168
  conversation,
 
169
  tools=tools,
170
  tokenize=False,
171
  add_generation_prompt=True,
 
172
  )
 
173
 
174
  inputs = tokenizer(tool_use_prompt, return_tensors="pt")
175
 
176
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
177
+
178
+ outputs = model.generate(**inputs, max_new_tokens=1000)
179
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
180
  ```
181
 
182
+ Note that, for reasons of space, this example does not show a complete cycle of calling a tool and adding the tool call and tool
183
+ results to the chat history so that the model can use them in its next generation. For a full tool calling example, please
184
+ see the [function calling guide](https://huggingface.co/docs/transformers/main/chat_templating#advanced-tool-use--function-calling),
185
+ and note that Mixtral **does** use tool call IDs, so these must be included in your tool calls and tool results. They should be
186
+ exactly 9 alphanumeric characters.
187
+
188
  # Instruct tokenizer
189
  The HuggingFace tokenizer included in this release should match our own. To compare:
190
  `pip install mistral-common`
tokenizer_config.json CHANGED
@@ -6173,7 +6173,7 @@
6173
  }
6174
  },
6175
  "bos_token": "<s>",
6176
- "chat_template": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.last and system_message is defined %}\n {{- '[INST] ' + system_message + '\\n\\n' + message['content'] + '[/INST]' }}\n {%- else %}\n {{- '[INST] ' + message['content'] + '[/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
6177
  "clean_up_tokenization_spaces": false,
6178
  "eos_token": "</s>",
6179
  "legacy": false,
 
6173
  }
6174
  },
6175
  "bos_token": "<s>",
6176
+ "chat_template": "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{%- for message in loop_messages | rejectattr(\"role\", \"equalto\", \"tool\") | rejectattr(\"role\", \"equalto\", \"tool_results\") | selectattr(\"tool_calls\", \"undefined\") %}\n {%- if (message[\"role\"] == \"user\") != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message[\"role\"] == \"tool_calls\" or message.tool_calls is defined %}\n {%- if message.tool_calls is defined %}\n {%- set tool_calls = message.tool_calls %}\n {%- else %}\n {%- set tool_calls = message.content %}\n {%- endif %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
6177
  "clean_up_tokenization_spaces": false,
6178
  "eos_token": "</s>",
6179
  "legacy": false,