wangrongsheng commited on
Commit
cbeb6de
·
verified ·
1 Parent(s): 3c94a23

Upload folder using huggingface_hub

Browse files
Files changed (43) hide show
  1. LICENSE +34 -0
  2. Open Source Software Notice +218 -0
  3. README.md +127 -6
  4. README_EN.md +124 -0
  5. checklist.chk +36 -0
  6. config.json +31 -0
  7. configuration_openpangu_dense.py +56 -0
  8. generation_config.json +11 -0
  9. inference/generate.py +57 -0
  10. inference/vllm_ascend/_build_info.py +3 -0
  11. inference/vllm_ascend/attention/attention.py +1220 -0
  12. inference/vllm_ascend/attention/mla_v1.py +1224 -0
  13. inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py +6 -0
  14. inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py +171 -0
  15. inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py +6 -0
  16. inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py +300 -0
  17. inference/vllm_ascend/envs.py +153 -0
  18. inference/vllm_ascend/models/__init__.py +68 -0
  19. inference/vllm_ascend/models/open_pangu.py +1127 -0
  20. inference/vllm_ascend/ops/fused_moe.py +1530 -0
  21. inference/vllm_ascend/patch/worker/patch_common/__init__.py +27 -0
  22. inference/vllm_ascend/patch/worker/patch_common/patch_config.py +97 -0
  23. inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py +26 -0
  24. inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py +159 -0
  25. inference/vllm_ascend/quantization/w8a8.py +757 -0
  26. inference/vllm_ascend/quantization/w8a8_dynamic.py +831 -0
  27. inference/vllm_ascend/tokenizer_config.json +1 -0
  28. inference/vllm_ascend/utils.py +563 -0
  29. inference/vllm_ascend/worker/model_runner_v1.py +0 -0
  30. inference/vllm_ascend/worker/npu_input_batch.py +796 -0
  31. inference/vllm_ascend_for_openpangu_embedded_7b.md +129 -0
  32. inference/vllm_ascend_for_openpangu_embedded_7b.zh.md +129 -0
  33. model-00001-of-00004.safetensors +3 -0
  34. model-00002-of-00004.safetensors +3 -0
  35. model-00003-of-00004.safetensors +3 -0
  36. model-00004-of-00004.safetensors +3 -0
  37. model.safetensors.index.json +486 -0
  38. modeling_openpangu_dense.py +585 -0
  39. modular_openpangu_dense.py +149 -0
  40. special_tokens_map.json +30 -0
  41. tokenization_openpangu.py +273 -0
  42. tokenizer.model +3 -0
  43. tokenizer_config.json +1 -0
LICENSE ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0
2
+
3
+ This OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 (the "Agreement") is a legal agreement between You and Huawei Technologies Co., Ltd. ("Huawei", "We" or "Us"), and it governs Your reproducing, use, modification, and distribution of openPangu as made available by Huawei under this Agreement.
4
+
5
+ By using, reproducing, modifying, distributing, performing or displaying any portion or element of openPangu, or otherwise accepting the terms of this Agreement, You agree to be bound by this Agreement.
6
+
7
+ 1. Definitions.
8
+ 1.1. “openPangu” or “Model” means openPangu large language models and software, including trained model weights, parameters (including optimizer states), accompanying source code and scripts released under this Agreement.
9
+ 1.2. “Derivative Model” means all (1) modifications to the Model, (2) works based on the Model, and (3) any other derivative works of the Model. For clarity, information or content results from operating or otherwise using the Model is not a Derivative Model.
10
+ 1.3. “You” or “Your” means an individual or Legal Entity exercising permissions granted by this Agreement and/or using the Model for any purpose.
11
+ 1.4. “Third Party” or “Third Parties” means individuals or legal entities that are not under common control with Us or You.
12
+
13
+ 2. License Grant. Subject to Your full compliance with the terms and conditions of this Agreement, We hereby grant to You a perpetual, worldwide, non-exclusive, non-transferable, no-charge, royalty-free license (except as stated in Section 3) to use, reproduce, modify, and distribute the Model.
14
+
15
+ 3. Conditions for License Grant. You represent and warrant that You will not, access, download, install, run, deploy, integrate, modify, or otherwise use the Model, directly or indirectly, within the European Union.
16
+
17
+
18
+ 4. Redistribution.
19
+ 4.1. If You distribute the Model or Derivative Model, You shall retain in Your distribution (1) a copy of this agreement, and (2) all copyright notices and other notices of origin included in the Model that are applicable to Your distribution.
20
+ 4.2. Further, if You distribute or make available to Third Parties a product or service (including another AI model) based on the Model, You are required to (1) display the acknowledgement “Powered by openPangu” and (2) include a trademark notice “openPangu is a trademark of Huawei Technologies Co., Ltd.” on related webpages, user manuals, product documentations or other advertising materials mentioning features of the Model.
21
+ 4.3. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for Derivative Model made by You as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the terms and conditions of this Agreement.
22
+
23
+ 5. Ownership. We do not claim ownership to any information or content generated using the Model or Derivative Model that are made by You. You are solely responsible for evaluating the accuracy and appropriateness of such information or content for Your use case.
24
+
25
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of Huawei, except as required for complying with Section 4.2.
26
+
27
+ 7. Indemnity. You will indemnify and hold harmless Huawei from and against any claim by any third party arising out of or related to Your use or distribution of the Model or Derivative Model made by You (e.g. a violation against Section 3). For avoidance of doubt, “third party” in this clause include supervisory authorities.
28
+
29
+ 8. THE MODEL IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NONINFRINGEMENT, ACCURACY, OR THE ABSENCE OF LATENT OR OTHER DEFECTS OR ERRORS, WHETHER OR NOT DISCOVERABLE, ALL TO THE GREATEST EXTENT PERMISSIBLE UNDER APPLICABLE LAW.
30
+
31
+ 9. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MODEL, IN WHOLE OR IN PART, NO MATTER HOW IT’S CAUSED OR THE LEGAL THEORY IT IS BASED ON, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
32
+
33
+
34
+ END OF THE TERMS AND CONDITIONS
Open Source Software Notice ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPEN SOURCE SOFTWARE NOTICE
2
+
3
+ Please note we provide an open source software notice along with this product and/or this product firmware (in the following just “this product”). The open source software licenses are granted by the respective right holders. And the open source licenses prevail all other license information with regard to the respective open source software contained in the product, including but not limited to End User Software Licensing Agreement. This notice is provided on behalf of Huawei Technologies Co. Ltd. and any of its local subsidiaries which may have provided this product to you in your local country.
4
+
5
+ Warranty Disclaimer
6
+ THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
7
+
8
+ Copyright Notice and License Texts
9
+
10
+ Software: transformers 4.53.2
11
+ Copyright notice:
12
+ Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
13
+
14
+ License Text:
15
+ ----------------------------------------
16
+
17
+ Apache License
18
+ Version 2.0, January 2004
19
+ http://www.apache.org/licenses/
20
+
21
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
22
+
23
+ 1. Definitions.
24
+
25
+ "License" shall mean the terms and conditions for use, reproduction,
26
+ and distribution as defined by Sections 1 through 9 of this document.
27
+
28
+ "Licensor" shall mean the copyright owner or entity authorized by
29
+ the copyright owner that is granting the License.
30
+
31
+ "Legal Entity" shall mean the union of the acting entity and all
32
+ other entities that control, are controlled by, or are under common
33
+ control with that entity. For the purposes of this definition,
34
+ "control" means (i) the power, direct or indirect, to cause the
35
+ direction or management of such entity, whether by contract or
36
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
37
+ outstanding shares, or (iii) beneficial ownership of such entity.
38
+
39
+ "You" (or "Your") shall mean an individual or Legal Entity
40
+ exercising permissions granted by this License.
41
+
42
+ "Source" form shall mean the preferred form for making modifications,
43
+ including but not limited to software source code, documentation
44
+ source, and configuration files.
45
+
46
+ "Object" form shall mean any form resulting from mechanical
47
+ transformation or translation of a Source form, including but
48
+ not limited to compiled object code, generated documentation,
49
+ and conversions to other media types.
50
+
51
+ "Work" shall mean the work of authorship, whether in Source or
52
+ Object form, made available under the License, as indicated by a
53
+ copyright notice that is included in or attached to the work
54
+ (an example is provided in the Appendix below).
55
+
56
+ "Derivative Works" shall mean any work, whether in Source or Object
57
+ form, that is based on (or derived from) the Work and for which the
58
+ editorial revisions, annotations, elaborations, or other modifications
59
+ represent, as a whole, an original work of authorship. For the purposes
60
+ of this License, Derivative Works shall not include works that remain
61
+ separable from, or merely link (or bind by name) to the interfaces of,
62
+ the Work and Derivative Works thereof.
63
+
64
+ "Contribution" shall mean any work of authorship, including
65
+ the original version of the Work and any modifications or additions
66
+ to that Work or Derivative Works thereof, that is intentionally
67
+ submitted to Licensor for inclusion in the Work by the copyright owner
68
+ or by an individual or Legal Entity authorized to submit on behalf of
69
+ the copyright owner. For the purposes of this definition, "submitted"
70
+ means any form of electronic, verbal, or written communication sent
71
+ to the Licensor or its representatives, including but not limited to
72
+ communication on electronic mailing lists, source code control systems,
73
+ and issue tracking systems that are managed by, or on behalf of, the
74
+ Licensor for the purpose of discussing and improving the Work, but
75
+ excluding communication that is conspicuously marked or otherwise
76
+ designated in writing by the copyright owner as "Not a Contribution."
77
+
78
+ "Contributor" shall mean Licensor and any individual or Legal Entity
79
+ on behalf of whom a Contribution has been received by Licensor and
80
+ subsequently incorporated within the Work.
81
+
82
+ 2. Grant of Copyright License. Subject to the terms and conditions of
83
+ this License, each Contributor hereby grants to You a perpetual,
84
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
85
+ copyright license to reproduce, prepare Derivative Works of,
86
+ publicly display, publicly perform, sublicense, and distribute the
87
+ Work and such Derivative Works in Source or Object form.
88
+
89
+ 3. Grant of Patent License. Subject to the terms and conditions of
90
+ this License, each Contributor hereby grants to You a perpetual,
91
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
92
+ (except as stated in this section) patent license to make, have made,
93
+ use, offer to sell, sell, import, and otherwise transfer the Work,
94
+ where such license applies only to those patent claims licensable
95
+ by such Contributor that are necessarily infringed by their
96
+ Contribution(s) alone or by combination of their Contribution(s)
97
+ with the Work to which such Contribution(s) was submitted. If You
98
+ institute patent litigation against any entity (including a
99
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
100
+ or a Contribution incorporated within the Work constitutes direct
101
+ or contributory patent infringement, then any patent licenses
102
+ granted to You under this License for that Work shall terminate
103
+ as of the date such litigation is filed.
104
+
105
+ 4. Redistribution. You may reproduce and distribute copies of the
106
+ Work or Derivative Works thereof in any medium, with or without
107
+ modifications, and in Source or Object form, provided that You
108
+ meet the following conditions:
109
+
110
+ (a) You must give any other recipients of the Work or
111
+ Derivative Works a copy of this License; and
112
+
113
+ (b) You must cause any modified files to carry prominent notices
114
+ stating that You changed the files; and
115
+
116
+ (c) You must retain, in the Source form of any Derivative Works
117
+ that You distribute, all copyright, patent, trademark, and
118
+ attribution notices from the Source form of the Work,
119
+ excluding those notices that do not pertain to any part of
120
+ the Derivative Works; and
121
+
122
+ (d) If the Work includes a "NOTICE" text file as part of its
123
+ distribution, then any Derivative Works that You distribute must
124
+ include a readable copy of the attribution notices contained
125
+ within such NOTICE file, excluding those notices that do not
126
+ pertain to any part of the Derivative Works, in at least one
127
+ of the following places: within a NOTICE text file distributed
128
+ as part of the Derivative Works; within the Source form or
129
+ documentation, if provided along with the Derivative Works; or,
130
+ within a display generated by the Derivative Works, if and
131
+ wherever such third-party notices normally appear. The contents
132
+ of the NOTICE file are for informational purposes only and
133
+ do not modify the License. You may add Your own attribution
134
+ notices within Derivative Works that You distribute, alongside
135
+ or as an addendum to the NOTICE text from the Work, provided
136
+ that such additional attribution notices cannot be construed
137
+ as modifying the License.
138
+
139
+ You may add Your own copyright statement to Your modifications and
140
+ may provide additional or different license terms and conditions
141
+ for use, reproduction, or distribution of Your modifications, or
142
+ for any such Derivative Works as a whole, provided Your use,
143
+ reproduction, and distribution of the Work otherwise complies with
144
+ the conditions stated in this License.
145
+
146
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
147
+ any Contribution intentionally submitted for inclusion in the Work
148
+ by You to the Licensor shall be under the terms and conditions of
149
+ this License, without any additional terms or conditions.
150
+ Notwithstanding the above, nothing herein shall supersede or modify
151
+ the terms of any separate license agreement you may have executed
152
+ with Licensor regarding such Contributions.
153
+
154
+ 6. Trademarks. This License does not grant permission to use the trade
155
+ names, trademarks, service marks, or product names of the Licensor,
156
+ except as required for reasonable and customary use in describing the
157
+ origin of the Work and reproducing the content of the NOTICE file.
158
+
159
+ 7. Disclaimer of Warranty. Unless required by applicable law or
160
+ agreed to in writing, Licensor provides the Work (and each
161
+ Contributor provides its Contributions) on an "AS IS" BASIS,
162
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
163
+ implied, including, without limitation, any warranties or conditions
164
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
165
+ PARTICULAR PURPOSE. You are solely responsible for determining the
166
+ appropriateness of using or redistributing the Work and assume any
167
+ risks associated with Your exercise of permissions under this License.
168
+
169
+ 8. Limitation of Liability. In no event and under no legal theory,
170
+ whether in tort (including negligence), contract, or otherwise,
171
+ unless required by applicable law (such as deliberate and grossly
172
+ negligent acts) or agreed to in writing, shall any Contributor be
173
+ liable to You for damages, including any direct, indirect, special,
174
+ incidental, or consequential damages of any character arising as a
175
+ result of this License or out of the use or inability to use the
176
+ Work (including but not limited to damages for loss of goodwill,
177
+ work stoppage, computer failure or malfunction, or any and all
178
+ other commercial damages or losses), even if such Contributor
179
+ has been advised of the possibility of such damages.
180
+
181
+ 9. Accepting Warranty or Additional Liability. While redistributing
182
+ the Work or Derivative Works thereof, You may choose to offer,
183
+ and charge a fee for, acceptance of support, warranty, indemnity,
184
+ or other liability obligations and/or rights consistent with this
185
+ License. However, in accepting such obligations, You may act only
186
+ on Your own behalf and on Your sole responsibility, not on behalf
187
+ of any other Contributor, and only if You agree to indemnify,
188
+ defend, and hold each Contributor harmless for any liability
189
+ incurred by, or claims asserted against, such Contributor by reason
190
+ of your accepting any such warranty or additional liability.
191
+
192
+ END OF TERMS AND CONDITIONS
193
+
194
+ APPENDIX: How to apply the Apache License to your work.
195
+
196
+ To apply the Apache License to your work, attach the following
197
+ boilerplate notice, with the fields enclosed by brackets "[]"
198
+ replaced with your own identifying information. (Don't include
199
+ the brackets!) The text should be enclosed in the appropriate
200
+ comment syntax for the file format. We also recommend that a
201
+ file or class name and description of purpose be included on the
202
+ same "printed page" as the copyright notice for easier
203
+ identification within third-party archives.
204
+
205
+ Copyright [yyyy] [name of copyright owner]
206
+
207
+ Licensed under the Apache License, Version 2.0 (the "License");
208
+ you may not use this file except in compliance with the License.
209
+ You may obtain a copy of the License at
210
+
211
+ http://www.apache.org/licenses/LICENSE-2.0
212
+
213
+ Unless required by applicable law or agreed to in writing, software
214
+ distributed under the License is distributed on an "AS IS" BASIS,
215
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
216
+ See the License for the specific language governing permissions and
217
+ limitations under the License.
218
+
README.md CHANGED
@@ -1,6 +1,127 @@
1
- ---
2
- license: other
3
- license_name: openpangu-model-license-agreement-version-1.0
4
- license_link: >-
5
- https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-V1.1/blob/main/LICENSE
6
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 开源盘古 Embedded-7B-V1.1
2
+
3
+ 中文 | [English](README_EN.md)
4
+
5
+ ## 1. 简介
6
+
7
+ openPangu-Embedded-7B-V1.1 是基于昇腾 NPU 从零训练的高效大语言模型,参数量为 7B(不含词表Embedding)。openPangu-Embedded-7B-V1.1 训练了约 25T tokens,具备快慢思考融合与自适应切换能力。
8
+
9
+ ## 2. 模型架构
10
+
11
+ | | openPangu-Embedded-7B-V1.1 |
12
+ | :---------------------------: | :----------------: |
13
+ | **Architecture** | Dense |
14
+ | **Parameters (Non-Embedding)** | 7B |
15
+ | **Number of Layers** | 34 |
16
+ | **Hidden Dimension** | 12800 |
17
+ | **Attention Mechanism** | GQA |
18
+ | **Number of Attention Heads** | 32 for Q,8 for KV |
19
+ | **Vocabulary Size** | 153k |
20
+ | **Context Length (Natively)** | 32k |
21
+ | **Pretraining Tokens** | 25T |
22
+
23
+ ## 3. 测评结果
24
+
25
+ | 测评集 | 测评指标 | 慢思考v1.0 | 慢思考v1.1 | 自适应v1.1 |
26
+ | :------------: | :-----------------: | :--------: | :--------: | :--------: |
27
+ | **通用能力** | | | | |
28
+ | MMLU-Pro | Exact Match | 76.32 | 75.54 | 72.81 |
29
+ | CMMLU | Acc | 75.59 | 72.94 | 72.18 |
30
+ | ArenaHard_v0.1 | w/o style control | 85.80 | 88.00 | 84.60 |
31
+ | C-Eval | Acc | 83.05 | 84.92 | 83.33 |
32
+ | GPQA-Diamond | Avg@4 | 70.54 | 73.23 | 73.74 |
33
+ | **数学能力** | | | | |
34
+ | MATH-500 | Avg@1 | 95.00 | 97.00 | 96.00 |
35
+ | AIME24 | Avg@16 | 71.57 | 79.38 | 79.02 |
36
+ | AIME25 | Avg@16 | 58.24 | 70.00 | 70.21 |
37
+ | **代码能力** | | | | |
38
+ | LiveCodeBench | Avg@2 (08/24~01/25) | 54.04 | 58.27 | 58.27 |
39
+ | MBPP+ | Avg@2 | 76.06 | 76.46 | 75.66 |
40
+
41
+ **注:** 评测过程中system prompt 为空,且不添加任何额外的思维链(CoT)提示。评测采用 128k 的序列长度进行。
42
+
43
+ 除精度外,我们还在部分数据集上统计了模型的输出长度,通过数据质量驱动的学习策略,自适应快慢思考可以在基本不影响精度地前提下,有效地在简单任务上自动切换部分输出为快思考,大幅缩短平均输出思维链长度(Length);在难任务通过保持慢思考能力,精度持平纯慢思考模型。
44
+
45
+ | 测评集 | 测评指标 | 慢思考v1.1 | 自适应v1.1 |
46
+ | :------------: | :-----------------: | :--------: | :--------: |
47
+ | **通用能力** | | | |
48
+ | CMMLU | Acc | 72.94 | 72.18 |
49
+ | | Length | 2574 | 1338 |
50
+ | C-Eval | Acc | 84.92 | 83.33 |
51
+ | | Length | 2484 | 1723 |
52
+ | **数学能力** | | | |
53
+ | AIME24 | Avg@16 | 79.38 | 79.02 |
54
+ | | Length | 48229 | 49656 |
55
+ | **代码能力** | | | |
56
+ | LiveCodeBench | Avg@2 (08/24~01/25) | 58.27 | 58.27 |
57
+ | | Length | 58140 | 59307 |
58
+
59
+ ## 4. 部署和使用
60
+
61
+ ### 4.1 环境准备
62
+
63
+ ##### 硬件规格
64
+
65
+ Atlas 800T A2 (64GB),驱动与固件安装包获取请参照 [[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)]。
66
+
67
+ ##### 软件环境
68
+
69
+ - 操作系统:Linux(推荐 openEuler>=24.03)
70
+ - CANN==8.1.RC1,安装准备及流程请参照 [[CANN Install]](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)
71
+ - python==3.10
72
+ - torch==2.1.0
73
+ - torch-npu==2.1.0.post12
74
+ - transformers==4.53.2
75
+
76
+ 以上软件配套经过验证,理论可以支持更高版本,如有疑问,可以提交 issue。
77
+
78
+ ### 4.2 权重完整性校验
79
+
80
+ 请参考以下方法对下载内容进行完整性校验,hash 值存储在 checklist.chk 文件中。
81
+
82
+ ```
83
+ #!/usr/bin/env bash
84
+ ARCH=$(uname -m)
85
+ MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
86
+ cd "$MODEL_PATH" || exit 1
87
+ if [ "$ARCH" = "arm64" ]; then
88
+ sha256sum checklist.chk
89
+ else
90
+ sha256sum -c checklist.chk
91
+ fi
92
+ ```
93
+
94
+ ### 4.3 推理样例
95
+
96
+ 下述内容提供 openPangu-Embedded-7B-V1.1 在 `transformers` 框架上进行推理的一个简单示例:
97
+
98
+ > 运行前请修改 generate.py,添加模型路径。
99
+
100
+ ```bash
101
+ cd inference
102
+ python generate.py
103
+ ```
104
+
105
+ openPangu-Embedded-7B-V1.1 模型默认为慢思考模式,可以通过以下手段切换至快慢自适应切换/快思考模式:
106
+
107
+ - 在代码实例`generate.py`中,`auto_thinking_prompt`与`no_thinking_prompt`变量的定义展示了切换至快慢自适应或快思考模式的具体实现:通过在用户输入末尾添加`/auto_think`或`/no_think`标记,可将当前轮次切换至快慢自适应切换/快思考模式。
108
+
109
+ ### 4.4 使用推理框架
110
+
111
+ vllm_ascend:参考[[vllm_ascend_for_openpangu_embedded_7b.zh]](inference/vllm_ascend_for_openpangu_embedded_7b.zh.md)
112
+
113
+ ## 5. 模型许可证
114
+
115
+ 除文件中对开源许可证另有约定外,openPangu-Embedded-7B-V1.1 模型根据 OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 授权,旨在允许使用并促进人工智能技术的进一步发展。有关详细信息,请参阅模型存储库根目录中的 [LICENSE](LICENSE) 文件。
116
+
117
+ ## 6. 免责声明
118
+
119
+ 由于 openPangu-Embedded-7B-V1.1(“模型”)所依赖的技术固有的技术限制,以及人工智能生成的内容是由盘古自动生成的,华为无法对以下事项做出任何保证:
120
+
121
+ - 尽管该模型的输出由 AI 算法生成,但不能排除某些信息可能存在缺陷、不合理或引起不适的可能性,生成的内容不代表华为的态度或立场;
122
+ - 无法保证该模型 100% 准确、可靠、功能齐全、及时、安全、无错误、不间断、持续稳定或无任何故障;
123
+ - 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任。
124
+
125
+ ## 7. 反馈
126
+
127
+ 如果有任何意见和建议,请提交issue或联系 [email protected]
README_EN.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openPangu-Embedded-7B-V1.1
2
+ [中文](README.md) | English
3
+
4
+ ## 1. Introduction
5
+ The openPangu-Embedded-7B-V1.1 is an efficient large language model trained from scratch based on the Ascend NPU. It contains 7 billion parameters (excluding the vocabulary embedding layer). The openPangu-Embedded-7B-V1.1 has been trained on approximately 25T tokens. The model is capable of integrating both fast and slow thinking, and can adaptively switch between two thinking mode based on assessed query complexity.
6
+
7
+
8
+ ## 2. Model Architecture
9
+
10
+ | | openPangu-Embedded-7B-V1.1 |
11
+ | :---------------------------: | :----------------: |
12
+ | **Architecture** | Dense |
13
+ | **Parameters (Non-Embedding)** | 7B |
14
+ | **Number of Layers** | 34 |
15
+ | **Hidden Dimension** | 12800 |
16
+ | **Attention Mechanism** | GQA |
17
+ | **Number of Attention Heads** | 32 for Q,8 for KV |
18
+ | **Vocabulary Size** | 153k |
19
+ | **Context Length (Natively)** | 32k |
20
+ | **Pretraining Tokens** | 25T |
21
+
22
+
23
+ ## 3. Results
24
+
25
+ | Benchmark | Metric | Slow-thinking v1.0 | Slow-thinking v1.1 | Adaptive-switching v1.1 |
26
+ | :---: | :---: | :---: | :---: | :---: |
27
+ | **General** | | | | |
28
+ | MMLU-Pro | Exact Match | 76.32 | 75.54 | 72.81 |
29
+ | CMMLU | Acc | 75.59 | 72.94 | 72.18 |
30
+ | ArenaHard_v0.1 | w/o style control | 85.80 | 88.00 | 84.60 |
31
+ | C-Eval | Acc | 83.05 | 84.92 | 83.33 |
32
+ | GPQA-Diamond | Avg@4 | 70.54 | 73.23 | 73.74 |
33
+ | **Math** | | | | |
34
+ | MATH-500 | Avg@1 | 95.00 | 97.00 | 96.00 |
35
+ | AIME24 | Avg@16 | 71.57 | 79.38 | 79.02 |
36
+ | AIME25 | Avg@16 | 58.24 | 70.00 | 70.21 |
37
+ | **Coding** | | | | |
38
+ | LiveCodeBench | Avg@2 (08/24~01/25) | 54.04 | 58.27 | 58.27 |
39
+ | MBPP+ | Avg@2 | 76.06 | 76.46 | 75.66 |
40
+
41
+ **Note:** The system prompt is left empty, and no additional Chain-of-Thought (CoT) prompts are introduced during the evaluation. All evaluations are performed using a sequence length of 128k tokens.
42
+
43
+ In addition to accuracy, we also analyzed the model's output length on some datasets. Through data quality-driven learning strategy, adaptive-switching mode can effectively automatically switch some outputs to fast thinking on simple tasks without significantly affecting accuracy, significantly shortening the average Chain-of-Thought length. On difficult tasks, by maintaining slow thinking capabilities, the accuracy is comparable to that of a pure slow-thinking model.
44
+
45
+ | Benchmark | Metric | Slow-thinking v1.1 | Adaptive-switching v1.1 |
46
+ | :------------: | :-----------------: | :--------: | :--------: |
47
+ | **General** | | | |
48
+ | CMMLU | Acc | 72.94 | 72.18 |
49
+ | | Length | 2574 | 1338 |
50
+ | C-Eval | Acc | 84.92 | 83.33 |
51
+ | | Length | 2484 | 1723 |
52
+ | **Math** | | | |
53
+ | AIME24 | Avg@16 | 79.38 | 79.02 |
54
+ | | Length | 48229 | 49656 |
55
+ | **Coding** | | | |
56
+ | LiveCodeBench | Avg@2 (08/24~01/25) | 58.27 | 58.27 |
57
+ | | Length | 58140 | 59307 |
58
+
59
+ ## 4. Deployment
60
+
61
+ ### 4.1 Environment
62
+
63
+ ##### Hardware Requirements
64
+
65
+ Atlas 800T A2 (64GB), please refer to [[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)] for obtaining the driver and firmware installation packages.
66
+
67
+ #### System Requirements & Dependencies
68
+
69
+ - System: Linux (OpenEuler ≥ 24.03 recommended)
70
+ - CANN==8.1.RC1: [[CANN Install]](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)
71
+ - python==3.10
72
+ - torch==2.1.0
73
+ - torch-npu==2.1.0.post12
74
+ - transformers==4.53.2
75
+
76
+ The above software environment has been verified, and theoretically supports newer versions. For any questions, please submit an issue.
77
+
78
+ ### 4.2 Integrity Check
79
+
80
+ Please refer to the following methods to verify the integrity of the downloaded content. The hash values are stored in the `checklist.chk` file.
81
+
82
+ ```
83
+ #!/usr/bin/env bash
84
+ ARCH=$(uname -m)
85
+ MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
86
+ cd "$MODEL_PATH" || exit 1
87
+ if [ "$ARCH" = "arm64" ]; then
88
+ sha256sum checklist.chk
89
+ else
90
+ sha256sum -c checklist.chk
91
+ fi
92
+ ```
93
+
94
+ ### 4.3 Inference Examples
95
+
96
+ The following provides a simple inference example of openPangu-Embedded-7B-V1.1 based on the `transformers` framework:
97
+ >Please modify generate.py and add the model path before running.
98
+ ```bash
99
+ cd inference
100
+ python generate.py
101
+ ```
102
+
103
+ The openPangu-Embedded-7B-V1.1 model is in slow thinking mode by default, and can be switched to adaptive/fast thinking mode by the following means:
104
+ - In the code example `generate.py`, the definition of the `auto_thinking_prompt` and `no_thinking_prompt` variables demonstrates the specific implementation for switching to adaptive/fast thinking mode: by appending the `/auto_think` or `/no_think` tag at the end of user input, the current turn can be switched to fast thinking mode. In this mode, `thinking_content` will be an empty value.
105
+
106
+ ### 4.4 Using Inference Framework
107
+ vllm_ascend:[[vllm_ascend_for_openpangu_embedded_7b]](inference/vllm_ascend_for_openpangu_embedded_7b.md)
108
+
109
+
110
+
111
+
112
+ ## 5. Model License
113
+
114
+ Unless otherwise noted, openPangu-Embedded-7B-V1.1 model is licensed under the terms and conditions of OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0, which is intended to be used permissively and enable the further development of artificial intelligence technologies. Please refer to the [LICENSE](LICENSE) file located in the root directory of the model repository for details.
115
+
116
+ ## 6. Disclaimer
117
+ Due to the technical limitations inherent in the technology on which the openPangu-Embedded-7B-V1.1 (“Model”) relies and the fact that the artificial intelligence generated content is automatically produced by Model, Huawei cannot make any guarantees regarding the following matters:
118
+ - The output of this Model is automatically generated via AI algorithms, it does not rule out the possibility that some of the information may be flawed, unreasonable, or cause discomfort, and the generated content does not represent Huawei's attitude or standpoint;
119
+ - There is no guarantee that this Model is 100% accurate, reliable, functional, timely, secure and safety, error-free, uninterrupted, continuously stable, or free of any faults;
120
+ - The output of this Model does not constitute any advices or decisions for you, and it does not guarantee the authenticity, completeness, accuracy, timeliness, legality, functionality, or practicality of the generated content. The generated content cannot replace professionals in medical, legal, and other fields in answering your questions. The generated content is for your reference only and does not represent any attitude, standpoint, or position of Huawei. You need to make independent judgments based on your actual situation, and Huawei does not assume any responsibilities.
121
+
122
+
123
+ ## 7. Contact Us
124
+ If you have any comments or suggestions, please submit an issue or contact [email protected].
checklist.chk ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 8e300541888d68a6148d18720fa6b2df48f71e2e1606026d55b30af6dcb2ba37 *./config.json
2
+ 7694a0e7b59d7ec2eeebc2fd058f02fe4dc4464b27f82839fc9f425a88555a3a *./configuration_openpangu_dense.py
3
+ 5cbfc09f10ae85f0e9bebc1281541dcc7107d86e34282839277782cbb146117d *./generation_config.json
4
+ e2f59bb91ad4e826fe087090393f438e0b6d41b6d5fe2991b390938030877504 *./inference/generate.py
5
+ ba6d7edcf1cf464d6fd787b12a9bda2a16fea0ac0d5d1e54136baec503d6e696 *./inference/vllm_ascend/attention/attention.py
6
+ 2254aeca0be7b8922318e10c4a950f39afb30ba5fe3b46564a58671b237ac612 *./inference/vllm_ascend/attention/mla_v1.py
7
+ f9577c29bc4dc19a4cc41ccfcca17065402c9dd92221bef987c74808b23ed124 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
8
+ 9070682b058a79d2b2874ba5e07ce72beff6efb870f75cdac30cdcf6ba8fadc7 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
9
+ 91eab52cdc19603b7b705b302e25345d849e18fa66875261a1135d5382392123 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
10
+ d07256c9014f911f81269e65aad6c0d7dd61d4e82f5cb399e05285d5c1bc8fa8 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
11
+ 52a968f10ebaebeb626248afd3e1d1b92f8fbfcaad19ebf05cafbc0bd03192cb *./inference/vllm_ascend/envs.py
12
+ b654e72ece161b3f04080e5c4d2476641c024939ac5308115fe1c65a6c5c7215 *./inference/vllm_ascend/models/open_pangu.py
13
+ e98aa2549f02017a35b07499216fe569e86400684087821820cf2d971c8fcbac *./inference/vllm_ascend/models/__init__.py
14
+ 09273eb0e4696d2fb530881ba1ad9d331897dd81c0cd2f203ed3d0a475b4d39b *./inference/vllm_ascend/ops/fused_moe.py
15
+ 8436ab93933989431160e55627b5dce5326f0fc5ec18263653902764ac8ace7b *./inference/vllm_ascend/patch/worker/patch_common/patch_config.py
16
+ 8c59df8086bde0cd4df674403f83000921a34403651a8ff2b31de9b28768247a *./inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
17
+ e712ea36caf16c2a9dd21c5288f9d8e34c7fd2face444da44dca6db6c21f6c1b *./inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
18
+ 63a6ba0d0b0158d4586219c979bf96d5fe87b74123af93f1c8d9ed842db96500 *./inference/vllm_ascend/patch/worker/patch_common/__init__.py
19
+ 743bd96cfc109975a11fe5412c4b5de46f880501dcbbbdd10e11cbeb865fa4f2 *./inference/vllm_ascend/quantization/w8a8.py
20
+ 6adfaa8a67ea9b561dec2e6a2392f6fc85ff376fb2030d8761c34c6c6d3f4cbf *./inference/vllm_ascend/quantization/w8a8_dynamic.py
21
+ 4a95e3bf02d4c88bb614a48b37fe5c66c8c4c28e574cfc8a0e9791fc22704e7d *./inference/vllm_ascend/tokenizer_config.json
22
+ e2457c558f048876afe069d1226e7080ac214478f1a9ac28ae472928b81b5a06 *./inference/vllm_ascend/utils.py
23
+ 62c6734d1283e3d649a6478d2004f46bfee2f7878af7f2849c979b124e355302 *./inference/vllm_ascend/worker/model_runner_v1.py
24
+ bc6505adabc0498ad07b49187858788c65c13dbf9446fd0bcf177a3e1b27220d *./inference/vllm_ascend/worker/npu_input_batch.py
25
+ 4aaf57e6f6d2e139b3847b10ee59d738398ebfc4927a22325b27dad384874aec *./inference/vllm_ascend/_build_info.py
26
+ 2aae256964b995a9416935d2aeb32cb9028f413a8e9bc2030dbe00e02b38ffec *./model-00001-of-00004.safetensors
27
+ 6b0c4bcd085187fc8a1e1aeeec59afac21c876ba9b47e275e87b8ce56e058fec *./model-00002-of-00004.safetensors
28
+ 511d39bdc8d619627aa701820c141afa1c3a1139450e2d2d6892004bc7b61721 *./model-00003-of-00004.safetensors
29
+ 52ae62d754a63d2fcd8d1f40be1332dd161a2105ff7c43622328ef3800d4d891 *./model-00004-of-00004.safetensors
30
+ 9bf645e8399be6d99000eae64bd172b5c457d6d2c44d2257b47eb97a3c41aeda *./model.safetensors.index.json
31
+ f15eaf322af8a0b0f16b26795eb68af836179413d3dbfa4dc44505db6c8b0d6f *./modeling_openpangu_dense.py
32
+ c1f2d87f855b994039c52b1e83c8a7f3d71a2d1eb52946c4a2e862e99f19d8b3 *./modular_openpangu_dense.py
33
+ b34cf5e7c7660889303b6e2d0a346c440356385c9db551d06f6615cf9fc600d1 *./special_tokens_map.json
34
+ c98602d6d1f61792a8bd3393972bbbe7409a205c0bb6299394c74287c26bd723 *./tokenization_openpangu.py
35
+ 6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec *./tokenizer.model
36
+ acb88eac57f8765fedf34e9c10bc16d55c46f0902b0fea74fbf041daca2667ae *./tokenizer_config.json
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PanguEmbeddedForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_openpangu_dense.PanguEmbeddedConfig",
7
+ "AutoModel": "modeling_openpangu_dense.PanguEmbeddedModel",
8
+ "AutoModelForCausalLM": "modeling_openpangu_dense.PanguEmbeddedForCausalLM"
9
+ },
10
+ "bias": true,
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "pad_token_id": 0,
14
+ "eos_token_id": 45892,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 12800,
19
+ "max_position_embeddings": 32768,
20
+ "model_type": "PanguEmbedded",
21
+ "num_attention_heads": 32,
22
+ "num_hidden_layers": 34,
23
+ "num_key_value_heads": 8,
24
+ "rms_norm_eps": 1e-05,
25
+ "rope_theta": 16000000.0,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.53.2",
29
+ "use_cache": true,
30
+ "vocab_size": 153376
31
+ }
configuration_openpangu_dense.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ from transformers.utils import logging
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class PanguEmbeddedConfig(PretrainedConfig):
12
+
13
+ model_type = "PanguEmbedded"
14
+ _auto_class = "AutoConfig"
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size=153376,
19
+ hidden_size=4096,
20
+ intermediate_size=12800,
21
+ num_hidden_layers=34,
22
+ num_attention_heads=32,
23
+ num_key_value_heads=8,
24
+ hidden_act="silu",
25
+ max_position_embeddings=32768,
26
+ initializer_range=0.02,
27
+ rms_norm_eps=1e-5,
28
+ use_cache=True,
29
+ pad_token_id=0,
30
+ bos_token_id=1,
31
+ eos_token_id=45892,
32
+ tie_word_embeddings=False,
33
+ rope_theta=16000000.0,
34
+ bias=True,
35
+ **kwargs,
36
+ ):
37
+ self.vocab_size = vocab_size
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.hidden_size = hidden_size
40
+ self.intermediate_size = intermediate_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_attention_heads = num_attention_heads
43
+ self.num_key_value_heads = num_key_value_heads
44
+ self.hidden_act = hidden_act
45
+ self.initializer_range = initializer_range
46
+ self.rms_norm_eps = rms_norm_eps
47
+ self.use_cache = use_cache
48
+ self.rope_theta = rope_theta
49
+ self.bias = bias
50
+ super().__init__(
51
+ pad_token_id=pad_token_id,
52
+ bos_token_id=bos_token_id,
53
+ eos_token_id=eos_token_id,
54
+ tie_word_embeddings=tie_word_embeddings,
55
+ **kwargs,
56
+ )
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "do_sample": true,
4
+ "bos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "eos_token_id": 45892,
7
+ "temperature": 1.0,
8
+ "top_k": 0,
9
+ "top_p": 0.8,
10
+ "transformers_version": "4.53.2"
11
+ }
inference/generate.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ model_local_path = "path_to_openPangu-Embedded-7B"
7
+
8
+
9
+ # load the tokenizer and the model
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ model_local_path,
12
+ use_fast=False,
13
+ trust_remote_code=True,
14
+ local_files_only=True
15
+ )
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_local_path,
19
+ trust_remote_code=True,
20
+ torch_dtype="auto",
21
+ device_map="npu",
22
+ local_files_only=True
23
+ )
24
+
25
+ # prepare the model input
26
+ sys_prompt = "你必须严格遵守法律法规和社会道德规范。" \
27
+ "生成任何内容时,都应避免涉及暴力、色情、恐怖主义、种族歧视、性别歧视等不当内容。" \
28
+ "一旦检测到输入或输出有此类倾向,应拒绝回答并发出警告。例如,如果输入内容包含暴力威胁或色情描述," \
29
+ "应返回错误信息:“您的输入包含不当内容,无法处理。”"
30
+
31
+ prompt = "Give me a short introduction to large language model."
32
+ no_thinking_prompt = prompt+" /no_think"
33
+ auto_thinking_prompt = prompt+" /auto_think"
34
+ messages = [
35
+ {"role": "system", "content": sys_prompt}, # define your system prompt here
36
+ {"role": "user", "content": prompt}
37
+ ]
38
+ text = tokenizer.apply_chat_template(
39
+ messages,
40
+ tokenize=False,
41
+ add_generation_prompt=True
42
+ )
43
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
44
+
45
+ # conduct text completion
46
+ outputs = model.generate(**model_inputs, max_new_tokens=32768, eos_token_id=45892, return_dict_in_generate=True)
47
+
48
+ input_length = model_inputs.input_ids.shape[1]
49
+ generated_tokens = outputs.sequences[:, input_length:]
50
+ output_sent = tokenizer.decode(generated_tokens[0])
51
+
52
+ # parsing thinking content
53
+ thinking_content = output_sent.split("[unused17]")[0].split("[unused16]")[-1].strip()
54
+ content = output_sent.split("[unused17]")[-1].split("[unused10]")[0].strip()
55
+
56
+ print("\nthinking content:", thinking_content)
57
+ print("\ncontent:", content)
inference/vllm_ascend/_build_info.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Auto-generated file
2
+ __soc_version__ = 'ASCEND910B1'
3
+ __sleep_mode_enabled__ = True
inference/vllm_ascend/attention/attention.py ADDED
@@ -0,0 +1,1220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # This file is a part of the vllm-ascend project.
16
+ #
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Type
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch_npu
24
+ import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
25
+ from torch.nn.functional import scaled_dot_product_attention
26
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
27
+ AttentionLayer,
28
+ AttentionMetadata, AttentionType,
29
+ MLAAttentionImpl)
30
+ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
31
+ CommonMetadataBuilder,
32
+ compute_slot_mapping,
33
+ compute_slot_mapping_start_idx,
34
+ is_block_tables_empty)
35
+ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
36
+
37
+ from vllm_ascend.ascend_config import get_ascend_config
38
+ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
39
+ from vllm_ascend.ops.cache import concat_and_cache_mla
40
+ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
41
+ enable_custom_op, is_310p, nd_to_nz_2d)
42
+ from vllm_ascend.worker.model_runner import (
43
+ ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
44
+
45
+ _ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
46
+
47
+
48
+ class AscendAttentionBackend(AttentionBackend):
49
+
50
+ @staticmethod
51
+ def get_name() -> str:
52
+ return "ASCEND"
53
+
54
+ @staticmethod
55
+ def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
56
+ return AscendAttentionBackendImpl
57
+
58
+ @staticmethod
59
+ def get_metadata_cls() -> Type["AscendMetadata"]:
60
+ return AscendMetadata
61
+
62
+ @staticmethod
63
+ def get_state_cls() -> Type["CommonAttentionState"]:
64
+ return CommonAttentionState
65
+
66
+ @staticmethod
67
+ def get_kv_cache_shape(
68
+ num_blocks: int,
69
+ block_size: int,
70
+ num_kv_heads: int,
71
+ head_size: int,
72
+ ) -> Tuple[int, ...]:
73
+ if is_310p():
74
+ return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
75
+ 16)
76
+ else:
77
+ return (2, num_blocks, block_size, num_kv_heads, head_size)
78
+
79
+ @staticmethod
80
+ def swap_blocks(
81
+ src_kv_cache: List[torch.Tensor],
82
+ dst_kv_cache: List[torch.Tensor],
83
+ src_to_dst: torch.Tensor,
84
+ ) -> None:
85
+ src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
86
+ dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
87
+ src_indices = src_to_dst[:, 0]
88
+ dst_indices = src_to_dst[:, 1]
89
+
90
+ dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
91
+ dst_key_cache.device)
92
+ dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
93
+ dst_key_cache.device)
94
+
95
+ @staticmethod
96
+ def copy_blocks(
97
+ kv_caches: List[torch.Tensor],
98
+ src_to_dists: torch.Tensor,
99
+ ) -> None:
100
+ src_indices = src_to_dists[:, 0]
101
+ dst_indices = src_to_dists[:, 1]
102
+
103
+ for kv_cache in kv_caches:
104
+ key_caches = kv_cache[0]
105
+ value_caches = kv_cache[1]
106
+ key_caches[dst_indices] = key_caches[src_indices]
107
+ value_caches[dst_indices] = value_caches[src_indices]
108
+
109
+ @staticmethod
110
+ def get_builder_cls() -> Type["AscendMetadataBuilder"]:
111
+ return AscendMetadataBuilder
112
+
113
+ @classmethod
114
+ def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
115
+ return cls.get_builder_cls()(*args, **kwargs)
116
+
117
+
118
+ class AscendMLAAttentionBackend(AscendAttentionBackend):
119
+
120
+ @staticmethod
121
+ def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]:
122
+ return AscendMLAAttentionBackendImpl
123
+
124
+ @staticmethod
125
+ def get_kv_cache_shape(
126
+ num_blocks: int,
127
+ block_size: int,
128
+ num_kv_heads: int,
129
+ head_size: int,
130
+ ) -> Tuple[int, ...]:
131
+ return (num_blocks, block_size, num_kv_heads, head_size)
132
+
133
+
134
+ @dataclass
135
+ class AscendMetadata(AttentionMetadata):
136
+ """Metadata for Ascendbackend.
137
+ * modified from XFormersbackend
138
+ NOTE: Any python object stored here is not updated when it is
139
+ cuda-graph replayed. If you have values that need to be changed
140
+ dynamically, it should be stored in tensor. The tensor has to be
141
+ updated from `CUDAGraphRunner.forward` API.
142
+ """
143
+
144
+ # |---------- N-1 iteration --------|
145
+ # |---------------- N iteration ---------------------|
146
+ # |- tokenA -|......................|-- newTokens ---|
147
+ # |---------- context_len ----------|
148
+ # |-------------------- seq_len ----------------------|
149
+ # |-- query_len ---|
150
+
151
+ # FIXME: It is for flash attn.
152
+ # Maximum sequence length among prefill batch. 0 if there are decoding
153
+ # Avoid mypy error
154
+ # Total number of prefill requests.
155
+ num_prefills: int
156
+ # Number of prefill tokens.
157
+ num_prefill_tokens: int
158
+ # (num_tokens,). The indices of the token slots that input tokens will be
159
+ # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
160
+ # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
161
+ # in block 0, and 1st slot in block 1, respectively.
162
+ slot_mapping: torch.Tensor
163
+
164
+ # requests only.
165
+ max_prefill_seq_len: int
166
+ # Maximum sequence length among decode batch. 0 if there are prefill
167
+ # requests only.
168
+ max_decode_seq_len: int
169
+
170
+ chunked_prefill_enabled: bool
171
+
172
+ # (batch_size, max_blocks_per_seq).
173
+ # Block addresses per sequence. (Seq id -> list of physical block)
174
+ block_tables: Optional[torch.Tensor]
175
+
176
+ # seq_lens stored as a tensor.
177
+ seq_lens_tensor: Optional[torch.Tensor]
178
+
179
+ # (batch_size,). The sequence length per sequence. Sequence length means
180
+ # the computed tokens + new tokens None if it is a decoding.
181
+ seq_lens: Optional[List[int]] = None
182
+
183
+ # The query lengths of the input sequences
184
+ query_lens: Optional[List[int]] = None
185
+
186
+ # Maximum query length in the batch. None for decoding.
187
+ max_query_len: Optional[int] = None
188
+
189
+ # Self-attention prefill/decode metadata cache
190
+ _cached_prefill_metadata: Optional["AscendMetadata"] = None
191
+ _cached_decode_metadata: Optional["AscendMetadata"] = None
192
+
193
+ # Begin encoder attn & enc/dec cross-attn fields...
194
+
195
+ # Encoder sequence lengths representation
196
+ encoder_seq_lens: Optional[List[int]] = None
197
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
198
+
199
+ # Maximum sequence length among encoder sequences
200
+ max_encoder_seq_len: Optional[int] = None
201
+
202
+ # Number of tokens input to encoder
203
+ num_encoder_tokens: Optional[int] = None
204
+
205
+ # Mask for normal situation
206
+ attn_mask: Optional[torch.Tensor] = None
207
+
208
+ # Mask for prefix caching
209
+ compress_mask: Optional[torch.Tensor] = None
210
+
211
+ # Mask for chunked prefill
212
+ chunk_mask: Optional[torch.Tensor] = None
213
+
214
+ # Cross-attention memory-mapping data structures: slot mapping
215
+ # and block tables
216
+ cross_slot_mapping: Optional[torch.Tensor] = None
217
+ cross_block_tables: Optional[torch.Tensor] = None
218
+
219
+ @property
220
+ def prefill_metadata(self) -> Optional["AscendMetadata"]:
221
+ if self.num_prefills == 0:
222
+ return None
223
+
224
+ if self._cached_prefill_metadata is not None:
225
+ # Recover cached prefill-phase attention
226
+ # metadata structure.
227
+ return self._cached_prefill_metadata
228
+
229
+ assert ((self.seq_lens is not None)
230
+ or (self.encoder_seq_lens is not None))
231
+
232
+ # Compute some attn_metadata fields which default to None.
233
+ slot_mapping = (None if self.slot_mapping is None else
234
+ self.slot_mapping[:self.num_prefill_tokens])
235
+ seq_lens = (None if self.seq_lens is None else
236
+ self.seq_lens[:self.num_prefills])
237
+ query_lens = (None if self.query_lens is None else
238
+ self.query_lens[:self.num_prefills])
239
+ block_tables = (None if self.block_tables is None else
240
+ self.block_tables[:self.num_prefills])
241
+
242
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
243
+ self.seq_lens_tensor[:self.num_prefills])
244
+
245
+ # Construct & cache prefill-phase attention metadata structure.
246
+ self._cached_prefill_metadata = AscendMetadata(
247
+ num_prefills=self.num_prefills,
248
+ num_prefill_tokens=self.num_prefill_tokens,
249
+ num_decode_tokens=0,
250
+ slot_mapping=slot_mapping,
251
+ seq_lens=seq_lens,
252
+ seq_lens_tensor=seq_lens_tensor,
253
+ query_lens=query_lens,
254
+ max_query_len=self.max_query_len,
255
+ max_prefill_seq_len=self.max_prefill_seq_len,
256
+ max_decode_seq_len=0,
257
+ chunked_prefill_enabled=self.chunked_prefill_enabled,
258
+ block_tables=block_tables,
259
+ # Begin encoder & cross attn fields below...
260
+ encoder_seq_lens=self.encoder_seq_lens,
261
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
262
+ max_encoder_seq_len=self.max_encoder_seq_len,
263
+ multi_modal_placeholder_index_maps=self.
264
+ multi_modal_placeholder_index_maps,
265
+ cross_slot_mapping=self.cross_slot_mapping,
266
+ cross_block_tables=self.cross_block_tables,
267
+ enable_kv_scales_calculation=False)
268
+ return self._cached_prefill_metadata
269
+
270
+ @property
271
+ def decode_metadata(self) -> Optional["AscendMetadata"]:
272
+ if self.num_decode_tokens == 0:
273
+ return None
274
+
275
+ if self._cached_decode_metadata is not None:
276
+ # Recover cached decode-phase attention
277
+ # metadata structure.
278
+ return self._cached_decode_metadata
279
+
280
+ # Compute some attn_metadata fields which default to None.
281
+ slot_mapping = (None if self.slot_mapping is None else
282
+ self.slot_mapping[self.num_prefill_tokens:])
283
+ seq_lens = (None if self.seq_lens is None else
284
+ self.seq_lens[self.num_prefills:])
285
+ query_lens = (None if self.query_lens is None else
286
+ self.query_lens[self.num_prefills:])
287
+ block_tables = (None if self.block_tables is None else
288
+ self.block_tables[self.num_prefills:])
289
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
290
+ self.seq_lens_tensor[self.num_prefills:])
291
+ # Construct & cache decode-phase attention metadata structure.
292
+ self._cached_decode_metadata = AscendMetadata(
293
+ num_prefills=0,
294
+ num_prefill_tokens=0,
295
+ num_decode_tokens=self.num_decode_tokens,
296
+ slot_mapping=slot_mapping,
297
+ seq_lens=seq_lens,
298
+ seq_lens_tensor=seq_lens_tensor,
299
+ query_lens=query_lens,
300
+ max_query_len=self.max_query_len,
301
+ max_prefill_seq_len=0,
302
+ max_decode_seq_len=self.max_decode_seq_len,
303
+ chunked_prefill_enabled=self.chunked_prefill_enabled,
304
+ block_tables=block_tables,
305
+ # Begin encoder & cross attn fields below...
306
+ encoder_seq_lens=self.encoder_seq_lens,
307
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
308
+ max_encoder_seq_len=self.max_encoder_seq_len,
309
+ multi_modal_placeholder_index_maps=self.
310
+ multi_modal_placeholder_index_maps,
311
+ cross_slot_mapping=self.cross_slot_mapping,
312
+ cross_block_tables=self.cross_block_tables,
313
+ enable_kv_scales_calculation=False)
314
+ return self._cached_decode_metadata
315
+
316
+ def advance_step(self,
317
+ model_input: "ModelInputForNPUWithSamplingMetadata",
318
+ sampled_token_ids: Optional[torch.Tensor],
319
+ block_size: int,
320
+ num_seqs: int,
321
+ num_queries: int,
322
+ turn_prefills_into_decodes: bool = False):
323
+ """
324
+ Update metadata in-place to advance one decode step.
325
+ """
326
+ # When using cudagraph, the num_seqs is padded to the next captured
327
+ # batch sized, but num_queries tracks the actual number of requests in
328
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
329
+ if num_seqs != num_queries:
330
+ assert num_seqs > num_queries
331
+
332
+ if turn_prefills_into_decodes:
333
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
334
+ # decodes are scheduled together. In the first step, all the
335
+ # prefills turn into decodes. This update reflects that
336
+ # conversion.
337
+ assert self.num_decode_tokens + self.num_prefills == num_seqs
338
+ self.num_decode_tokens += self.num_prefills
339
+ self.num_prefills = 0
340
+ self.num_prefill_tokens = 0
341
+ self.max_prefill_seq_len = 0
342
+ self.max_query_len = 1
343
+
344
+ self.slot_mapping = self.slot_mapping[:num_seqs]
345
+ else:
346
+ assert self.seq_lens is not None
347
+ assert self.max_decode_seq_len == max(self.seq_lens)
348
+
349
+ assert self.num_prefills == 0
350
+ assert self.num_prefill_tokens == 0
351
+ assert self.num_decode_tokens == num_seqs
352
+ assert self.slot_mapping.shape == (num_seqs, )
353
+
354
+ assert self.seq_lens is not None
355
+ assert len(self.seq_lens) == num_seqs
356
+ assert self.seq_lens_tensor is not None
357
+ assert self.seq_lens_tensor.shape == (num_seqs, )
358
+ assert self.max_query_len == 1
359
+ assert self.max_prefill_seq_len == 0
360
+
361
+ assert self.block_tables is not None
362
+ assert self.block_tables.shape[0] == num_seqs
363
+
364
+ # Update query lengths. Note that we update only queries and not seqs,
365
+ # since tensors may be padded due to captured cuda graph batch size
366
+ for i in range(num_queries):
367
+ self.seq_lens[i] += 1
368
+ self.max_decode_seq_len = max(self.seq_lens)
369
+ if enable_custom_op():
370
+ #advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
371
+ torch.ops._C.advance_step_flashattn_ascendc(
372
+ num_seqs=num_seqs,
373
+ num_queries=num_queries,
374
+ block_size=block_size,
375
+ input_tokens=model_input.input_tokens,
376
+ sampled_token_ids=sampled_token_ids,
377
+ input_positions=model_input.input_positions,
378
+ seq_lens=self.seq_lens_tensor,
379
+ slot_mapping=self.slot_mapping,
380
+ block_tables=self.block_tables)
381
+ else:
382
+ # use traditional Pytorch method for updating these tensors.
383
+ # update input_tokens
384
+ sampled_token_ids_list = sampled_token_ids[:
385
+ num_queries].squeeze( # type: ignore
386
+ -1)
387
+ model_input.input_tokens[:
388
+ num_queries] = sampled_token_ids_list # type: ignore
389
+
390
+ # get seq_lens and input_positions
391
+ seq_lens = self.seq_lens_tensor[:num_queries]
392
+ next_seq_lens = seq_lens + 1
393
+ next_input_pos = next_seq_lens - 1
394
+
395
+ # update seq_lens and input_positions
396
+ self.seq_lens_tensor[:num_queries] = next_seq_lens
397
+ model_input.input_positions[:
398
+ num_queries] = next_input_pos # type: ignore
399
+
400
+ # 计算 block index 和 offset
401
+ block_idx = next_input_pos // block_size
402
+ block_offset = next_input_pos % block_size
403
+
404
+ current_block_table = self.block_tables.gather(
405
+ 1, block_idx.unsqueeze(-1)).squeeze(-1)
406
+ slot_num = current_block_table * block_size + block_offset
407
+
408
+ # update slot_mapping
409
+ self.slot_mapping[:num_queries] = slot_num
410
+
411
+
412
+ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
413
+
414
+ _attn_mask_builder = None # noqa
415
+
416
+ def __init__(self, input_builder: "ModelInputForNPUBuilder"):
417
+ self.input_builder = input_builder
418
+ self.runner = input_builder.runner
419
+ self.sliding_window = input_builder.sliding_window
420
+ self.block_size = input_builder.block_size
421
+
422
+ self.attn_mask = None
423
+ self.compress_mask = None
424
+ self.chunk_mask = None
425
+ if AscendMetadataBuilder._attn_mask_builder is None:
426
+ AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
427
+ 128, self.input_builder.runner.model_config.dtype)
428
+
429
+ def _add_seq_group(
430
+ self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup,
431
+ chunked_prefill_enabled: bool):
432
+ """Add a sequence group to the metadata. Specifically update/append
433
+ 1. context length.
434
+ 2. block table.
435
+ 3. slot mapping.
436
+ """
437
+ is_prompt = inter_data.is_prompt
438
+ block_tables = inter_data.block_tables
439
+
440
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
441
+ curr_sliding_window_block) in zip(
442
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
443
+ inter_data.orig_seq_lens, inter_data.seq_lens,
444
+ inter_data.query_lens, inter_data.context_lens,
445
+ inter_data.curr_sliding_window_blocks):
446
+ self.context_lens.append(context_len)
447
+ if is_prompt:
448
+ self.num_prefills += 1
449
+ self.num_prefill_tokens += token_len
450
+ self.prefill_seq_lens.append(seq_len)
451
+ else:
452
+ self.num_decode_tokens += query_len
453
+ self.curr_seq_lens.append(curr_seq_len)
454
+
455
+ # Compute block table.
456
+ # TODO(sang): Combine chunked prefill and prefix caching by
457
+ # only allowing multiple of block_size chunk size.
458
+ # NOTE: This only works for oooooooxxx style attention.
459
+ block_table: List[int] = []
460
+ prefix_cache_hit = any([
461
+ inter_data.prefix_cache_hit
462
+ for inter_data in self.input_builder.inter_data_list
463
+ ])
464
+ if prefix_cache_hit:
465
+ # NOTE(woosuk): For flash-attn, the block table should
466
+ # include the entries for the incoming prefill tokens.
467
+ if block_tables is not None:
468
+ block_table = block_tables[seq_id]
469
+ elif ((chunked_prefill_enabled or not is_prompt)
470
+ and block_tables is not None):
471
+ if curr_sliding_window_block == 0:
472
+ block_table = block_tables[seq_id]
473
+ else:
474
+ block_table = block_tables[seq_id][
475
+ -curr_sliding_window_block:]
476
+ self.block_tables.append(block_table)
477
+
478
+ # Compute slot mapping.
479
+ is_profile_run = is_block_tables_empty(block_tables)
480
+ start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
481
+ context_len,
482
+ self.sliding_window)
483
+ compute_slot_mapping(
484
+ is_profile_run,
485
+ self.slot_mapping,
486
+ seq_id,
487
+ seq_len,
488
+ context_len,
489
+ start_idx,
490
+ self.block_size,
491
+ inter_data.block_tables,
492
+ )
493
+
494
+ def _get_graph_runner_block_tables(
495
+ self, num_seqs: int,
496
+ block_tables: List[List[int]]) -> torch.Tensor:
497
+ # The shape of graph_block_tables is
498
+ # [max batch size, max context len // block size].
499
+
500
+ max_batch_size, max_blocks = self.runner.graph_block_tables.shape
501
+ assert max_batch_size >= num_seqs
502
+
503
+ graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
504
+ for i, block_table in enumerate(block_tables):
505
+ if block_table:
506
+ num_blocks = len(block_table)
507
+ if num_blocks <= max_blocks:
508
+ graph_block_tables[i, :num_blocks] = block_table
509
+ else:
510
+ graph_block_tables[
511
+ i, :max_blocks] = block_table[:max_blocks]
512
+
513
+ return torch.from_numpy(graph_block_tables).to(
514
+ device=self.runner.device, non_blocking=True)
515
+
516
+ def build(
517
+ self,
518
+ seq_lens: List[int],
519
+ query_lens: List[int],
520
+ graph_pad_size: int,
521
+ ):
522
+ """Build attention metadata with on-device tensors.
523
+
524
+ Args:
525
+ seq_lens: The maybe padded sequence lengths of the input sequences.
526
+ query_lens: The query lengths of the input sequences.
527
+ """
528
+ for inter_data in self.input_builder.inter_data_list:
529
+ self._add_seq_group(inter_data,
530
+ self.input_builder.chunked_prefill_enabled)
531
+
532
+ device = self.runner.device
533
+ dtype = self.runner.model_config.dtype
534
+ use_npu_graph = graph_pad_size != -1
535
+
536
+ max_query_len = max(query_lens)
537
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
538
+ max_decode_seq_len = max(self.curr_seq_lens, default=0)
539
+ max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
540
+ num_decode_tokens = self.num_decode_tokens
541
+
542
+ if self.num_prefills == 0 and use_npu_graph:
543
+ num_seqs = len(seq_lens)
544
+ self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
545
+ self.block_tables.extend([[]] * graph_pad_size)
546
+ block_tables = self._get_graph_runner_block_tables(
547
+ num_seqs, self.block_tables)
548
+ else:
549
+ block_tables = make_tensor_with_pad(
550
+ self.block_tables,
551
+ pad=0,
552
+ dtype=torch.int32,
553
+ device=device,
554
+ )
555
+
556
+ if self.num_prefills > 0:
557
+ if block_tables is None or block_tables.numel() == 0:
558
+ # normal mask
559
+ self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
560
+ max_prefill_seq_len, dtype, device)
561
+ if is_310p():
562
+ mask_nz = nd_to_nz_2d(self.attn_mask)
563
+ mask_nz = torch_npu.npu_format_cast(
564
+ mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)
565
+ self.attn_mask = mask_nz
566
+ elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
567
+ # compress mask for prefix cache
568
+ self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
569
+ 128, dtype, device)
570
+ else:
571
+ # chunk_mask for chunk prefill
572
+ attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
573
+ max_seq_len, dtype, device)
574
+ if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
575
+ # Do not use in-place multiplication to avoid modifying `attn_mask_cache`!
576
+ attn_mask = attn_mask * -10000
577
+ chunk_mask_list = []
578
+ for i, seq_len in enumerate(seq_lens):
579
+ context_len = self.context_lens[i]
580
+ chunk_mask_list.append(attn_mask[context_len:seq_len])
581
+ self.chunk_mask = torch.cat(chunk_mask_list, 0)
582
+ else:
583
+ self.attn_mask = None
584
+ self.compress_mask = None
585
+ self.chunk_mask = None
586
+
587
+ assert max_query_len > 0, "query_lens: {}".format(query_lens)
588
+
589
+ assert device is not None
590
+ slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
591
+ device, self.runner.pin_memory)
592
+ seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
593
+ self.runner.pin_memory)
594
+ placeholder_index_maps = {
595
+ modality: placeholder_map.index_map()
596
+ for modality, placeholder_map in
597
+ self.multimodal_placeholder_maps.items()
598
+ }
599
+
600
+ return AscendMetadata(
601
+ num_prefills=self.num_prefills,
602
+ slot_mapping=slot_mapping_tensor,
603
+ num_prefill_tokens=self.num_prefill_tokens,
604
+ num_decode_tokens=num_decode_tokens,
605
+ seq_lens=seq_lens,
606
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
607
+ enable_kv_scales_calculation=True,
608
+ seq_lens_tensor=seq_lens_tensor,
609
+ query_lens=query_lens,
610
+ max_query_len=max_query_len,
611
+ max_prefill_seq_len=max_prefill_seq_len,
612
+ max_decode_seq_len=max_decode_seq_len,
613
+ block_tables=block_tables,
614
+ attn_mask=self.attn_mask,
615
+ compress_mask=self.compress_mask,
616
+ chunk_mask=self.chunk_mask,
617
+ chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
618
+ )
619
+
620
+
621
+ class AscendAttentionBackendImpl(AttentionImpl):
622
+
623
+ def __init__(
624
+ self,
625
+ num_heads: int,
626
+ head_size: int,
627
+ scale: float,
628
+ num_kv_heads: int,
629
+ alibi_slopes: Optional[List[float]],
630
+ sliding_window: Optional[int],
631
+ kv_cache_dtype: str,
632
+ blocksparse_params: Optional[Dict[str, Any]] = None,
633
+ logits_soft_cap: Optional[float] = None,
634
+ attn_type: str = AttentionType.DECODER,
635
+ kv_sharing_target_layer_name: Optional[str] = None,
636
+ use_irope: bool = False,
637
+ ) -> None:
638
+ self.num_heads = num_heads
639
+ self.head_size = head_size
640
+ self.scale = float(scale)
641
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
642
+ self.hidden_size = self.num_heads * self.head_size
643
+ self.kv_cache_dtype = kv_cache_dtype
644
+ self.sliding_window = sliding_window
645
+ if alibi_slopes is not None:
646
+ alibi_slopes = torch.tensor(alibi_slopes,
647
+ dtype=torch.float32,
648
+ device="npu")
649
+ self.alibi_slopes = alibi_slopes
650
+ self.attn_type = attn_type
651
+
652
+ assert self.num_heads % self.num_kv_heads == 0
653
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
654
+ self.seq_len_cpu_tensor = None
655
+ self.query_len_cpu_tensor = None
656
+ self.key_cache = None
657
+ self.value_cache = None
658
+
659
+ def forward(
660
+ self,
661
+ layer: AttentionLayer,
662
+ query: torch.Tensor,
663
+ key: torch.Tensor,
664
+ value: torch.Tensor,
665
+ kv_cache: torch.Tensor,
666
+ attn_metadata: AscendMetadata,
667
+ attn_type: str = AttentionType.DECODER,
668
+ output: Optional[torch.Tensor] = None,
669
+ ) -> torch.Tensor:
670
+ """Forward pass with Ascend attention.
671
+ Args:
672
+ query: shape = [num_tokens, num_heads * head_size]
673
+ num_tokens = batch_size * seq_len
674
+ key: shape = [num_tokens, num_kv_heads * head_size]
675
+ value: shape = [num_tokens, num_kv_heads * head_size]
676
+ kv_cache: shape = [2, num_blocks, block_size,
677
+ num_kv_heads, head_size]
678
+ key_cache = [num_blocks, block_size,
679
+ num_kv_heads, head_size]
680
+ value_cache = [num_blocks, block_size,
681
+ num_kv_heads, head_size]
682
+ attn_metadata: Metadata for attention.
683
+ Returns:
684
+ shape = [batch_size, seq_len * num_heads * head_size]
685
+ """
686
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
687
+ # View q k v to BSH.
688
+ num_tokens = query.shape[0]
689
+ query = query.view(-1, self.num_heads, self.head_size)
690
+ key = key.view(-1, self.num_kv_heads, self.head_size)
691
+ value = value.view(-1, self.num_kv_heads, self.head_size)
692
+ # TODO: Remove this contiguous in the future.
693
+ value = value.contiguous()
694
+ attn_type = self.attn_type
695
+
696
+ output = torch.empty(num_tokens,
697
+ self.num_heads,
698
+ self.head_size,
699
+ dtype=query.dtype,
700
+ device=query.device)
701
+
702
+ if kv_cache.numel() > 0:
703
+ if self.key_cache is None:
704
+ self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
705
+ slots = attn_metadata.slot_mapping
706
+
707
+ if hasattr(layer, 'quant_method'):
708
+ isPrefill = True if attn_metadata.num_prefills > 0 else False
709
+ if isPrefill:
710
+ assert attn_metadata.prefill_metadata is not None
711
+ self.seq_lens_tensor_cpu = torch.from_numpy(
712
+ np.array(attn_metadata.prefill_metadata.seq_lens).astype(
713
+ np.int32))
714
+ else:
715
+ assert attn_metadata.decode_metadata is not None
716
+ self.seq_lens_tensor_cpu = torch.from_numpy(
717
+ np.array(attn_metadata.decode_metadata.seq_lens).astype(
718
+ np.int32))
719
+ block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
720
+ # Details of kv_cache arrangement in attention quantization
721
+ # are implemented by quant_method.
722
+ layer.quant_method.apply(
723
+ layer,
724
+ query,
725
+ key,
726
+ value,
727
+ self.key_cache,
728
+ self.value_cache,
729
+ self.scale,
730
+ block_tables,
731
+ isPrefill,
732
+ attn_metadata,
733
+ output,
734
+ seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
735
+ else:
736
+ if self.key_cache is not None:
737
+ torch_npu._npu_reshape_and_cache(key=key,
738
+ value=value,
739
+ key_cache=self.key_cache,
740
+ value_cache=self.value_cache,
741
+ slot_indices=slots)
742
+
743
+ if attn_metadata.num_prefills > 0:
744
+ # Prefix cache disabled and chunk prefill disabled or no prefix cache hit
745
+ if (attn_metadata.block_tables is None
746
+ or attn_metadata.block_tables.numel() == 0):
747
+ if attn_type == AttentionType.ENCODER_ONLY:
748
+ # TODO: change to use torch_npu encoder attention op, instead
749
+ # of torch sdpa
750
+ query = query.movedim(0, query.dim() - 2)
751
+ key = key.movedim(0, key.dim() - 2)
752
+ value = value.movedim(0, value.dim() - 2)
753
+
754
+ causal_attn = (attn_type == AttentionType.DECODER)
755
+ if attn_metadata.seq_lens is not None:
756
+ seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
757
+ attn_masks = [None] * len(seq_lens_q)
758
+ start_q, start_kv = 0, 0
759
+ for seq_len_q, seq_len_kv, mask in zip(
760
+ seq_lens_q, seq_lens_kv, attn_masks):
761
+ end_q = start_q + seq_len_q
762
+ end_kv = start_kv + seq_len_kv
763
+ sub_out = scaled_dot_product_attention(
764
+ query[None, :, start_q:end_q, :],
765
+ key[None, :, start_kv:end_kv, :],
766
+ value[None, :, start_kv:end_kv, :],
767
+ attn_mask=mask,
768
+ dropout_p=0.0,
769
+ is_causal=causal_attn and mask is None,
770
+ scale=self.scale).squeeze(0).movedim(
771
+ query.dim() - 2, 0)
772
+ output[start_q:end_q, :, :] = sub_out
773
+ start_q, start_kv = end_q, end_kv
774
+ else:
775
+ assert attn_metadata.attn_mask is not None
776
+ mask = attn_metadata.attn_mask
777
+ assert attn_metadata.prefill_metadata is not None
778
+ self.seq_lens_tensor_cpu = torch.from_numpy(
779
+ np.array(attn_metadata.prefill_metadata.seq_lens).
780
+ astype(np.int32))
781
+ if is_310p():
782
+ # align q k v output tensors
783
+ query = aligned_16(query)
784
+ key = aligned_16(key)
785
+ value = aligned_16(value)
786
+ output = aligned_16(output)
787
+
788
+ # do reformat in case of broadcasted tensors
789
+ mask = mask.repeat(
790
+ self.seq_lens_tensor_cpu.size(0), 1, 1, 1)
791
+ mask = torch_npu.npu_format_cast(
792
+ mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
793
+ torch_npu._npu_flash_attention(
794
+ query=query,
795
+ key=key,
796
+ value=value,
797
+ mask=mask,
798
+ seq_len=self.seq_lens_tensor_cpu,
799
+ scale_value=self.scale,
800
+ num_heads=self.num_heads,
801
+ num_kv_heads=self.num_kv_heads,
802
+ out=output)
803
+ output = output[:num_tokens, :, :]
804
+ # Prefix cache only and cache hit
805
+ elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
806
+ assert kv_cache is not None
807
+ assert attn_metadata.prefill_metadata is not None
808
+ self.seq_lens_tensor_cpu = torch.from_numpy(
809
+ np.array(
810
+ attn_metadata.prefill_metadata.seq_lens).astype(
811
+ np.int32))
812
+ self.query_lens_tensor_cpu = torch.from_numpy(
813
+ np.array(
814
+ attn_metadata.prefill_metadata.query_lens).astype(
815
+ np.int32))
816
+ block_tables = attn_metadata.prefill_metadata.block_tables
817
+ assert attn_metadata.compress_mask is not None
818
+ compress_mask = attn_metadata.compress_mask
819
+ torch_npu._npu_flash_attention_qlens(
820
+ query=query,
821
+ key_cache=self.key_cache,
822
+ value_cache=self.value_cache,
823
+ block_table=block_tables,
824
+ mask=compress_mask,
825
+ seq_len=self.query_lens_tensor_cpu,
826
+ context_lens=self.seq_lens_tensor_cpu,
827
+ num_kv_heads=self.num_kv_heads,
828
+ num_heads=self.num_heads,
829
+ scale_value=self.scale,
830
+ out=output)
831
+ # Splitfuse
832
+ else:
833
+ assert kv_cache is not None
834
+ self.seq_lens_tensor_cpu = torch.from_numpy(
835
+ np.array(attn_metadata.seq_lens).astype(np.int32))
836
+ self.query_lens_tensor_cpu = torch.from_numpy(
837
+ np.array(attn_metadata.query_lens).astype(np.int32))
838
+ block_tables = attn_metadata.block_tables
839
+ assert attn_metadata.chunk_mask is not None
840
+ chunk_mask = attn_metadata.chunk_mask
841
+ torch_npu._npu_paged_attention_splitfuse(
842
+ query=query,
843
+ key_cache=self.key_cache,
844
+ value_cache=self.value_cache,
845
+ block_table=block_tables,
846
+ context_lens=self.seq_lens_tensor_cpu,
847
+ mask=chunk_mask,
848
+ seq_len=self.query_lens_tensor_cpu,
849
+ num_kv_heads=self.num_kv_heads,
850
+ num_heads=self.num_heads,
851
+ scale_value=self.scale,
852
+ out=output)
853
+ # Decode only
854
+ else:
855
+ assert self.key_cache is not None
856
+ assert self.value_cache is not None
857
+ assert attn_metadata.decode_metadata is not None
858
+ self.seq_lens_tensor_cpu = torch.from_numpy(
859
+ np.array(attn_metadata.decode_metadata.seq_lens).astype(
860
+ np.int32))
861
+ if is_310p():
862
+ # # seq_lens_tensor needs to be transferred to the device for 310P
863
+ self.seq_lens_tensor_cpu = self.seq_lens_tensor_cpu.to(
864
+ device=self.key_cache.device)
865
+ block_tables = attn_metadata.decode_metadata.block_tables
866
+ torch_npu._npu_paged_attention(
867
+ query=query,
868
+ key_cache=self.key_cache,
869
+ value_cache=self.value_cache,
870
+ num_kv_heads=self.num_kv_heads,
871
+ num_heads=self.num_heads,
872
+ scale_value=self.scale,
873
+ block_table=block_tables,
874
+ context_lens=self.seq_lens_tensor_cpu,
875
+ out=output)
876
+
877
+ return output.view(num_tokens, self.hidden_size)
878
+
879
+
880
+ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
881
+
882
+ def __init__(
883
+ self,
884
+ num_heads: int,
885
+ head_size: int,
886
+ scale: float,
887
+ num_kv_heads: int,
888
+ alibi_slopes: Optional[List[float]],
889
+ sliding_window: Optional[int],
890
+ kv_cache_dtype: str,
891
+ blocksparse_params: Optional[Dict[str, Any]] = None,
892
+ logits_soft_cap: Optional[float] = None,
893
+ attn_type: str = AttentionType.DECODER,
894
+ kv_sharing_target_layer_name: Optional[str] = None,
895
+ **extra_impl_args,
896
+ ) -> None:
897
+ self.num_heads = num_heads
898
+ self.head_size = head_size
899
+ self.scale = float(scale)
900
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
901
+ self.hidden_size = self.num_heads * self.head_size
902
+ self.kv_cache_dtype = kv_cache_dtype
903
+ self.sliding_window = sliding_window
904
+ if alibi_slopes is not None:
905
+ alibi_slopes = torch.tensor(alibi_slopes,
906
+ dtype=torch.float32,
907
+ device="npu")
908
+ self.alibi_slopes = alibi_slopes
909
+ self.attn_type = attn_type
910
+
911
+ assert self.num_heads % self.num_kv_heads == 0
912
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
913
+ self.seq_len_cpu_tensor = None
914
+
915
+ # MLA Args
916
+ self.q_lora_rank = extra_impl_args['q_lora_rank']
917
+ self.kv_lora_rank = extra_impl_args['kv_lora_rank']
918
+ self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim']
919
+ self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim']
920
+ self.qk_head_dim = extra_impl_args['qk_head_dim']
921
+ self.v_head_dim = extra_impl_args['v_head_dim']
922
+ self.rotary_emb = extra_impl_args['rotary_emb']
923
+ self.q_proj = extra_impl_args['q_proj']
924
+ self.kv_b_proj = extra_impl_args['kv_b_proj']
925
+ self.o_proj = extra_impl_args['o_proj']
926
+ self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
927
+ None)
928
+ self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
929
+ self.k_pe_cache = None
930
+ self.k_nope_cache = None
931
+ self.w_kc = None
932
+ self.w_vc = None
933
+
934
+ ascend_config = get_ascend_config()
935
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
936
+
937
+
938
+ def exec_kv(
939
+ self,
940
+ hidden_states: torch.Tensor,
941
+ cos: torch.Tensor,
942
+ sin: torch.Tensor,
943
+ kv_cache: Tuple,
944
+ slots: torch.Tensor,
945
+ ):
946
+ B = hidden_states.shape[0]
947
+ N = self.num_kv_heads
948
+ S = 1
949
+ kv = self.kv_a_proj_with_mqa(hidden_states)[0]
950
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
951
+ kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
952
+
953
+ k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
954
+ kv,
955
+ self.kv_a_layernorm.weight,
956
+ cos,
957
+ sin,
958
+ slots.to(torch.int64),
959
+ kv_cache[1],
960
+ kv_cache[0],
961
+ epsilon=self.kv_a_layernorm.variance_epsilon,
962
+ cache_mode="PA",
963
+ )
964
+
965
+ return k_pe, k_nope
966
+
967
+ def apply_rotary_emb(
968
+ self,
969
+ x: torch.Tensor,
970
+ cos: torch.Tensor,
971
+ sin: torch.Tensor,
972
+ is_neox_style: bool,
973
+ ) -> torch.Tensor:
974
+ """
975
+ Args:
976
+ x: [num_tokens, num_heads, head_size]
977
+ cos: [num_tokens, head_size // 2]
978
+ sin: [num_tokens, head_size // 2]
979
+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
980
+ positional embeddings.
981
+ """
982
+ cos = cos.unsqueeze(-2).to(x.dtype)
983
+ sin = sin.unsqueeze(-2).to(x.dtype)
984
+ if is_neox_style:
985
+ x1, x2 = torch.chunk(x, 2, dim=-1)
986
+ else:
987
+ x1 = x[..., ::2]
988
+ x2 = x[..., 1::2]
989
+ o1 = x1 * cos - x2 * sin
990
+ o2 = x2 * cos + x1 * sin
991
+ if is_neox_style:
992
+ return torch.cat((o1, o2), dim=-1)
993
+ else:
994
+ return torch.stack((o1, o2), dim=-1).flatten(-2)
995
+
996
+ def rope_single(
997
+ self,
998
+ x: torch.Tensor,
999
+ cos: torch.Tensor,
1000
+ sin: torch.Tensor,
1001
+ ) -> torch.Tensor:
1002
+ B, N, D = x.shape
1003
+ S = 1
1004
+ x = x.view(B, N, S, D)
1005
+ x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
1006
+ return x.view(B, N, D)
1007
+
1008
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
1009
+ if self.w_kc is None or self.w_vc is None:
1010
+ kv_b_proj_weight = self.kv_b_proj.weight.reshape(
1011
+ self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
1012
+ self.kv_lora_rank)
1013
+ self.w_kc = kv_b_proj_weight[:, :self.
1014
+ qk_nope_head_dim, :].contiguous()
1015
+ self.w_vc = kv_b_proj_weight[:,
1016
+ self.qk_nope_head_dim:, :].transpose(
1017
+ 1, 2).contiguous()
1018
+
1019
+ def forward(
1020
+ self,
1021
+ layer: AttentionLayer,
1022
+ hidden_states_or_q_c: torch.Tensor,
1023
+ hidden_states_or_kv_c_normed: torch.Tensor,
1024
+ k_pe: torch.Tensor,
1025
+ kv_cache: torch.Tensor,
1026
+ attn_metadata: AscendMetadata,
1027
+ attn_type: str = AttentionType.DECODER,
1028
+ output: Optional[torch.Tensor] = None,
1029
+ ) -> torch.Tensor:
1030
+ """Forward pass with Ascend attention.
1031
+ Args:
1032
+ hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
1033
+ num_tokens = batch_size * seq_len
1034
+ hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
1035
+ k_pe: shape = [num_tokens, num_kv_heads * head_size]
1036
+ kv_cache: shape = [1, num_blocks, block_size,
1037
+ num_kv_heads * head_size]
1038
+ attn_metadata: Metadata for attention.
1039
+ Returns:
1040
+ shape = [batch_size, seq_len * num_heads * head_size]
1041
+ """
1042
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
1043
+ attn_type = self.attn_type
1044
+ if attn_type != AttentionType.DECODER:
1045
+ raise NotImplementedError("Encoder self-attention and "
1046
+ "encoder/decoder cross-attention "
1047
+ "are not implemented for "
1048
+ "PallasAttentionBackendImpl")
1049
+
1050
+ if attn_metadata is None:
1051
+ # for profile run
1052
+ return hidden_states_or_q_c
1053
+
1054
+ num_tokens = hidden_states_or_q_c.shape[0]
1055
+ q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
1056
+ self.qk_head_dim)
1057
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
1058
+ dim=-1)
1059
+ if k_pe is None and attn_metadata.decode_metadata:
1060
+ seq_len = self.rotary_emb.max_position_embeddings
1061
+
1062
+ cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
1063
+ sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
1064
+ cos = cos[attn_metadata.input_positions]
1065
+ sin = sin[attn_metadata.input_positions]
1066
+ cos = cos[:, None, None, :]
1067
+ sin = sin[:, None, None, :]
1068
+
1069
+ q_pe = self.rope_single(q_pe, cos, sin)
1070
+ k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
1071
+ kv_cache, attn_metadata.slot_mapping)
1072
+ else:
1073
+ if k_pe is None:
1074
+ # NOTE: k_pe is None when graph mode enabled
1075
+ kv_c, k_pe = self.kv_a_proj_with_mqa(
1076
+ hidden_states_or_kv_c_normed)[0].split(
1077
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1078
+ kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1079
+ else:
1080
+ kv_c_normed = hidden_states_or_kv_c_normed
1081
+ k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
1082
+ if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
1083
+ # NOTE: When scaling not specified
1084
+ ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
1085
+ q_pe = q_pe.reshape(num_tokens, -1)
1086
+ k_pe = k_pe.reshape(num_tokens, -1)
1087
+ q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
1088
+ q_pe, k_pe)
1089
+ q_pe = q_pe.view(ori_q_pe_shape)
1090
+ k_pe = k_pe.view(ori_k_pe_shape)
1091
+ else:
1092
+ q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
1093
+ q_pe, k_pe)
1094
+
1095
+ if attn_metadata.num_prefills > 0:
1096
+ kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
1097
+ self.num_heads, -1)
1098
+ k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
1099
+ dim=-1)
1100
+ else:
1101
+ q_nope_t = torch.transpose(q_nope, 0, 1)
1102
+ q_nope_out = torch.bmm(q_nope_t, self.w_kc)
1103
+ q_nope = torch.transpose(q_nope_out, 0, 1)
1104
+
1105
+ query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
1106
+ self.num_heads, -1)
1107
+
1108
+ # TODO: Replace the env with more flexible expressions
1109
+ if self.torchair_graph_enabled:
1110
+ if len(kv_cache) > 0 and kv_cache[0].numel(
1111
+ ) > 0 and attn_metadata.num_prefills > 0:
1112
+ slots = attn_metadata.slot_mapping
1113
+ # NOTE: Separate the kv cache in advance to avoid OOM or other issues
1114
+ torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1115
+ num_tokens, self.num_kv_heads, -1),
1116
+ value=k_pe,
1117
+ key_cache=kv_cache[0],
1118
+ value_cache=kv_cache[1],
1119
+ slot_indices=slots)
1120
+ elif kv_cache.numel() > 0:
1121
+ # TODO replace this naive implement with fusion kernel
1122
+ concat_and_cache_mla(kv_c_normed, k_pe, kv_cache,
1123
+ attn_metadata.slot_mapping)
1124
+
1125
+ if attn_metadata.num_prefills > 0:
1126
+ attn_output = torch.empty(num_tokens,
1127
+ self.num_heads,
1128
+ self.v_head_dim,
1129
+ dtype=query.dtype,
1130
+ device=query.device)
1131
+ if (attn_metadata.block_tables is None
1132
+ or attn_metadata.block_tables.numel() == 0):
1133
+ assert attn_metadata.attn_mask is not None
1134
+ assert attn_metadata.prefill_metadata is not None
1135
+ assert attn_metadata.prefill_metadata.seq_lens is not None
1136
+ mask = attn_metadata.attn_mask
1137
+ self.seq_lens_tensor_cpu = torch.from_numpy(
1138
+ np.array(attn_metadata.prefill_metadata.seq_lens).astype(
1139
+ np.int32))
1140
+ k_pe = k_pe.repeat(1, self.num_heads, 1)
1141
+ key = torch.cat(
1142
+ [k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
1143
+ torch_npu._npu_flash_attention(
1144
+ query=query,
1145
+ key=key,
1146
+ value=value,
1147
+ mask=mask,
1148
+ seq_len=self.seq_lens_tensor_cpu,
1149
+ scale_value=self.scale,
1150
+ num_heads=self.num_heads,
1151
+ num_kv_heads=self.num_heads,
1152
+ out=attn_output)
1153
+ else:
1154
+ # TODO: Will support prefix cache and chunked prefill soon.
1155
+ raise RuntimeError(
1156
+ "Prefix cache and chunked prefill are currently not supported."
1157
+ )
1158
+ elif attn_metadata.decode_metadata:
1159
+ assert kv_cache is not None
1160
+ if self.torchair_graph_enabled:
1161
+ # shape of query for npu graph mode should be:
1162
+ # [bs, num_heads_per_rank, seq_len, dim]
1163
+ q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
1164
+ q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
1165
+ # shape of knope/k_pe for npu graph mode should be:
1166
+ # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
1167
+ block_size = kv_cache[0].shape[1]
1168
+ k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
1169
+ self.kv_lora_rank)
1170
+ k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
1171
+ self.qk_rope_head_dim)
1172
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
1173
+ q_nope,
1174
+ k_nope,
1175
+ k_nope,
1176
+ query_rope=q_pe,
1177
+ key_rope=k_pe,
1178
+ num_heads=self.num_heads,
1179
+ num_key_value_heads=self.num_kv_heads,
1180
+ input_layout="BNSD",
1181
+ atten_mask=attn_metadata.attn_mask,
1182
+ scale=self.scale,
1183
+ antiquant_mode=0,
1184
+ antiquant_scale=None,
1185
+ block_table=attn_metadata.block_tables,
1186
+ block_size=block_size,
1187
+ actual_seq_lengths_kv=attn_metadata.seq_lens,
1188
+ )
1189
+ attn_output = attn_output.view(num_tokens, -1,
1190
+ self.kv_lora_rank).transpose(
1191
+ 0, 1)
1192
+ attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
1193
+ else:
1194
+ # if torch.empty is used here, the preemptive scheduling case of
1195
+ # test_mtp_correctness.py will fail to run.
1196
+ attn_output = torch.randn(
1197
+ [num_tokens, self.num_heads, self.kv_lora_rank],
1198
+ dtype=query.dtype,
1199
+ device=query.device)
1200
+ self.seq_lens_tensor_cpu = torch.from_numpy(
1201
+ np.array(attn_metadata.decode_metadata.seq_lens).astype(
1202
+ np.int32))
1203
+ block_tables = attn_metadata.decode_metadata.block_tables
1204
+ torch_npu._npu_paged_attention_mla(
1205
+ query=query,
1206
+ key_cache=kv_cache,
1207
+ num_kv_heads=self.num_kv_heads,
1208
+ num_heads=self.num_heads,
1209
+ scale_value=self.scale,
1210
+ block_table=block_tables,
1211
+ context_lens=self.seq_lens_tensor_cpu,
1212
+ mla_vheadsize=self.kv_lora_rank,
1213
+ out=attn_output)
1214
+ attn_output_t = torch.transpose(attn_output, 0, 1)
1215
+ attn_output_t = torch.bmm(attn_output_t, self.w_vc)
1216
+ attn_output = torch.transpose(attn_output_t, 0, 1)
1217
+
1218
+ output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
1219
+
1220
+ return output
inference/vllm_ascend/attention/mla_v1.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch_npu
7
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
8
+ AttentionMetadata,
9
+ MLAAttentionImpl)
10
+ from vllm.attention.backends.utils import PAD_SLOT_ID
11
+ from vllm.config import get_current_vllm_config
12
+ from vllm.distributed import get_tensor_model_parallel_world_size
13
+ from vllm.model_executor.layers.linear import (LinearBase,
14
+ UnquantizedLinearMethod)
15
+ from vllm.utils import cdiv, round_down
16
+
17
+ from vllm_ascend.ascend_config import get_ascend_config
18
+ from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
19
+ from vllm_ascend.attention.attention_v1 import AscendAttentionState
20
+ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
21
+ from vllm_ascend.multistream.context import get_multistream_comm_context
22
+ from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
23
+ from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
24
+ from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
25
+ from vllm_ascend.worker.npu_input_batch import InputBatch
26
+
27
+ if TYPE_CHECKING:
28
+ from vllm.v1.core.sched.output import SchedulerOutput
29
+
30
+
31
+ @dataclass
32
+ class CommonAttentionMetadata:
33
+ """
34
+ Attention metadata attributes that can be shared by layers in different KV
35
+ cache groups and thus having different block table.
36
+ """
37
+
38
+ query_start_loc: torch.Tensor
39
+ """(batch_size + 1,), the start location of each request in query Tensor"""
40
+ seq_lens: torch.Tensor
41
+ """(batch_size,), the length of each request including both computed tokens
42
+ and newly scheduled tokens"""
43
+
44
+
45
+ class AscendMLABackend(AttentionBackend):
46
+
47
+ accept_output_buffer: bool = True
48
+
49
+ @staticmethod
50
+ def get_name() -> str:
51
+ return "VLLM_ASCEND_MLA"
52
+
53
+ @staticmethod
54
+ def get_metadata_cls() -> type["AttentionMetadata"]:
55
+ return AscendMLAMetadata
56
+
57
+ @staticmethod
58
+ def get_builder_cls():
59
+ return AscendMLAMetadataBuilder
60
+
61
+ @staticmethod
62
+ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
63
+ head_size: int) -> tuple[int, ...]:
64
+ return (num_blocks, block_size, num_kv_heads, head_size)
65
+
66
+ @staticmethod
67
+ def get_impl_cls() -> Type["MLAAttentionImpl"]:
68
+ return AscendMLAImpl
69
+
70
+
71
+ @dataclass
72
+ class AscendMLAPrefillMetadata:
73
+ """ Prefill Specific Metadata for Ascend"""
74
+
75
+ @dataclass
76
+ class ChunkedContextMetadata:
77
+ # New for MLA (compared to FlashAttention)
78
+ # For handling chunked prefill
79
+ cu_seq_lens: torch.Tensor
80
+ starts: torch.Tensor
81
+ seq_tot: list[int]
82
+ max_seq_lens: list[int]
83
+ workspace: torch.Tensor
84
+ chunk_seq_lens: torch.Tensor
85
+
86
+ attn_mask: torch.Tensor
87
+ query_lens: list[int]
88
+ seq_lens: list[int]
89
+ context_lens: torch.Tensor
90
+ input_positions: torch.Tensor
91
+ query_start_loc: torch.Tensor
92
+ block_table: torch.Tensor
93
+ max_query_len: int
94
+ max_seq_lens: int
95
+ chunked_context: Optional[ChunkedContextMetadata] = None
96
+
97
+
98
+ @dataclass
99
+ class AscendMLADecodeMetadata:
100
+ # Input positions for rotrary embeddings since for MLA the rotary
101
+ # position embeddings are applied inside the attention backend
102
+ input_positions: torch.Tensor
103
+ block_table: torch.Tensor
104
+ seq_lens: torch.Tensor
105
+ max_seq_lens: int
106
+ seq_lens_list: list[int]
107
+ attn_mask: Optional[torch.Tensor] = None
108
+
109
+
110
+ @dataclass
111
+ class AscendMLAMetadata:
112
+ """Metadata for MLACommon.
113
+
114
+ NOTE: Please read the comment at the top of the file before trying to
115
+ understand this class
116
+ """
117
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
118
+ # |---------- N-1 iteration --------|
119
+ # |---------------- N iteration ---------------------|
120
+ # |- tokenA -|......................|-- newTokens ---|
121
+ # |---------- context_len ----------|
122
+ # |-------------------- seq_len ---------------------|
123
+ # |-- query_len ---|
124
+
125
+ num_actual_tokens: int # Number of tokens excluding padding.
126
+ slot_mapping: torch.Tensor
127
+ query_start_loc: torch.Tensor
128
+ seq_lens: torch.Tensor
129
+ block_tables: torch.Tensor
130
+
131
+ # New for MLA (compared to FlashAttention)
132
+ # For handling prefill decode split
133
+ num_decodes: int
134
+ num_decode_tokens: int
135
+ num_prefills: int
136
+
137
+ # For logging.
138
+ num_input_tokens: int = 0 # Number of tokens including padding.
139
+
140
+ max_num_tokens_across_dp: int = 0
141
+ with_prefill_across_dp: bool = False
142
+
143
+ query_lens: Optional[list[int]] = None
144
+ # The dimension of the attention heads
145
+ head_dim: Optional[int] = None
146
+ attn_mask: torch.Tensor = None
147
+ # chunked prefill by default if no attn_states passed
148
+ attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
149
+
150
+ decode: Optional[AscendMLADecodeMetadata] = None
151
+ prefill: Optional[AscendMLAPrefillMetadata] = None
152
+
153
+ def __post_init__(self):
154
+ pass
155
+ # supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
156
+ # if self.head_dim is not None and self.head_dim \
157
+ # not in supported_head_sizes:
158
+ # raise ValueError(
159
+ # f"Only {supported_head_sizes} are supported for head_dim,",
160
+ # f"received {self.head_dim}.")
161
+
162
+ def split_metadata_for_multistream(
163
+ self,
164
+ ms_split_config: MSAttentionMetadataSplitConfig,
165
+ ) -> list["AscendMLAMetadata"]:
166
+ """Split metadata for multi-stream with AscendMLAMetadata"""
167
+ return model_input_split_v1_mla_attn(
168
+ ms_split_config=ms_split_config,
169
+ attn_metadata=self,
170
+ _metadata_cls=AscendMLAMetadata,
171
+ )
172
+
173
+
174
+ M = TypeVar("M", bound=AscendMLAMetadata)
175
+
176
+
177
+ class AscendMLAMetadataBuilder:
178
+ """
179
+ NOTE: Please read the comment at the top of the file before trying to
180
+ understand this class
181
+ """
182
+
183
+ # _attn_mask_builder = None
184
+ def __init__(self,
185
+ runner,
186
+ metadata_cls: Optional[AscendMLAMetadata] = None):
187
+ self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
188
+ if metadata_cls is not None else AscendMLAMetadata # type: ignore
189
+ self.runner = runner
190
+ scheduler_config = runner.scheduler_config
191
+ model_config = runner.model_config
192
+ self.block_size = runner.block_size
193
+ self.chunked_prefill_enabled = runner.chunked_prefill_enabled
194
+ if self.chunked_prefill_enabled:
195
+ self.chunked_prefill_workspace_size = min(
196
+ # Max sure there is enough for 8 full length request or at least
197
+ # 4 pages of cache per request
198
+ max(8 * model_config.max_model_len,
199
+ 4 * scheduler_config.max_num_seqs * self.block_size),
200
+ # For long-context models try not to over-allocate limiting
201
+ # kv-cache space, limiting it to 64k tokens,
202
+ # which would result in the workspace being:
203
+ # 2*(576)*(64*1024) = 144mb
204
+ # (assuming 576 MLA head dim, and fp16)
205
+ # which would result in up-projected context being
206
+ # 2*(192*128)*(64*1024) = 3gb
207
+ # (assuming 192 QK head dim, 128 heads, and fp16)
208
+ 128 * 1024)
209
+ assert self.chunked_prefill_workspace_size >= \
210
+ scheduler_config.max_num_seqs * self.block_size
211
+ self.chunked_prefill_workspace = torch.empty(
212
+ (self.chunked_prefill_workspace_size,
213
+ model_config.get_head_size()),
214
+ dtype=model_config.dtype,
215
+ device=runner.device,
216
+ )
217
+ ascend_config = get_ascend_config()
218
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
219
+
220
+ def reorder_batch(self, input_batch: "InputBatch",
221
+ scheduler_output: "SchedulerOutput") -> bool:
222
+ # We now want to reorder the batch so that the "decode" requests are at
223
+ # the front and the "prefill" requests are at the using the least amount
224
+ # swaps possible. (NOTE for now we loosely use "decode" to mean requests
225
+ # where attention is likely memory-bound and "prefill" to mean requests
226
+ # where attention is likely compute-bound, TODO(lucas): figure out a
227
+ # better naming here)
228
+ decodes = []
229
+ prefills = []
230
+ num_decode_tokens = 0
231
+ num_prefill_tokens = 0
232
+
233
+ for i, req_id in enumerate(input_batch.req_ids):
234
+ num_tokens = scheduler_output.num_scheduled_tokens[req_id]
235
+ num_spec_tokens = len(
236
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
237
+ # For torch air graph mode we treat spec decoding as decode.
238
+ if self.torchair_graph_enabled:
239
+ if num_tokens - num_spec_tokens == 1:
240
+ decodes.append(i)
241
+ num_decode_tokens += num_tokens
242
+ else:
243
+ prefills.append(i)
244
+ num_prefill_tokens += num_tokens
245
+ # For eager mode we treat spec decoding as chunked prefill.
246
+ else:
247
+ if num_tokens == 1:
248
+ decodes.append(i)
249
+ num_decode_tokens += num_tokens
250
+ else:
251
+ prefills.append(i)
252
+ num_prefill_tokens += num_tokens
253
+
254
+ # We hope that this is fairly minimal since decodes
255
+ # should be around for a number of iterations so hopefully they are
256
+ # relatively stationary (and new request are generally appended to the
257
+ # persistent batch so already should be at the back)
258
+ # To achieve this we loop over the decodes in descending order and
259
+ # the prefills in ascending order. We swap decodes from the "back"
260
+ # i.e. past where the last decode should be in the reodorered with
261
+ # prefills from the front of the batch.
262
+ # `decodes` and `prefills` are already in ascending order just based on
263
+ # the above loop
264
+ num_decodes = len(decodes)
265
+ num_prefills = len(prefills)
266
+ first_prefill = 0
267
+ modified_batch = False
268
+
269
+ for i in range(1, min(num_decodes, num_prefills) + 1):
270
+ # If the decode is at the "back" of the batch, i, we can swap it
271
+ # with the prefill closest to the front of the batch
272
+ if decodes[num_decodes - i] >= num_decodes:
273
+ input_batch.swap_states(prefills[first_prefill],
274
+ decodes[num_decodes - i])
275
+ first_prefill += 1
276
+ modified_batch = True
277
+ else:
278
+ break
279
+
280
+ # Save for next `build` call
281
+ # TODO(lucas): this is a bit of a hack, we should probably have a
282
+ # better way of doing this
283
+ self._num_decodes = num_decodes
284
+ self._num_prefills = num_prefills
285
+ self._num_decode_tokens = num_decode_tokens
286
+ self._num_prefill_tokens = num_prefill_tokens
287
+
288
+ return modified_batch
289
+
290
+ def _get_graph_runner_block_tables(
291
+ self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
292
+
293
+ max_batch_size, max_blocks = self.runner.graph_block_tables.shape
294
+ assert max_batch_size >= num_seqs
295
+
296
+ if isinstance(self.runner.graph_block_tables, np.ndarray):
297
+ graph_block_tables = torch.zeros((max_batch_size, max_blocks),
298
+ dtype=block_tables.dtype,
299
+ device=block_tables.device)
300
+ else:
301
+ graph_block_tables = self.runner.graph_block_tables.to(
302
+ device=block_tables.device, dtype=block_tables.dtype)
303
+
304
+ num_blocks = block_tables.size(1)
305
+ if num_blocks <= max_blocks:
306
+ graph_block_tables[:num_seqs, :
307
+ num_blocks] = block_tables[:num_seqs, :
308
+ num_blocks]
309
+ else:
310
+ graph_block_tables[:num_seqs, :
311
+ max_blocks] = block_tables[:num_seqs, :
312
+ max_blocks]
313
+
314
+ return graph_block_tables[:num_seqs, :max_blocks]
315
+
316
+ def build_dummy(self, num_reqs: int,
317
+ num_actual_tokens: int) -> AscendMLAMetadata:
318
+ device = self.runner.device
319
+ _, max_blocks = self.runner.graph_block_tables.shape
320
+ block_table = torch.zeros((num_reqs, max_blocks),
321
+ dtype=torch.int32,
322
+ device=device)
323
+ block_table = self._get_graph_runner_block_tables(
324
+ num_reqs, block_table)
325
+ seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
326
+ input_positions = torch.zeros(num_reqs,
327
+ dtype=torch.int32,
328
+ device=device).long()
329
+ slot_mapping = torch.full((num_reqs, ),
330
+ PAD_SLOT_ID,
331
+ dtype=torch.int32,
332
+ device=device)
333
+ query_start_loc = torch.full((num_reqs, ),
334
+ -1,
335
+ dtype=torch.int32,
336
+ device=device)
337
+ decode_metadata = AscendMLADecodeMetadata(
338
+ input_positions=input_positions,
339
+ block_table=block_table,
340
+ seq_lens=seq_lens,
341
+ seq_lens_list=seq_lens.tolist(),
342
+ max_seq_lens=1,
343
+ attn_mask=self.runner.spec_attn_mask)
344
+ return self.metadata_cls( # type: ignore
345
+ num_input_tokens=num_actual_tokens,
346
+ num_actual_tokens=num_actual_tokens,
347
+ slot_mapping=slot_mapping,
348
+ head_dim=self.runner.model_config.get_head_size(),
349
+ num_decodes=1,
350
+ num_decode_tokens=1,
351
+ num_prefills=0,
352
+ attn_mask=self.runner.attn_mask,
353
+ attn_state=AscendAttentionState.DecodeOnly,
354
+ prefill=None,
355
+ decode=decode_metadata,
356
+ query_start_loc=query_start_loc,
357
+ seq_lens=seq_lens,
358
+ block_tables=block_table,
359
+ )
360
+
361
+ def build(
362
+ self,
363
+ num_reqs: int,
364
+ num_actual_tokens: int,
365
+ max_query_len: int,
366
+ common_attn_metadata: CommonAttentionMetadata,
367
+ common_prefix_len: Optional[int] = None,
368
+ graph_pad_size: int = -1,
369
+ max_num_tokens_across_dp: int = 0,
370
+ with_prefill_across_dp: bool = False,
371
+ ) -> AscendMLAMetadata:
372
+ assert self._num_decodes + self._num_prefills == num_reqs
373
+
374
+ # Note(simon): be careful about the CPU <> GPU memory movement in this
375
+ # function. We should avoid GPU -> CPU sync as much as possible because
376
+ # it blocks on all previous kernels.
377
+ device = self.runner.device
378
+
379
+ block_table = (self.runner.input_batch.block_table[0].
380
+ get_device_tensor()[:num_reqs])
381
+ slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
382
+ device, non_blocking=True)
383
+ input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
384
+ device, non_blocking=True).long()
385
+
386
+ seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
387
+ query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
388
+ num_reqs]
389
+ seq_lens = seq_lens_cpu
390
+ max_query_len = query_lens.max().item()
391
+ max_seq_lens = seq_lens.max().item()
392
+ query_start_loc = common_attn_metadata.query_start_loc
393
+
394
+ prefill_metadata = None
395
+ chunked_context_metadata = None
396
+ if self._num_prefills > 0:
397
+ reqs_start = self._num_decodes # prefill_start
398
+ tokens_start = self._num_decode_tokens
399
+ max_query_len = query_lens[tokens_start:].max().item()
400
+ max_seq_lens = seq_lens[tokens_start:].max().item()
401
+ query_start_loc = common_attn_metadata.query_start_loc
402
+ prefill_query_start_loc = query_start_loc[
403
+ reqs_start:] - query_start_loc[reqs_start]
404
+
405
+ context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
406
+ reqs_start:num_reqs]
407
+ max_context_len_cpu = context_lens_cpu.max().item()
408
+ num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
409
+ if self.chunked_prefill_enabled and max_context_len_cpu > 0:
410
+ max_context_chunk = (self.chunked_prefill_workspace_size //
411
+ num_prefills_with_context_cpu)
412
+ max_context_chunk = round_down(max_context_chunk,
413
+ self.block_size)
414
+
415
+ assert max_context_chunk > 0
416
+ num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
417
+ chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
418
+ .unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
419
+ chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
420
+ chunk_starts + max_context_chunk)
421
+ chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
422
+ cu_seq_lens_cpu = torch.zeros(num_chunks,
423
+ self._num_prefills + 1,
424
+ dtype=torch.int32,
425
+ pin_memory=True)
426
+ torch.cumsum(chunk_seq_lens,
427
+ dim=1,
428
+ out=cu_seq_lens_cpu[:, 1:],
429
+ dtype=torch.int32)
430
+ chunked_context_metadata = \
431
+ AscendMLAPrefillMetadata.ChunkedContextMetadata(
432
+ cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
433
+ starts=chunk_starts.to(device, non_blocking=True),
434
+ seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
435
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
436
+ chunk_seq_lens=chunk_seq_lens,
437
+ workspace=self.chunked_prefill_workspace,
438
+ )
439
+
440
+ prefill_metadata = AscendMLAPrefillMetadata(
441
+ attn_mask=self.runner.attn_mask,
442
+ query_lens=query_lens[tokens_start:],
443
+ seq_lens=seq_lens,
444
+ context_lens=seq_lens[tokens_start:],
445
+ input_positions=input_positions[tokens_start:],
446
+ block_table=block_table[reqs_start:, ...],
447
+ max_query_len=max_query_len,
448
+ max_seq_lens=max_seq_lens,
449
+ query_start_loc=prefill_query_start_loc,
450
+ chunked_context=chunked_context_metadata,
451
+ )
452
+
453
+ decode_metadata = None
454
+ use_torchair_graph = graph_pad_size != -1
455
+ if self._num_decodes > 0:
456
+ max_seq_lens = seq_lens[:self._num_decodes].max().item()
457
+ seq_lens = seq_lens[:self._num_decode_tokens]
458
+ input_positions = input_positions[:self._num_decode_tokens]
459
+ block_table = block_table[:self._num_decode_tokens, ...]
460
+ if use_torchair_graph and self.runner.attn_state in [
461
+ AscendAttentionState.DecodeOnly,
462
+ AscendAttentionState.SpecDecoding
463
+ ]:
464
+ num_seqs = len(seq_lens)
465
+ if graph_pad_size != 0:
466
+ pad_value = 1
467
+ padded_seq_lens = seq_lens.tolist() + [pad_value
468
+ ] * graph_pad_size
469
+ else:
470
+ padded_seq_lens = seq_lens.tolist()
471
+
472
+ seq_lens = torch.from_numpy(
473
+ np.array(padded_seq_lens).astype(np.int32))
474
+ padding = torch.full((graph_pad_size, ),
475
+ PAD_SLOT_ID,
476
+ dtype=slot_mapping.dtype,
477
+ device=slot_mapping.device)
478
+ slot_mapping = torch.cat([slot_mapping, padding])
479
+ block_table_padding = torch.zeros(
480
+ (graph_pad_size, ) + block_table.shape[1:],
481
+ dtype=block_table.dtype,
482
+ device=block_table.device)
483
+ block_table = torch.cat([block_table, block_table_padding],
484
+ dim=0)
485
+ block_table = self._get_graph_runner_block_tables(
486
+ num_seqs + graph_pad_size, block_table)
487
+ padding_0 = torch.zeros(graph_pad_size,
488
+ dtype=input_positions.dtype,
489
+ device=input_positions.device)
490
+ input_positions = torch.cat([input_positions, padding_0])
491
+
492
+ decode_metadata = AscendMLADecodeMetadata(
493
+ input_positions=input_positions,
494
+ block_table=block_table,
495
+ seq_lens=seq_lens,
496
+ seq_lens_list=seq_lens.tolist(),
497
+ max_seq_lens=max_seq_lens,
498
+ attn_mask=self.runner.spec_attn_mask)
499
+
500
+ return self.metadata_cls( # type: ignore
501
+ num_actual_tokens=num_actual_tokens,
502
+ query_lens=query_lens.tolist(),
503
+ slot_mapping=slot_mapping,
504
+ head_dim=self.runner.model_config.get_head_size(),
505
+ num_decodes=self._num_decodes,
506
+ num_decode_tokens=self._num_decode_tokens,
507
+ num_prefills=self._num_prefills,
508
+ attn_mask=self.runner.attn_mask,
509
+ attn_state=self.runner.attn_state,
510
+ prefill=prefill_metadata,
511
+ decode=decode_metadata,
512
+ query_start_loc=query_start_loc,
513
+ block_tables=block_table,
514
+ seq_lens=seq_lens,
515
+ max_num_tokens_across_dp=max_num_tokens_across_dp,
516
+ with_prefill_across_dp=with_prefill_across_dp,
517
+ )
518
+
519
+
520
+ class AscendMLAImpl(MLAAttentionImpl):
521
+ """
522
+ NOTE: Please read the comment at the top of the file before trying to
523
+ understand this class
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ num_heads: int,
529
+ head_size: int,
530
+ scale: float,
531
+ num_kv_heads: int,
532
+ alibi_slopes: Optional[list[float]],
533
+ sliding_window: Optional[int],
534
+ kv_cache_dtype: str,
535
+ blocksparse_params: Optional[dict[str, Any]],
536
+ logits_soft_cap: Optional[float],
537
+ attn_type: str,
538
+ kv_sharing_target_layer_name: Optional[str] = None,
539
+ **kwargs,
540
+ ) -> None:
541
+ self.num_heads = num_heads
542
+ self.head_size = head_size
543
+ self.scale = float(scale)
544
+ self.num_kv_heads = num_kv_heads
545
+ self.kv_cache_dtype = kv_cache_dtype
546
+
547
+ # MLA Args
548
+ self.q_lora_rank = kwargs['q_lora_rank']
549
+ self.kv_lora_rank = kwargs['kv_lora_rank']
550
+ self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
551
+ self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
552
+ self.qk_head_dim = kwargs['qk_head_dim']
553
+ self.v_head_dim = kwargs['v_head_dim']
554
+ self.rotary_emb = kwargs['rotary_emb']
555
+ self.q_proj = kwargs['q_proj']
556
+ self.kv_b_proj = kwargs['kv_b_proj']
557
+ self.o_proj = kwargs['o_proj']
558
+ self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
559
+ self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
560
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
561
+ self.tp_size = get_tensor_model_parallel_world_size()
562
+
563
+ ascend_config = get_ascend_config()
564
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
565
+ self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
566
+
567
+ # Adapt torch air graph mode with spec decoding.
568
+ speculative_config = get_current_vllm_config().speculative_config
569
+ if speculative_config is not None:
570
+ self.spec_token_num = speculative_config.num_speculative_tokens
571
+ assert self.spec_token_num > 0
572
+ self.SHARE_MASK_TRIL_SPARSE = ~torch.tril(torch.ones((2048, 2048), dtype=torch.bool)).npu()
573
+
574
+ def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
575
+ # Convert from (B, N, L) to (N, B, L)
576
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
577
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
578
+ x = torch.bmm(x, self.W_UV)
579
+ # Convert from (N, B, V) to (B, N * V)
580
+ x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
581
+ MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
582
+ npu_prefetch(self.o_proj.weight,
583
+ x,
584
+ max_size=MAX_O_PROJ_PREFETCH_SIZE,
585
+ enabled=enable_multistream_mla)
586
+ return self.o_proj(x, is_prefill=False)[0]
587
+
588
+ # Return `ql_nope`, `q_pe`
589
+ def _q_proj_and_k_up_proj(self, x):
590
+ q_nope, q_pe = self.q_proj(x)[0]\
591
+ .view(-1, self.num_heads, self.qk_head_dim)\
592
+ .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
593
+
594
+ # Convert from (B, N, P) to (N, B, P)
595
+ q_nope = q_nope.transpose(0, 1)
596
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
597
+ ql_nope = torch.bmm(q_nope, self.W_UK_T)
598
+ # Convert from (N, B, L) to (B, N, L)
599
+ return ql_nope.transpose(0, 1), q_pe
600
+
601
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
602
+
603
+ def get_layer_weight(layer):
604
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
605
+ for attr in WEIGHT_NAMES:
606
+ if hasattr(layer, attr):
607
+ return getattr(layer, attr)
608
+ raise AttributeError(
609
+ f"Layer '{layer}' has no recognized weight attribute:"
610
+ f" {WEIGHT_NAMES}.")
611
+
612
+ def get_and_maybe_dequant_weights(layer: LinearBase):
613
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
614
+ # NOTE: This should only be used offline, since it's O(N^3)
615
+ eye = torch.eye(layer.input_size_per_partition,
616
+ dtype=act_dtype,
617
+ device=get_layer_weight(layer).device)
618
+ dequant_weights = layer.quant_method.apply(layer,
619
+ eye,
620
+ bias=None)
621
+ del eye
622
+ # standardize to (output, input)
623
+ return dequant_weights.T
624
+ return layer.weight
625
+
626
+ # we currently do not have quantized bmm's which are needed for
627
+ # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
628
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
629
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
630
+ assert kv_b_proj_weight.shape == (
631
+ self.kv_lora_rank,
632
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
633
+ f"{kv_b_proj_weight.shape=}, "
634
+ f"{self.kv_lora_rank=}, "
635
+ f"{self.num_heads=}, "
636
+ f"{self.qk_nope_head_dim=}, "
637
+ f"{self.v_head_dim=}")
638
+ kv_b_proj_weight = kv_b_proj_weight.view(
639
+ self.kv_lora_rank,
640
+ self.num_heads,
641
+ self.qk_nope_head_dim + self.v_head_dim,
642
+ )
643
+
644
+ W_UK, W_UV = kv_b_proj_weight.split(
645
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
646
+
647
+ # Convert from (L, N, V) to (N, L, V)
648
+ self.W_UV = W_UV.transpose(0, 1).contiguous()
649
+ # Convert from (L, N, P) to (N, P, L)
650
+ self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
651
+
652
+ # Waiting for BMM NZ support
653
+ # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
654
+ # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
655
+
656
+ def _compute_prefill_context(
657
+ self,
658
+ query: torch.Tensor,
659
+ kv_c_and_k_pe_cache: torch.Tensor,
660
+ rope_dim: int,
661
+ attn_metadata: AscendMLAMetadata,
662
+ prefix_output: torch.Tensor,
663
+ prefix_lse: torch.Tensor,
664
+ ):
665
+ prefill_metadata = attn_metadata.prefill
666
+ if prefill_metadata is None or prefill_metadata.chunked_context is None:
667
+ return prefix_output, prefix_lse
668
+
669
+ iters = len(prefill_metadata.chunked_context.seq_tot)
670
+ q_pe = query[..., self.qk_nope_head_dim:]
671
+ q_nope = query[..., :self.qk_nope_head_dim]
672
+
673
+ seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
674
+ latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
675
+ cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
676
+ cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
677
+ for i in range(iters):
678
+ toks = prefill_metadata.chunked_context.seq_tot[i]
679
+
680
+ seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
681
+ seq_len = torch.stack([seq_len1, seq_len2])
682
+ kv_c_normed = torch.empty(toks,
683
+ kv_c_and_k_pe_cache.size(2),
684
+ latent_kv_dim,
685
+ dtype=query.dtype,
686
+ device=query.device)
687
+ k_pe = torch.empty(toks,
688
+ kv_c_and_k_pe_cache.size(2),
689
+ rope_dim,
690
+ dtype=query.dtype,
691
+ device=query.device)
692
+
693
+ torch_npu.atb.npu_paged_cache_load(
694
+ cache_kv_c,
695
+ cache_k_pe,
696
+ prefill_metadata.block_table,
697
+ seq_len2.to(query.device),
698
+ seq_starts=prefill_metadata.chunked_context.starts[i],
699
+ key=kv_c_normed,
700
+ value=k_pe,
701
+ )
702
+
703
+ kv_c_normed = kv_c_normed.squeeze()
704
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
705
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
706
+ k_nope, v = kv_nope\
707
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
708
+ k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
709
+ mask = torch.triu(
710
+ torch.ones(512, 512, device=query.device, dtype=query.dtype),
711
+ 1)
712
+ torch_npu.atb.npu_ring_mla(
713
+ q_nope=q_nope,
714
+ q_rope=q_pe,
715
+ k_nope=k_nope,
716
+ k_rope=k_pe,
717
+ value=v,
718
+ mask=mask,
719
+ seqlen=seq_len,
720
+ head_num=self.num_heads,
721
+ kv_head_num=self.num_heads,
722
+ pre_out=prefix_output,
723
+ prev_lse=prefix_lse,
724
+ qk_scale=self.scale,
725
+ kernel_type="kernel_type_high_precision",
726
+ mask_type="no_mask",
727
+ input_layout="type_bsnd",
728
+ calc_type="calc_type_default",
729
+ output=prefix_output,
730
+ softmax_lse=prefix_lse)
731
+ return prefix_output, prefix_lse
732
+
733
+ def _forward_prefill(
734
+ self,
735
+ query: torch.Tensor,
736
+ kv_c_normed: torch.Tensor,
737
+ k_pe: torch.Tensor,
738
+ kv_c_and_k_pe_cache: torch.Tensor,
739
+ attn_metadata: AscendMLAMetadata,
740
+ ) -> torch.Tensor:
741
+ assert attn_metadata.prefill is not None
742
+
743
+ num_tokens = query.size(0)
744
+ attn_output = torch.empty(num_tokens,
745
+ self.num_heads,
746
+ self.v_head_dim,
747
+ dtype=query.dtype,
748
+ device=query.device)
749
+ k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
750
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
751
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
752
+ k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
753
+ # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
754
+ ascend_config = get_ascend_config()
755
+
756
+ if attn_metadata.attn_state in [
757
+ AscendAttentionState.ChunkedPrefill,
758
+ AscendAttentionState.SpecDecoding,
759
+ AscendAttentionState.PrefillCacheHit
760
+ ] and not ascend_config.chunked_prefill_for_mla:
761
+ attn_output_torch = torch.empty(num_tokens,
762
+ self.num_heads * self.v_head_dim,
763
+ dtype=query.dtype,
764
+ device=query.device)
765
+ # current requests is chunked in prefill, disable flash attention with chunked prefill
766
+ vanilla_chunked_prefill_mla(
767
+ output=attn_output_torch,
768
+ query=query,
769
+ kv_cache=kv_c_and_k_pe_cache,
770
+ block_tables=attn_metadata.prefill.block_table,
771
+ query_lens=attn_metadata.prefill.query_lens,
772
+ context_lens=attn_metadata.prefill.context_lens,
773
+ kv_b_proj=self.kv_b_proj,
774
+ max_query_len=attn_metadata.prefill.max_query_len,
775
+ max_context_len=attn_metadata.prefill.max_seq_lens,
776
+ nope_dim=self.qk_nope_head_dim,
777
+ rope_dim=self.qk_rope_head_dim,
778
+ v_head_dim=self.v_head_dim,
779
+ scale=self.scale,
780
+ alibi_slopes=None,
781
+ causal=True)
782
+ elif attn_metadata.attn_state in [
783
+ AscendAttentionState.ChunkedPrefill,
784
+ AscendAttentionState.SpecDecoding,
785
+ AscendAttentionState.PrefillCacheHit
786
+ ]:
787
+ attn_lse = torch.empty(self.num_heads,
788
+ num_tokens,
789
+ dtype=torch.float32,
790
+ device=query.device)
791
+ q_pe = query[..., self.qk_nope_head_dim:]
792
+ q_nope = query[..., :self.qk_nope_head_dim]
793
+ mask = torch.triu(
794
+ torch.ones(512, 512, device=query.device, dtype=query.dtype),
795
+ 1) # 512: mask only support 512
796
+ if attn_metadata.num_prefills > 1:
797
+ mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
798
+ 1)
799
+ torch_npu.atb.npu_ring_mla(
800
+ q_nope=q_nope,
801
+ q_rope=q_pe,
802
+ k_nope=k_nope,
803
+ k_rope=k_pe,
804
+ value=value,
805
+ mask=mask,
806
+ seqlen=torch.tensor(attn_metadata.prefill.query_lens,
807
+ dtype=torch.int32),
808
+ head_num=self.num_heads,
809
+ kv_head_num=self.num_heads,
810
+ pre_out=None,
811
+ prev_lse=None,
812
+ qk_scale=self.scale,
813
+ kernel_type="kernel_type_high_precision",
814
+ mask_type="mask_type_triu",
815
+ input_layout="type_bsnd",
816
+ calc_type="calc_type_first_ring",
817
+ output=attn_output,
818
+ softmax_lse=attn_lse)
819
+ attn_output, attn_lse = self._compute_prefill_context( \
820
+ query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
821
+
822
+ elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
823
+ key = torch.cat((k_nope, k_pe), dim=-1)
824
+ context_lens_list = torch.cumsum(attn_metadata.prefill.context_lens, dim=0).tolist()
825
+ attn_output = torch_npu.npu_fused_infer_attention_score(
826
+ query,
827
+ key,
828
+ value,
829
+ num_heads=self.num_heads,
830
+ input_layout="TND",
831
+ scale=self.scale,
832
+ sparse_mode=3,
833
+ atten_mask=self.SHARE_MASK_TRIL_SPARSE,
834
+ actual_seq_lengths=context_lens_list,
835
+ actual_seq_lengths_kv=context_lens_list,
836
+ inner_precise=0)[0]
837
+ attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
838
+ else:
839
+ raise RuntimeError(
840
+ "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
841
+ )
842
+ attn_output = attn_output.reshape(
843
+ [num_tokens, self.num_heads * self.v_head_dim])
844
+ if attn_metadata.attn_state in [
845
+ AscendAttentionState.ChunkedPrefill,
846
+ AscendAttentionState.SpecDecoding,
847
+ AscendAttentionState.PrefillCacheHit
848
+ ] and not ascend_config.chunked_prefill_for_mla:
849
+ attn_output = attn_output_torch
850
+
851
+ current_ms_metadata = get_multistream_comm_context()
852
+ if current_ms_metadata is None:
853
+ return self.o_proj(attn_output, is_prefill=True)[0]
854
+ else:
855
+ current_ms_metadata.before_comm_event.record()
856
+ with torch.npu.stream(current_ms_metadata.comm_stream):
857
+ current_ms_metadata.before_comm_event.wait()
858
+ return self.o_proj(attn_output, is_prefill=True)[0]
859
+
860
+ def exec_kv(
861
+ self,
862
+ hidden_states: torch.Tensor,
863
+ cos: torch.Tensor,
864
+ sin: torch.Tensor,
865
+ kv_cache: Tuple,
866
+ slots: torch.Tensor,
867
+ ):
868
+
869
+ B = hidden_states.shape[0]
870
+ N = self.num_kv_heads
871
+ S = 1
872
+ kv = self.kv_a_proj_with_mqa(hidden_states)[0]
873
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
874
+ kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
875
+ cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
876
+ k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
877
+ kv,
878
+ self.kv_a_layernorm.weight,
879
+ cos,
880
+ sin,
881
+ slots.to(torch.int64),
882
+ kv_cache[1],
883
+ kv_cache[0],
884
+ epsilon=self.kv_a_layernorm.variance_epsilon,
885
+ cache_mode=cache_mode,
886
+ )
887
+ return k_pe, k_nope, kv
888
+
889
+ def exec_kv_prefill(
890
+ self,
891
+ hidden_states: torch.Tensor,
892
+ cos: torch.Tensor,
893
+ sin: torch.Tensor,
894
+ kv_cache: Tuple,
895
+ slots: torch.Tensor,
896
+ ):
897
+
898
+ B = hidden_states.shape[0]
899
+ N = self.num_kv_heads
900
+ S = 1
901
+ kv = self.kv_a_proj_with_mqa(hidden_states)[0]
902
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
903
+ kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
904
+ cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
905
+ _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
906
+ kv,
907
+ self.kv_a_layernorm.weight,
908
+ cos,
909
+ sin,
910
+ slots.to(torch.int64),
911
+ kv_cache[1],
912
+ kv_cache[0],
913
+ epsilon=self.kv_a_layernorm.variance_epsilon,
914
+ cache_mode=cache_mode,
915
+ is_output_kv=True,
916
+ )
917
+ return k_pe, k_nope
918
+
919
+ def rope_single(
920
+ self,
921
+ x: torch.Tensor,
922
+ cos: torch.Tensor,
923
+ sin: torch.Tensor,
924
+ ) -> torch.Tensor:
925
+ B, N, D = x.shape
926
+ S = 1
927
+ x = x.view(B, N, S, D)
928
+ x = torch_npu.npu_interleave_rope(x, cos, sin)
929
+ return x.view(B, N, D)
930
+
931
+ def _forward_decode(
932
+ self,
933
+ q_nope: torch.Tensor,
934
+ q_pe: torch.Tensor,
935
+ k_nope: torch.Tensor,
936
+ k_pe: torch.Tensor,
937
+ kv_c_and_k_pe_cache: torch.Tensor,
938
+ attn_metadata: AscendMLAMetadata,
939
+ enable_multistream_mla: bool = False,
940
+ ) -> torch.Tensor:
941
+ decode_meta = attn_metadata.decode
942
+ assert decode_meta is not None
943
+
944
+ q = torch.cat([q_nope, q_pe], dim=-1)
945
+ num_tokens = q.size(0)
946
+ attn_output = torch.empty(
947
+ [num_tokens, self.num_heads, self.kv_lora_rank],
948
+ dtype=q.dtype,
949
+ device=q.device)
950
+ if self.running_in_graph:
951
+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
952
+ if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
953
+ assert num_tokens % self.spec_token_num == 0
954
+ q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
955
+ self.spec_token_num + 1, self.num_heads,
956
+ -1)
957
+ q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
958
+ self.spec_token_num + 1, self.num_heads, -1)
959
+ if not self.enable_kv_nz:
960
+ q_nope = q_nope.transpose(1, 2).contiguous()
961
+ q_pe = q_pe.transpose(1, 2).contiguous()
962
+ sparse_mode = 3
963
+ spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
964
+ else:
965
+ if self.enable_kv_nz:
966
+ q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
967
+ q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
968
+ else:
969
+ q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
970
+ q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
971
+ sparse_mode = 0
972
+ spec_attn_mask = None
973
+ # shape of knope/k_pe for npu graph mode should be:
974
+ # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
975
+ block_size = kv_c_and_k_pe_cache[0].shape[1]
976
+ if self.enable_kv_nz:
977
+ k_nope = k_nope.view(-1, self.num_kv_heads,
978
+ self.kv_lora_rank // 16, block_size, 16)
979
+ k_pe = k_pe.view(-1, self.num_kv_heads,
980
+ self.qk_rope_head_dim // 16, block_size, 16)
981
+ input_layout = "BSND"
982
+ else:
983
+ k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
984
+ self.kv_lora_rank)
985
+ k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
986
+ self.qk_rope_head_dim)
987
+ input_layout = "BNSD"
988
+
989
+ attn_output, _ = torch_npu.npu_fused_infer_attention_score(
990
+ q_nope,
991
+ k_nope,
992
+ k_nope,
993
+ query_rope=q_pe,
994
+ key_rope=k_pe,
995
+ num_heads=self.num_heads,
996
+ num_key_value_heads=self.num_kv_heads,
997
+ input_layout=input_layout,
998
+ atten_mask=spec_attn_mask,
999
+ sparse_mode=sparse_mode,
1000
+ scale=self.scale,
1001
+ antiquant_mode=0,
1002
+ antiquant_scale=None,
1003
+ block_table=decode_meta.block_table,
1004
+ block_size=block_size,
1005
+ actual_seq_lengths_kv=decode_meta.seq_lens_list,
1006
+ )
1007
+ else:
1008
+ torch_npu._npu_paged_attention_mla(
1009
+ query=q,
1010
+ key_cache=kv_c_and_k_pe_cache,
1011
+ num_kv_heads=self.num_kv_heads,
1012
+ num_heads=self.num_heads,
1013
+ scale_value=self.scale,
1014
+ block_table=attn_metadata.decode.block_table, # type:ignore
1015
+ context_lens=attn_metadata.decode.seq_lens, # type:ignore
1016
+ mla_vheadsize=self.kv_lora_rank,
1017
+ out=attn_output)
1018
+ current_ms_metadata = get_multistream_comm_context()
1019
+ if current_ms_metadata is None:
1020
+ return self._v_up_proj_and_o_proj(attn_output,
1021
+ enable_multistream_mla)
1022
+ else:
1023
+ current_ms_metadata.before_comm_event.record()
1024
+ with torch.npu.stream(current_ms_metadata.comm_stream):
1025
+ current_ms_metadata.before_comm_event.wait()
1026
+ return self._v_up_proj_and_o_proj(attn_output)
1027
+
1028
+ def forward(
1029
+ self,
1030
+ layer: AttentionLayer,
1031
+ hidden_states_or_q_c: torch.Tensor, # query in unified attn
1032
+ hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
1033
+ k_pe: torch.Tensor, # value in unified attn
1034
+ kv_cache: torch.Tensor,
1035
+ attn_metadata: M,
1036
+ output: Optional[torch.Tensor] = None,
1037
+ enable_multistream_mla: bool = False,
1038
+ ckq: Optional[torch.Tensor] = None,
1039
+ ) -> torch.Tensor:
1040
+ assert output is not None, "Output tensor must be provided."
1041
+ if attn_metadata is None:
1042
+ # Profiling run.
1043
+ return output
1044
+ self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
1045
+ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
1046
+ ]
1047
+ num_actual_toks = attn_metadata.num_actual_tokens
1048
+ if k_pe is None and not self.running_in_graph:
1049
+ if not self.torchair_graph_enabled:
1050
+ kv_c, k_pe = self.kv_a_proj_with_mqa(
1051
+ hidden_states_or_kv_c_normed)[0].split(
1052
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1053
+ kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1054
+ else:
1055
+ kv_c_normed = hidden_states_or_kv_c_normed
1056
+ assert attn_metadata.num_decodes is not None and \
1057
+ attn_metadata.num_prefills is not None and \
1058
+ attn_metadata.num_decode_tokens is not None
1059
+ has_decode = attn_metadata.num_decodes > 0
1060
+ has_prefill = attn_metadata.num_prefills > 0
1061
+ num_decode_tokens = attn_metadata.num_decode_tokens
1062
+ if not self.running_in_graph:
1063
+ # Inputs and outputs may be padded for CUDA graphs
1064
+ output_padded = output
1065
+ output = output[:num_actual_toks, ...]
1066
+ if not self.torchair_graph_enabled:
1067
+ kv_c_normed = kv_c_normed[:num_actual_toks, ...]
1068
+ prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
1069
+ if not self.running_in_graph:
1070
+ hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
1071
+ prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
1072
+ if not self.torchair_graph_enabled:
1073
+ decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1074
+ k_pe = k_pe[:num_actual_toks, ...]
1075
+ k_pe = k_pe.unsqueeze(1)
1076
+ decode_k_pe = k_pe[:num_decode_tokens]
1077
+ prefill_k_pe = k_pe[num_decode_tokens:]
1078
+ else:
1079
+ decode_hs_or_q_c = hidden_states_or_q_c
1080
+ if has_decode:
1081
+ decode_k_nope = None
1082
+ assert attn_metadata.decode is not None
1083
+ if self.running_in_graph:
1084
+ seq_len = self.rotary_emb.max_position_embeddings * \
1085
+ getattr(self.rotary_emb, "scaling_factor", 1)
1086
+ cos = self.rotary_emb.cos_cached[:seq_len].to(
1087
+ dtype=decode_hs_or_q_c.dtype)
1088
+ sin = self.rotary_emb.sin_cached[:seq_len].to(
1089
+ dtype=decode_hs_or_q_c.dtype)
1090
+ cos = cos[attn_metadata.decode.input_positions]
1091
+ sin = sin[attn_metadata.decode.input_positions]
1092
+ cos = cos[:, None, None, :]
1093
+ sin = sin[:, None, None, :]
1094
+ with npu_stream_switch("mla_secondary",
1095
+ 0,
1096
+ enabled=enable_multistream_mla):
1097
+ npu_wait_tensor(hidden_states_or_kv_c_normed,
1098
+ ckq,
1099
+ enabled=enable_multistream_mla)
1100
+ decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1101
+ hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1102
+ attn_metadata.slot_mapping)
1103
+ # Without explicitly controlling the order, IndexByTensor operations
1104
+ # would be placed after `matmul W_KV_T` hindering the overlapping of
1105
+ # KvRmsNormRopeCache and SingleRope.
1106
+ npu_wait_tensor(decode_hs_or_q_c,
1107
+ cos,
1108
+ enabled=enable_multistream_mla)
1109
+ npu_wait_tensor(decode_hs_or_q_c,
1110
+ sin,
1111
+ enabled=enable_multistream_mla)
1112
+ npu_wait_tensor(decode_hs_or_q_c,
1113
+ decode_kv,
1114
+ enabled=enable_multistream_mla)
1115
+
1116
+ decode_ql_nope, decode_q_pe = \
1117
+ self._q_proj_and_k_up_proj(decode_hs_or_q_c)
1118
+ if self.running_in_graph:
1119
+ with npu_stream_switch("mla_secondary",
1120
+ 0,
1121
+ enabled=enable_multistream_mla):
1122
+ npu_wait_tensor(decode_q_pe,
1123
+ decode_k_pe,
1124
+ enabled=enable_multistream_mla)
1125
+ decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1126
+ else:
1127
+ decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
1128
+ attn_metadata.decode.input_positions,
1129
+ decode_q_pe.contiguous(),
1130
+ decode_k_pe,
1131
+ max_seq_len=attn_metadata.decode.max_seq_lens)
1132
+ if has_prefill:
1133
+ assert attn_metadata.prefill is not None
1134
+ prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
1135
+ .view(-1, self.num_heads, self.qk_head_dim)
1136
+ prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
1137
+ prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
1138
+ if self.torchair_graph_enabled:
1139
+ num_tokens = prefill_hs_or_q_c.shape[0]
1140
+ seq_len = self.rotary_emb.max_position_embeddings * \
1141
+ getattr(self.rotary_emb, "scaling_factor", 1)
1142
+ cos = self.rotary_emb.cos_cached[:seq_len].to(
1143
+ dtype=prefill_q_pe.dtype)
1144
+ sin = self.rotary_emb.sin_cached[:seq_len].to(
1145
+ dtype=prefill_q_pe.dtype)
1146
+ cos = cos[attn_metadata.prefill.input_positions]
1147
+ sin = sin[attn_metadata.prefill.input_positions]
1148
+ cos = cos[:, None, None, :]
1149
+ sin = sin[:, None, None, :]
1150
+
1151
+ prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
1152
+ prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
1153
+ hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1154
+ attn_metadata.slot_mapping)
1155
+
1156
+ kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
1157
+ prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
1158
+ prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
1159
+ -1)
1160
+ prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
1161
+ else:
1162
+ prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
1163
+ attn_metadata.prefill.input_positions,
1164
+ prefill_q_pe.contiguous(),
1165
+ prefill_k_pe,
1166
+ max_seq_len=attn_metadata.prefill.max_seq_lens)
1167
+ if self.torchair_graph_enabled:
1168
+ if len(kv_cache) > 0 and kv_cache[0].numel(
1169
+ ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1170
+ slots = attn_metadata.slot_mapping
1171
+ # NOTE: Separate the kv cache in advance to avoid OOM or other issues
1172
+ torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1173
+ num_tokens, self.num_kv_heads, -1),
1174
+ value=prefill_k_pe,
1175
+ key_cache=kv_cache[0],
1176
+ value_cache=kv_cache[1],
1177
+ slot_indices=slots)
1178
+ elif kv_cache.numel() > 0:
1179
+ key = torch.cat([
1180
+ kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
1181
+ k_pe
1182
+ ],
1183
+ dim=2)
1184
+ torch_npu._npu_reshape_and_cache_siso(
1185
+ key=key,
1186
+ key_cache=kv_cache,
1187
+ slot_indices=attn_metadata.slot_mapping.flatten())
1188
+ if has_prefill:
1189
+ # FIX: aicore move should be also placed on the comm stream in dbo,
1190
+ # otherwise it may affect the accuracy
1191
+ # TODO: use an elegant way to overlap
1192
+ output_prefill = self._forward_prefill(prefill_q,
1193
+ prefill_k_c_normed,
1194
+ prefill_k_pe, kv_cache,
1195
+ attn_metadata)
1196
+ current_ms_metadata = get_multistream_comm_context()
1197
+ if current_ms_metadata is not None:
1198
+ with torch.npu.stream(current_ms_metadata.comm_stream):
1199
+ output[num_decode_tokens:] = output_prefill
1200
+ current_ms_metadata.after_comm_event.record()
1201
+ else:
1202
+ output[num_decode_tokens:] = output_prefill
1203
+
1204
+ if has_decode:
1205
+ if self.running_in_graph:
1206
+ return self._forward_decode(decode_ql_nope, decode_q_pe,
1207
+ decode_k_nope, decode_k_pe,
1208
+ kv_cache, attn_metadata,
1209
+ enable_multistream_mla)
1210
+ else:
1211
+ output_decode = self._forward_decode(decode_ql_nope,
1212
+ decode_q_pe,
1213
+ decode_k_nope,
1214
+ decode_k_pe, kv_cache,
1215
+ attn_metadata)
1216
+ current_ms_metadata = get_multistream_comm_context()
1217
+ if current_ms_metadata is not None:
1218
+ with torch.npu.stream(current_ms_metadata.comm_stream):
1219
+ output[:num_decode_tokens] = output_decode
1220
+ current_ms_metadata.after_comm_event.record()
1221
+ else:
1222
+ output[:num_decode_tokens] = output_decode
1223
+
1224
+ return output_padded
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ from .pangu_reasoning_parser import PanguReasoningParser
3
+
4
+ __all__ = [
5
+ "PanguReasoningParser"
6
+ ]
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
4
+
5
+ from collections.abc import Sequence
6
+ from typing import Optional, Union
7
+
8
+ from transformers import PreTrainedTokenizerBase
9
+
10
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
11
+ DeltaMessage)
12
+ from vllm.logger import init_logger
13
+ from vllm.reasoning import ReasoningParser, ReasoningParserManager
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ @ReasoningParserManager.register_module("pangu")
19
+ class PanguReasoningParser(ReasoningParser):
20
+ """
21
+ Reasoning parser for Pangu model.
22
+
23
+ The Pangu model uses [unused16]...[unused17] tokens to denote reasoning
24
+ text. This parser extracts the reasoning content from the model output.
25
+ """
26
+
27
+ start_token_id: int
28
+ end_token_id: int
29
+
30
+ start_token: str = "[unused16]"
31
+ end_token: str = "[unused17]"
32
+
33
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
34
+ super().__init__(tokenizer)
35
+
36
+ if not self.model_tokenizer:
37
+ raise ValueError(
38
+ "The model tokenizer must be passed to the ReasoningParser "
39
+ "constructor during construction.")
40
+
41
+ self.start_token_id = self.vocab.get(self.start_token)
42
+ self.end_token_id = self.vocab.get(self.end_token)
43
+ if self.start_token_id is None or self.end_token_id is None:
44
+ raise RuntimeError(
45
+ "Pangu reasoning parser could not locate think start/end "
46
+ "tokens in the tokenizer!")
47
+
48
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
49
+ return self.end_token_id in input_ids
50
+
51
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
52
+ """
53
+ Extract the content after the end tokens
54
+ """
55
+ if self.end_token_id not in input_ids[:-1]:
56
+ return []
57
+ else:
58
+ return input_ids[input_ids.index(self.end_token_id) + 1:]
59
+
60
+ def extract_reasoning_content_streaming(
61
+ self,
62
+ previous_text: str,
63
+ current_text: str,
64
+ delta_text: str,
65
+ previous_token_ids: Sequence[int],
66
+ current_token_ids: Sequence[int],
67
+ delta_token_ids: Sequence[int],
68
+ ) -> Union[DeltaMessage, None]:
69
+ """
70
+ Extract reasoning content from a delta message.
71
+ Handles streaming output where previous + delta = current.
72
+ Uses token IDs for faster processing.
73
+ For text [unused16]abc[unused17]xyz:
74
+ - 'abc' goes to reasoning_content
75
+ - 'xyz' goes to content
76
+ """
77
+ # Skip single special tokens
78
+ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
79
+ self.start_token_id, self.end_token_id
80
+ ]):
81
+ return None
82
+
83
+ # Check if [unused16] is present in previous or delta.
84
+ # Keep compatibility with models that don't generate [unused16] tokens.
85
+ if self.start_token_id in previous_token_ids:
86
+ if self.end_token_id in delta_token_ids:
87
+ # [unused16] in previous, [unused17] in delta,
88
+ # extract reasoning content
89
+ end_index = delta_text.find(self.end_token)
90
+ reasoning_content = delta_text[:end_index]
91
+ content = delta_text[end_index + len(self.end_token):]
92
+ return DeltaMessage(
93
+ reasoning_content=reasoning_content,
94
+ content=content if content else None,
95
+ )
96
+ elif self.end_token_id in previous_token_ids:
97
+ # [unused16] in previous, [unused17] in previous,
98
+ # reasoning content continues
99
+ return DeltaMessage(content=delta_text)
100
+ else:
101
+ # [unused16] in previous, no [unused17] in previous or delta,
102
+ # reasoning content continues
103
+ return DeltaMessage(reasoning_content=delta_text)
104
+ elif self.start_token_id in delta_token_ids:
105
+ if self.end_token_id in delta_token_ids:
106
+ # [unused16] in delta, [unused17] in delta, extract reasoning content
107
+ start_index = delta_text.find(self.start_token)
108
+ end_index = delta_text.find(self.end_token)
109
+ reasoning_content = delta_text[start_index +
110
+ len(self.start_token):end_index]
111
+ content = delta_text[end_index + len(self.end_token):]
112
+ return DeltaMessage(
113
+ reasoning_content=reasoning_content,
114
+ content=content if content else None,
115
+ )
116
+ else:
117
+ # [unused16] in delta, no [unused17] in delta,
118
+ # reasoning content continues
119
+ return DeltaMessage(reasoning_content=delta_text)
120
+ else:
121
+ # No [unused16] in previous or delta, also need to check for [unused17].
122
+ # Because the model may have generated [unused17] without [unused16]
123
+ if self.end_token_id in delta_token_ids:
124
+ # [unused17] in delta with more tokens,
125
+ # extract reasoning content and content
126
+ end_index = delta_text.find(self.end_token)
127
+ reasoning_content = delta_text[:end_index]
128
+ content = delta_text[end_index + len(self.end_token):]
129
+ return DeltaMessage(
130
+ reasoning_content=reasoning_content,
131
+ content=content if content else None,
132
+ )
133
+ elif self.end_token_id in previous_token_ids:
134
+ # [unused17] in previous, thinking content ends
135
+ return DeltaMessage(content=delta_text)
136
+ else:
137
+ # no [unused17] in previous or delta, reasoning content continues
138
+ return DeltaMessage(reasoning_content=delta_text)
139
+
140
+ def extract_reasoning_content(
141
+ self, model_output: str, request: ChatCompletionRequest
142
+ ) -> tuple[Optional[str], Optional[str]]:
143
+ """
144
+ Extract reasoning content from the model output.
145
+
146
+ For text [unused16]abc[unused17]xyz:
147
+ - 'abc' goes to reasoning_content
148
+ - 'xyz' goes to content
149
+
150
+ Returns:
151
+ tuple[Optional[str], Optional[str]]: reasoning content and content
152
+ """
153
+
154
+ # Check if the start token is present in the model output, remove it
155
+ # if it is present.
156
+ model_output_parts = model_output.partition(self.start_token)
157
+ model_output = model_output_parts[2] if model_output_parts[
158
+ 1] else model_output_parts[0]
159
+
160
+ # Thus we assume the reasoning content is always at the start.
161
+ if self.end_token not in model_output:
162
+ return model_output, None
163
+ else:
164
+ reasoning_content, _, content = model_output.partition(
165
+ self.end_token)
166
+ # If the end token is not found, return the model output as is.
167
+ # It should not happen since we already checked for the presence
168
+ # of the end token.
169
+ # If generation stops right after end-of-think, return null content
170
+ final_content = content or None
171
+ return reasoning_content, final_content
inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ from .pangu_tool_parser import PanguToolParser
3
+
4
+ __all__ = [
5
+ "PanguToolParser"
6
+ ]
inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ # Copyright 2023 The vLLM team.
3
+
4
+ import json
5
+ import re
6
+ from json import JSONDecodeError, JSONDecoder
7
+ from typing import Dict, List, Sequence, Union, Optional
8
+ from pydantic import Field
9
+ import partial_json_parser
10
+ from partial_json_parser.core.options import Allow
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from vllm.entrypoints.chat_utils import random_tool_call_id
14
+ from vllm.entrypoints.openai.tool_parsers.utils import (
15
+ extract_intermediate_diff)
16
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
17
+ DeltaFunctionCall, DeltaMessage,
18
+ DeltaToolCall,
19
+ ExtractedToolCallInformation,
20
+ FunctionCall, ToolCall,
21
+ )
22
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
23
+ ToolParser, ToolParserManager)
24
+ from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
25
+ is_complete_json)
26
+ from vllm.logger import init_logger
27
+ import os
28
+
29
+ logger = init_logger(__name__)
30
+
31
+
32
+ @ToolParserManager.register_module("pangu")
33
+ class PanguToolParser(ToolParser):
34
+
35
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, enable_reasoning=False):
36
+ super().__init__(tokenizer)
37
+
38
+ # initialize properties used for state when parsing tool calls in
39
+ # streaming mode
40
+ self.prev_tool_call_arr: List[Dict] = []
41
+ self.current_tool_id: int = -1
42
+ self.current_tool_name_sent: bool = False
43
+ self.streamed_args_for_tool: List[str] = [
44
+ ] # map what has been streamed for each tool so far to a list
45
+
46
+ self.tool_call_start_token = "[unused11]"
47
+ self.tool_call_end_token = "[unused12]"
48
+ self.pattern = re.escape(self.tool_call_start_token) \
49
+ + "(.*?)" + re.escape(self.tool_call_end_token)
50
+ self.tool_call_regex = re.compile(self.pattern, re.DOTALL)
51
+
52
+
53
+ self.tool_call_start_token_id = self.vocab.get(
54
+ self.tool_call_start_token)
55
+ self.tool_call_end_token_id = self.vocab.get(
56
+ self.tool_call_end_token)
57
+
58
+
59
+ if (self.tool_call_start_token_id is None
60
+ or self.tool_call_end_token_id is None):
61
+ raise RuntimeError(
62
+ "Pangu Tool parser could not locate tool calls start/end "
63
+ "tokens in the tokenizer!")
64
+ self.is_complete = []
65
+ self.text_after_start_token = ""
66
+
67
+
68
+ def extract_tool_calls(
69
+ self, model_output: str,
70
+ request: ChatCompletionRequest
71
+ ) -> ExtractedToolCallInformation:
72
+ """
73
+ Extract the tool calls from a complete model response.
74
+ """
75
+ # case -- if a tool call token is not present, return a text response
76
+ if not (self.tool_call_start_token in model_output and \
77
+ model_output.find(self.tool_call_end_token) != -1):
78
+ return ExtractedToolCallInformation(tools_called=False,
79
+ tool_calls=[],
80
+ content=model_output)
81
+
82
+ try:
83
+ raw_function_calls = []
84
+ # use a regex to find the tool call between the tags
85
+ function_call_tuples = self.tool_call_regex.findall(model_output)
86
+
87
+
88
+ # load the JSON, and then use it to build the Function and
89
+ # Tool Call
90
+ for function_call_str in function_call_tuples:
91
+ function_call = json.loads(function_call_str)
92
+ raw_function_calls.extend(function_call)
93
+
94
+
95
+ tool_calls: List[ToolCall] = [
96
+ ToolCall(
97
+ type="function",
98
+ function=FunctionCall(
99
+ name=function_call["name"],
100
+ # function call args are JSON but as a string
101
+ arguments=json.dumps(function_call["arguments"] \
102
+ if "arguments" in function_call \
103
+ else function_call["parameters"], ensure_ascii=False)))
104
+ for function_call in raw_function_calls
105
+ ]
106
+ content = model_output[:model_output.
107
+ find(self.tool_call_start_token)]
108
+
109
+ # get any content before the tool call
110
+ ret = ExtractedToolCallInformation(tools_called=True,
111
+ tool_calls=tool_calls,
112
+ content=content if content else None)
113
+
114
+ return ret
115
+
116
+ except Exception:
117
+ logger.exception("Error in extracting tool call from response.")
118
+ # return information to just treat the tool call as regular JSON
119
+ return ExtractedToolCallInformation(tools_called=False,
120
+ tool_calls=[],
121
+ content=model_output)
122
+
123
+ def extract_tool_calls_streaming(
124
+ self,
125
+ previous_text: str,
126
+ current_text: str,
127
+ delta_text: str,
128
+ previous_token_ids: Sequence[int],
129
+ current_token_ids: Sequence[int],
130
+ delta_token_ids: Sequence[int],
131
+ request: ChatCompletionRequest,
132
+ ) -> Union[DeltaMessage, None]:
133
+
134
+ if (self.tool_call_end_token_id in delta_token_ids
135
+ and len(delta_token_ids) == 1):
136
+ # if it's the only token, return None, so we don't send a chat
137
+ # completion and don't send a control token
138
+ return None
139
+
140
+ if (self.tool_call_end_token in current_text
141
+ and self.tool_call_end_token not in delta_text):
142
+ return DeltaMessage(content=delta_text)
143
+
144
+ if self.tool_call_start_token not in current_text:
145
+ return DeltaMessage(content=delta_text)
146
+
147
+ if self.tool_call_start_token in delta_text:
148
+ texts = delta_text.split(self.tool_call_start_token)
149
+ text_before_start_token = texts[0]
150
+ if text_before_start_token:
151
+ return DeltaMessage(content=text_before_start_token)
152
+
153
+ if (self.tool_call_start_token_id in delta_token_ids
154
+ and len(delta_token_ids) == 1):
155
+ # if it's the only token, return None, so we don't send a chat
156
+ # completion and don't send a control token
157
+ return None
158
+
159
+ # bit mask flags for partial JSON parsing. If the name hasn't been
160
+ # sent yet, don't allow sending
161
+ # an incomplete string since OpenAI only ever (as far as I have
162
+ # seen) allows sending the entire tool/ function name at once.
163
+ flags = Allow.ALL if self.current_tool_name_sent \
164
+ else Allow.ALL & ~Allow.STR
165
+ try:
166
+
167
+ tool_call_portion = current_text.split(
168
+ self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0]
169
+ try:
170
+ tool_call_arr: list[dict] = partial_json_parser.loads(
171
+ tool_call_portion, flags)
172
+
173
+ self.is_complete.append(
174
+ is_complete_json(tool_call_portion))
175
+ except partial_json_parser.core.exceptions.MalformedJSON:
176
+ logger.debug('not enough tokens to parse into JSON yet')
177
+ return None
178
+
179
+ # select as the current tool call the one we're on the state at
180
+ current_tool_call: dict = tool_call_arr[self.current_tool_id] \
181
+ if len(tool_call_arr) > 0 else {}
182
+
183
+ # case -- if no tokens have been streamed for the tool, e.g.
184
+ # only the array brackets, stream nothing
185
+ if len(tool_call_arr) == 0:
186
+ return None
187
+
188
+ # case: we are starting a new tool in the array
189
+ # -> array has > 0 length AND length has moved past cursor
190
+ elif (len(tool_call_arr) > 0
191
+ and len(tool_call_arr) > self.current_tool_id + 1):
192
+
193
+ # if we're moving on to a new call, first make sure we
194
+ # haven't missed anything in the previous one that was
195
+ # auto-generated due to JSON completions, but wasn't
196
+ # streamed to the client yet.
197
+ if self.current_tool_id >= 0:
198
+ cur_arguments = current_tool_call.get("arguments")
199
+ if cur_arguments:
200
+ cur_args_json = json.dumps(cur_arguments,
201
+ ensure_ascii=False)
202
+ sent = len(
203
+ self.streamed_args_for_tool[self.current_tool_id])
204
+ argument_diff = cur_args_json[sent:]
205
+
206
+ logger.debug("got arguments diff: %s", argument_diff)
207
+ delta = DeltaMessage(tool_calls=[
208
+ DeltaToolCall(index=self.current_tool_id,
209
+ function=DeltaFunctionCall(
210
+ arguments=argument_diff).
211
+ model_dump(exclude_none=True))
212
+ ])
213
+ self.streamed_args_for_tool[
214
+ self.current_tool_id] += argument_diff
215
+ else:
216
+ delta = None
217
+ else:
218
+ delta = None
219
+ # re-set stuff pertaining to progress in the current tool
220
+ self.current_tool_id = len(tool_call_arr) - 1
221
+ self.current_tool_name_sent = False
222
+ self.streamed_args_for_tool.append("")
223
+ self.is_complete = []
224
+ logger.debug("starting on new tool %d", self.current_tool_id)
225
+ return delta
226
+
227
+ # if the current tool name hasn't been sent, send if available
228
+ # - otherwise send nothing
229
+ elif not self.current_tool_name_sent:
230
+ function_name = current_tool_call.get("name")
231
+ if function_name:
232
+ delta = DeltaMessage(tool_calls=[
233
+ DeltaToolCall(index=self.current_tool_id,
234
+ type="function",
235
+ id=random_tool_call_id(),
236
+ function=DeltaFunctionCall(
237
+ name=function_name).model_dump(
238
+ exclude_none=True))
239
+ ])
240
+ self.current_tool_name_sent = True
241
+ else:
242
+ delta = None
243
+
244
+ # now we know we're on the same tool call and we're streaming
245
+ # arguments
246
+ else:
247
+ cur_arguments = current_tool_call.get("arguments")
248
+ delta = None
249
+ if (self.is_complete[-1] and not cur_arguments
250
+ and not self.streamed_args_for_tool[-1]):
251
+ argument_diff = "{}"
252
+ delta = DeltaMessage(tool_calls=[
253
+ DeltaToolCall(index=self.current_tool_id,
254
+ function=DeltaFunctionCall(
255
+ arguments=argument_diff).
256
+ model_dump(exclude_none=True))
257
+ ])
258
+ self.streamed_args_for_tool[
259
+ self.current_tool_id] += argument_diff
260
+
261
+ if cur_arguments:
262
+ sent = len(
263
+ self.streamed_args_for_tool[self.current_tool_id])
264
+ cur_args_json = json.dumps(cur_arguments,
265
+ ensure_ascii=False)
266
+ prev_arguments = self.prev_tool_call_arr[
267
+ self.current_tool_id].get("arguments")
268
+
269
+ argument_diff = None
270
+ if self.is_complete[-1]:
271
+ argument_diff = cur_args_json[sent:]
272
+ elif prev_arguments:
273
+ prev_args_json = json.dumps(prev_arguments,
274
+ ensure_ascii=False)
275
+ if cur_args_json != prev_args_json:
276
+
277
+ prefix = find_common_prefix(
278
+ prev_args_json, cur_args_json)
279
+ argument_diff = prefix[sent:]
280
+
281
+ if argument_diff is not None:
282
+ delta = DeltaMessage(tool_calls=[
283
+ DeltaToolCall(index=self.current_tool_id,
284
+ function=DeltaFunctionCall(
285
+ arguments=argument_diff).
286
+ model_dump(exclude_none=True))
287
+ ])
288
+ self.streamed_args_for_tool[
289
+ self.current_tool_id] += argument_diff
290
+
291
+
292
+ self.prev_tool_call_arr = tool_call_arr
293
+ return delta
294
+
295
+ except Exception:
296
+ logger.exception("Error trying to handle streaming tool call.")
297
+ logger.debug(
298
+ "Skipping chunk as a result of tool streaming extraction "
299
+ "error")
300
+ return None
inference/vllm_ascend/envs.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
6
+ # Copyright 2023 The vLLM team.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ #
20
+
21
+ import os
22
+ from typing import Any, Callable, Dict
23
+
24
+ # The begin-* and end* here are used by the documentation generator
25
+ # to extract the used env vars.
26
+
27
+ # begin-env-vars-definition
28
+
29
+ env_variables: Dict[str, Callable[[], Any]] = {
30
+ # max compile thread number for package building. Usually, it is set to
31
+ # the number of CPU cores. If not set, the default value is None, which
32
+ # means all number of CPU cores will be used.
33
+ "MAX_JOBS":
34
+ lambda: os.getenv("MAX_JOBS", None),
35
+ # The build type of the package. It can be one of the following values:
36
+ # Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
37
+ "CMAKE_BUILD_TYPE":
38
+ lambda: os.getenv("CMAKE_BUILD_TYPE"),
39
+ # Whether to compile custom kernels. If not set, the default value is True.
40
+ # If set to False, the custom kernels will not be compiled. Please note that
41
+ # the sleep mode feature will be disabled as well if custom kernels are not
42
+ # compiled.
43
+ "COMPILE_CUSTOM_KERNELS":
44
+ lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
45
+ # The CXX compiler used for compiling the package. If not set, the default
46
+ # value is None, which means the system default CXX compiler will be used.
47
+ "CXX_COMPILER":
48
+ lambda: os.getenv("CXX_COMPILER", None),
49
+ # The C compiler used for compiling the package. If not set, the default
50
+ # value is None, which means the system default C compiler will be used.
51
+ "C_COMPILER":
52
+ lambda: os.getenv("C_COMPILER", None),
53
+ # The version of the Ascend chip. If not set, the default value is
54
+ # ASCEND910B1. It's used for package building. Please make sure that the
55
+ # version is correct.
56
+ "SOC_VERSION":
57
+ lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
58
+ # If set, vllm-ascend will print verbose logs during compilation
59
+ "VERBOSE":
60
+ lambda: bool(int(os.getenv('VERBOSE', '0'))),
61
+ # The home path for CANN toolkit. If not set, the default value is
62
+ # /usr/local/Ascend/ascend-toolkit/latest
63
+ "ASCEND_HOME_PATH":
64
+ lambda: os.getenv("ASCEND_HOME_PATH", None),
65
+ # The path for HCCN Tool, the tool will be called by disaggregated prefilling
66
+ # case.
67
+ "HCCN_PATH":
68
+ lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
69
+ # The path for HCCL library, it's used by pyhccl communicator backend. If
70
+ # not set, the default value is libhccl.so。
71
+ "HCCL_SO_PATH":
72
+ # The prefill device id for disaggregated prefilling case.
73
+ lambda: os.environ.get("HCCL_SO_PATH", None),
74
+ "PROMPT_DEVICE_ID":
75
+ lambda: os.getenv("PROMPT_DEVICE_ID", None),
76
+ # The decode device id for disaggregated prefilling case.
77
+ "DECODE_DEVICE_ID":
78
+ lambda: os.getenv("DECODE_DEVICE_ID", None),
79
+ # The port number for llmdatadist communication. If not set, the default
80
+ # value is 26000.
81
+ "LLMDATADIST_COMM_PORT":
82
+ lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
83
+ # The wait time for llmdatadist sync cache. If not set, the default value is
84
+ # 5000ms.
85
+ "LLMDATADIST_SYNC_CACHE_WAIT_TIME":
86
+ lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
87
+ # The version of vllm is installed. This value is used for developers who
88
+ # installed vllm from source locally. In this case, the version of vllm is
89
+ # usually changed. For example, if the version of vllm is "0.9.0", but when
90
+ # it's installed from source, the version of vllm is usually set to "0.9.1".
91
+ # In this case, developers need to set this value to "0.9.0" to make sure
92
+ # that the correct package is installed.
93
+ "VLLM_VERSION":
94
+ lambda: os.getenv("VLLM_VERSION", None),
95
+ # Whether to enable the trace recompiles from pytorch.
96
+ "VLLM_ASCEND_TRACE_RECOMPILES":
97
+ lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
98
+ # Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
99
+ # GroupedMatmulFinalizeRouting operators are combined to implement EP.
100
+ "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
101
+ lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
102
+ ),
103
+ "VLLM_ASCEND_ENABLE_DBO":
104
+ lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
105
+ # Whether to enable the model execute time observe profile. Disable it when
106
+ # running vllm ascend in production environment.
107
+ "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
108
+ lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
109
+ ),
110
+ # MOE_ALL2ALL_BUFFER:
111
+ # 0: default, normal init.
112
+ # 1: enable moe_all2all_buffer.
113
+ "MOE_ALL2ALL_BUFFER":
114
+ lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
115
+ # Some models are optimized by vllm ascend. While in some case, e.g. rlhf
116
+ # training, the optimized model may not be suitable. In this case, set this
117
+ # value to False to disable the optimized model.
118
+ "USE_OPTIMIZED_MODEL":
119
+ lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
120
+ # SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
121
+ # In theory, it should have better performance than select_experts.
122
+ # Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
123
+ "SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
124
+ lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
125
+ # The tolerance of the kv cache size, if the difference between the
126
+ # actual kv cache size and the cached kv cache size is less than this value,
127
+ # then the cached kv cache size will be used.
128
+ "VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
129
+ lambda: int(
130
+ os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
131
+ # Whether to enable the topk optimization. It's disabled by default for experimental support
132
+ # We'll make it enabled by default in the future.
133
+ "VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
134
+ lambda: bool(
135
+ int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
136
+ # Whether to enable top n sigma sampling
137
+ "VLLM_ASCEND_ENABLE_TOP_N_SIGMA":
138
+ lambda: bool(
139
+ int(os.getenv("VLLM_ASCEND_ENABLE_TOP_N_SIGMA", '0'))),
140
+ }
141
+
142
+ # end-env-vars-definition
143
+
144
+
145
+ def __getattr__(name: str):
146
+ # lazy evaluation of environment variables
147
+ if name in env_variables:
148
+ return env_variables[name]()
149
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
150
+
151
+
152
+ def __dir__():
153
+ return list(env_variables.keys())
inference/vllm_ascend/models/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vllm import ModelRegistry
2
+
3
+ import vllm_ascend.envs as envs
4
+
5
+
6
+ def register_model():
7
+ from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
8
+ from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
9
+ from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
10
+ from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11
+ from .open_pangu import PanguUltraMoEForCausalLM # noqa: F401
12
+ from .open_pangu import PanguEmbeddedForCausalLM # noqa: F401
13
+ from .qwen2_5_vl import \
14
+ AscendQwen2_5_VLForConditionalGeneration # noqa: F401
15
+ from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
16
+
17
+ ModelRegistry.register_model(
18
+ "DeepSeekMTPModel",
19
+ "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
20
+
21
+ ModelRegistry.register_model(
22
+ "Qwen2VLForConditionalGeneration",
23
+ "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
24
+
25
+ if envs.USE_OPTIMIZED_MODEL:
26
+ ModelRegistry.register_model(
27
+ "Qwen2_5_VLForConditionalGeneration",
28
+ "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
29
+ )
30
+ else:
31
+ ModelRegistry.register_model(
32
+ "Qwen2_5_VLForConditionalGeneration",
33
+ "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
34
+ )
35
+
36
+ if envs.VLLM_ASCEND_ENABLE_DBO:
37
+ ModelRegistry.register_model(
38
+ "DeepseekV2ForCausalLM",
39
+ "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
40
+
41
+ ModelRegistry.register_model(
42
+ "DeepseekV3ForCausalLM",
43
+ "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
44
+
45
+ else:
46
+ ModelRegistry.register_model(
47
+ "DeepseekV2ForCausalLM",
48
+ "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
49
+
50
+ ModelRegistry.register_model(
51
+ "DeepseekV3ForCausalLM",
52
+ "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
53
+
54
+ ModelRegistry.register_model(
55
+ "Qwen3MoeForCausalLM",
56
+ "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
57
+
58
+ ModelRegistry.register_model(
59
+ "PanguProMoEForCausalLM",
60
+ "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
61
+
62
+ ModelRegistry.register_model(
63
+ "PanguUltraMoEForCausalLM",
64
+ "vllm_ascend.models.open_pangu:PanguUltraMoEForCausalLM")
65
+
66
+ ModelRegistry.register_model(
67
+ "PanguEmbeddedForCausalLM",
68
+ "vllm_ascend.models.open_pangu:PanguEmbeddedForCausalLM")
inference/vllm_ascend/models/open_pangu.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2023 The vLLM team.
4
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
24
+ import torch
25
+ import torch_npu
26
+ import vllm.envs as envs
27
+ from torch import nn
28
+ from transformers import PretrainedConfig
29
+ from vllm.compilation.decorators import support_torch_compile
30
+ from vllm.attention import Attention, AttentionMetadata, AttentionType
31
+ from vllm.config import CacheConfig, ModelConfig, VllmConfig
32
+ from vllm.distributed import (get_tensor_model_parallel_rank,
33
+ get_tensor_model_parallel_world_size,
34
+ get_tp_group, split_tensor_along_last_dim,
35
+ tensor_model_parallel_all_gather,
36
+ tensor_model_parallel_all_reduce,
37
+ tensor_model_parallel_reduce_scatter)
38
+ from vllm.distributed.parallel_state import get_dp_group
39
+ from vllm.forward_context import get_forward_context
40
+ from vllm.model_executor.layers.activation import SiluAndMul
41
+ from vllm.model_executor.layers.layernorm import RMSNorm
42
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
43
+ MergedColumnParallelLinear,
44
+ ReplicatedLinear,
45
+ RowParallelLinear,
46
+ UnquantizedLinearMethod,
47
+ QKVParallelLinear)
48
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
49
+ from vllm.model_executor.layers.quantization import QuantizationConfig
50
+ from vllm.model_executor.layers.rotary_embedding import get_rope, _rotate_gptj
51
+ from vllm.model_executor.layers.sampler import get_sampler
52
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
53
+ ParallelLMHead, VocabParallelEmbedding)
54
+ from vllm.model_executor.model_loader.weight_utils import (
55
+ default_weight_loader, maybe_remap_kv_scale_name)
56
+ from vllm.model_executor.models.utils import (
57
+ make_layers, maybe_prefix, extract_layer_index)
58
+ from vllm_ascend.ascend_config import get_ascend_config
59
+ from vllm_ascend.distributed.parallel_state import get_ep_group
60
+ from vllm_ascend.ops.fused_moe import AscendFusedMoE
61
+ from vllm_ascend.quantization.quant_config import AscendLinearMethod
62
+ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
63
+ from vllm_ascend.utils import dispose_tensor, npu_prefetch, get_fused_moe_state, FusedMoEState
64
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
65
+
66
+
67
+ class OpenPanguMergedReplicatedLinear(ReplicatedLinear):
68
+
69
+ def __init__(
70
+ self,
71
+ input_size: int,
72
+ output_sizes: list[int],
73
+ bias: bool = True,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ prefix: str = "",
76
+ ):
77
+ self.output_sizes = output_sizes
78
+ super().__init__(input_size,
79
+ sum(output_sizes),
80
+ bias=bias,
81
+ quant_config=quant_config,
82
+ prefix=prefix)
83
+
84
+ def weight_loader(self, param: torch.nn.Parameter,
85
+ loaded_weight: torch.Tensor, loaded_shard_id: int):
86
+ # With no support for GGUF format yet.
87
+ if getattr(param, "is_gguf_weight", False) or getattr(param, "is_gguf_weight_type", False):
88
+ raise ValueError('With no support for GGUF format yet.')
89
+ if loaded_shard_id >= len(self.output_sizes):
90
+ raise ValueError(f'loaded_shard_id {loaded_shard_id} >= len(self.output_sizes) {len(self.output_sizes)}.')
91
+ shard_offset = sum(self.output_sizes[:loaded_shard_id])
92
+ shard_size = self.output_sizes[loaded_shard_id]
93
+ shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
94
+ if shard.size() != loaded_weight.size():
95
+ raise ValueError(f"Tried to load weights of size {loaded_weight.size()} "
96
+ f"to a parameter shard of id {loaded_shard_id} size {shard.size()}.")
97
+ shard.copy_(loaded_weight)
98
+
99
+
100
+ class OpenPanguRowParallelLinearReplaceAllreduce(RowParallelLinear):
101
+
102
+ def forward(
103
+ self,
104
+ input_,
105
+ is_prefill=True
106
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
107
+ if self.input_is_parallel:
108
+ input_parallel = input_
109
+ else:
110
+ tp_rank = get_tensor_model_parallel_rank()
111
+ splitted_input = split_tensor_along_last_dim(
112
+ input_, num_partitions=self.tp_size)
113
+ input_parallel = splitted_input[tp_rank].contiguous()
114
+
115
+ # Matrix multiply.
116
+ if self.quant_method is None:
117
+ raise ValueError('self.quant_method is None.')
118
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
119
+ # bias will not get added more than once in TP>1 case)
120
+ bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
121
+ output_parallel = self.quant_method.apply(self,
122
+ input_parallel,
123
+ bias=bias_)
124
+ if self.reduce_results and self.tp_size > 1:
125
+ if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
126
+ output = tensor_model_parallel_reduce_scatter(output_parallel,
127
+ dim=0)
128
+ else:
129
+ output = tensor_model_parallel_all_reduce(output_parallel)
130
+ else:
131
+ output = output_parallel
132
+
133
+ output_bias = self.bias if self.skip_bias_add else None
134
+
135
+ if not self.return_bias:
136
+ return output
137
+ return output, output_bias
138
+
139
+
140
+ class OpenPanguRowParallelLinear(RowParallelLinear):
141
+
142
+ def forward(
143
+ self,
144
+ input_,
145
+ is_prefill=True
146
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
147
+ return super().forward(input_)
148
+
149
+
150
+ class OpenPanguRotaryEmbedding(nn.Module):
151
+ def __init__(self,
152
+ head_size: int,
153
+ rotary_dim: int,
154
+ max_position_embeddings: int,
155
+ base: float,
156
+ ):
157
+ super().__init__()
158
+ self.dim = rotary_dim
159
+ self.max_position_embeddings = max_position_embeddings
160
+ self.base = base
161
+ self._set_cos_sin_cache(
162
+ seq_len=max_position_embeddings,
163
+ device='npu',
164
+ dtype=torch.get_default_dtype(),
165
+ )
166
+
167
+ def _set_cos_sin_cache(self,
168
+ seq_len: int,
169
+ device: str,
170
+ dtype: torch.dtype
171
+ ):
172
+ self.max_seq_len = seq_len
173
+ inv_freq = 1.0 / (
174
+ self.base
175
+ ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device='npu') / self.dim)
176
+ )
177
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
178
+ t = torch.arange(seq_len, device='npu', dtype=torch.float32)
179
+ freqs = torch.outer(t, inv_freq)
180
+ emb = torch.cat((freqs, freqs), dim=-1)
181
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
182
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
183
+
184
+ def forward(self,
185
+ positions: torch.Tensor,
186
+ query: torch.Tensor,
187
+ key: torch.Tensor,
188
+ offsets: Optional[torch.Tensor] = None,
189
+ max_seq_len: Optional[int] = None,
190
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
191
+ if max_seq_len is not None and max_seq_len > self.max_seq_len:
192
+ self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
193
+ idx = torch.add(positions, offsets) if offsets is not None else positions
194
+ cos = self.cos_cached[idx]
195
+ sin = self.sin_cached[idx]
196
+ # Adapt: adapt cos and sin shape
197
+ cos = cos.view(-1, 1, cos.shape[-1])
198
+ sin = sin.view(-1, 1, sin.shape[-1])
199
+ # Adapt end.
200
+ query_rot = query * cos + _rotate_gptj(query) * sin
201
+ if key is not None:
202
+ key_rot = key * cos + _rotate_gptj(key) * sin
203
+ return query_rot, key_rot
204
+
205
+
206
+ class OpenPanguSiluAndMul(SiluAndMul):
207
+
208
+ def __init__(self,
209
+ *,
210
+ weight_scale: Optional[Callable[[], torch.Tensor]] = None):
211
+ super().__init__()
212
+ self.weight_scale = weight_scale
213
+
214
+ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
215
+ torch.Tensor]]):
216
+ if isinstance(x, tuple):
217
+ if self.weight_scale is None:
218
+ raise ValueError('self.weight_scale is None.')
219
+ quantized_x, dynamic_scale = x
220
+ return torch_npu.npu_dequant_swiglu_quant(
221
+ x=quantized_x,
222
+ weight_scale=self.weight_scale(),
223
+ activation_scale=dynamic_scale,
224
+ activate_left=True,
225
+ quant_mode=1)
226
+ else:
227
+ return super().forward_oot(x)
228
+
229
+
230
+ def check_ffn_act_fn(act_fn: str):
231
+ if act_fn != "silu":
232
+ raise ValueError(
233
+ f"Unsupported activation: {act_fn}. Only silu is supported for now.")
234
+
235
+
236
+ class OpenPanguMLP(nn.Module):
237
+
238
+ def __init__(
239
+ self,
240
+ hidden_size: int,
241
+ intermediate_size: int,
242
+ hidden_act: str,
243
+ quant_config: Optional[QuantizationConfig] = None,
244
+ bias: bool = False,
245
+ reduce_results: bool = True,
246
+ force_replicate: bool = False,
247
+ prefix: str = "",
248
+ ) -> None:
249
+ super().__init__()
250
+ if not force_replicate:
251
+ self.gate_up_proj = MergedColumnParallelLinear(
252
+ hidden_size, [intermediate_size] * 2,
253
+ bias=bias,
254
+ quant_config=quant_config,
255
+ prefix=f"{prefix}.gate_up_proj")
256
+ self.down_proj = RowParallelLinear(intermediate_size,
257
+ hidden_size,
258
+ bias=bias,
259
+ quant_config=quant_config,
260
+ reduce_results=reduce_results,
261
+ prefix=f"{prefix}.down_proj")
262
+ else:
263
+ self.gate_up_proj = OpenPanguMergedReplicatedLinear(
264
+ hidden_size, [intermediate_size] * 2,
265
+ bias=bias,
266
+ quant_config=quant_config,
267
+ prefix=f"{prefix}.gate_up_proj")
268
+ self.down_proj = ReplicatedLinear(intermediate_size,
269
+ hidden_size,
270
+ bias=bias,
271
+ quant_config=quant_config,
272
+ prefix=f"{prefix}.down_proj")
273
+
274
+ check_ffn_act_fn(hidden_act)
275
+
276
+ quant_method = self.gate_up_proj.quant_method
277
+ if isinstance(quant_method, UnquantizedLinearMethod):
278
+ self.act_fn = OpenPanguSiluAndMul()
279
+ elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
280
+ quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
281
+ # TODO(sdmyzlp): Currently preserved as before:
282
+ # 1. The only quantization supported for silu is W8A8Dynamic
283
+ # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
284
+ #
285
+ # Maybe one can implement a better and more general configuration
286
+ # scheme, e.g. by somehow passing around the tweaked `quant_config`
287
+ self.act_fn = OpenPanguSiluAndMul(
288
+ # Use lazy binding, for `weight_scale_fp32` is accessible
289
+ # only after `process_weights_after_loading`.
290
+ weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
291
+ # To be consumed by AscendW8A8DynamicLinearMethod.apply()
292
+ self.gate_up_proj._ascend_quant_config = {
293
+ "output_dtype": torch.int32,
294
+ "pertoken_scale": False,
295
+ "return_scale": True,
296
+ }
297
+ self.down_proj._ascend_quant_config = {
298
+ "output_dtype": torch.bfloat16,
299
+ "pertoken_scale": True,
300
+ "return_scale": False,
301
+ }
302
+ else:
303
+ raise NotImplementedError(
304
+ f"Quantization with [{type(quant_method)}] is NOT supported")
305
+
306
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
307
+ return self.down_proj(self.act_fn(self.gate_up_proj(x)[0]))[0]
308
+
309
+
310
+ class OpenPanguMoE(nn.Module):
311
+
312
+ top_k: int
313
+
314
+ def __init__(
315
+ self,
316
+ config: PretrainedConfig,
317
+ quant_config: Optional[QuantizationConfig] = None,
318
+ prefix: str = "",
319
+ ):
320
+ super().__init__()
321
+ ascend_config = get_ascend_config()
322
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
323
+ self.enable_multistream_moe = \
324
+ ascend_config.torchair_graph_config.enable_multistream_moe
325
+ self.routed_scaling_factor = config.routed_scaling_factor
326
+ check_ffn_act_fn(config.hidden_act)
327
+
328
+ self.gate = ReplicatedLinear(config.hidden_size,
329
+ config.num_routed_experts,
330
+ bias=False,
331
+ quant_config=None,
332
+ prefix=f"{prefix}.gate")
333
+
334
+ self.experts = AscendFusedMoE(
335
+ num_experts=config.num_routed_experts,
336
+ top_k=config.num_experts_per_tok,
337
+ hidden_size=config.hidden_size,
338
+ intermediate_size=config.moe_intermediate_size,
339
+ reduce_results=False,
340
+ renormalize=config.norm_topk_prob,
341
+ quant_config=quant_config,
342
+ use_grouped_topk=True,
343
+ num_expert_group=1,
344
+ topk_group=1,
345
+ prefix=f"{prefix}.experts",
346
+ scoring_func='sigmoid',
347
+ e_score_correction_bias=None)
348
+
349
+ if config.num_shared_experts is not None:
350
+ self.all_reduce_merge = self.experts.all_reduce_merge
351
+ reduce_results = not self.all_reduce_merge
352
+ intermediate_size = (config.moe_intermediate_size * config.num_shared_experts)
353
+ self.shared_experts = OpenPanguMLP(
354
+ hidden_size=config.hidden_size,
355
+ intermediate_size=intermediate_size,
356
+ hidden_act=config.hidden_act,
357
+ quant_config=quant_config,
358
+ reduce_results=reduce_results,
359
+ force_replicate=self.enable_multistream_moe,
360
+ prefix=f"{prefix}.shared_experts",
361
+ )
362
+ else:
363
+ self.shared_experts = None # type: ignore
364
+
365
+ self.tp_size = get_tensor_model_parallel_world_size()
366
+ self.dp_size = get_dp_group().world_size
367
+ self.tp_group = get_tp_group().device_group
368
+ self.tp_rank = get_tp_group().rank_in_group
369
+ self.ep_group = get_ep_group()
370
+
371
+ self.params_dtype = torch.get_default_dtype()
372
+ self.rm_router_logits = self.experts.rm_router_logits
373
+
374
+ self.__class__.top_k = config.num_experts_per_tok
375
+
376
+ def forward(self,
377
+ hidden_states: torch.Tensor,
378
+ attn_metadata: Optional[AttentionMetadata] = None,
379
+ replace_allreduce: bool = False) -> torch.Tensor:
380
+
381
+ if attn_metadata is None:
382
+ attn_metadata = get_forward_context().attn_metadata
383
+ # when profile runs, force experts to load balanced tokens
384
+ # to avoid high memory consumption on a single rank.
385
+ # TODO: need a better flag to indicate whether in profile run or not.
386
+ if attn_metadata is None:
387
+ # for profile run
388
+ is_prefill = True
389
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
390
+ enable_force_load_balance = fused_moe_state != FusedMoEState.AllGatherEP
391
+ else:
392
+ is_prefill = attn_metadata.num_prefills > 0
393
+ enable_force_load_balance = False
394
+ if hasattr(attn_metadata, 'with_prefill_across_dp'):
395
+ is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
396
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
397
+
398
+ # router_logits: (num_tokens, n_experts)
399
+ router_logits = None
400
+ if not self.rm_router_logits or fused_moe_state == FusedMoEState.All2All:
401
+ router_logits, _ = self.gate(hidden_states.float())
402
+
403
+ routed_hidden_states, shared_hidden_states = self.experts(
404
+ hidden_states=hidden_states,
405
+ router_logits=router_logits,
406
+ is_prefill=is_prefill,
407
+ top_k=self.__class__.top_k,
408
+ enable_force_load_balance=enable_force_load_balance,
409
+ shared_experts=self.shared_experts,
410
+ gate=self.gate,
411
+ replace_allreduce=replace_allreduce)
412
+
413
+ if self.all_reduce_merge and fused_moe_state == FusedMoEState.All2All:
414
+ shared_hidden_states = tensor_model_parallel_all_reduce(shared_hidden_states)
415
+ hidden_states = routed_hidden_states * self.routed_scaling_factor + shared_hidden_states
416
+ if self.all_reduce_merge and fused_moe_state != FusedMoEState.All2All:
417
+ # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
418
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
419
+
420
+ return hidden_states
421
+
422
+
423
+ class OpenPanguMLAAttention(nn.Module):
424
+
425
+ def __init__(
426
+ self,
427
+ config: PretrainedConfig,
428
+ hidden_size: int,
429
+ num_heads: int,
430
+ attention_qk_dim: int,
431
+ attention_qk_rope_dim: int,
432
+ attention_v_dim: int,
433
+ attention_q_lora_dim: Optional[int],
434
+ attention_kv_lora_dim: int,
435
+ rope_theta: float = 10000,
436
+ max_position_embeddings: int = 8192,
437
+ cache_config: Optional[CacheConfig] = None,
438
+ quant_config: Optional[QuantizationConfig] = None,
439
+ prefix: str = "",
440
+ ) -> None:
441
+ super().__init__()
442
+ ascend_config = get_ascend_config()
443
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
444
+ self.enable_multistream_mla = ascend_config.torchair_graph_config.enable_multistream_mla
445
+
446
+ self.hidden_size = hidden_size
447
+ self.num_heads = num_heads
448
+ self.attention_qk_dim = attention_qk_dim
449
+ self.attention_qk_rope_dim = attention_qk_rope_dim
450
+ self.qk_head_dim = attention_qk_dim + attention_qk_rope_dim
451
+ self.attention_v_dim = attention_v_dim
452
+ self.attention_q_lora_dim = attention_q_lora_dim
453
+ self.attention_kv_lora_dim = attention_kv_lora_dim
454
+ self.rope_theta = rope_theta
455
+
456
+ tp_size = get_tensor_model_parallel_world_size()
457
+ if num_heads % tp_size != 0:
458
+ raise ValueError(f'num_heads {num_heads} is not divisible by tp_size {tp_size}.')
459
+ self.num_local_heads = num_heads // tp_size
460
+
461
+ self.scaling = self.qk_head_dim**-0.5
462
+ self.max_position_embeddings = max_position_embeddings
463
+
464
+ self.prefix = prefix
465
+ self.debug_layer_idx = int(self.prefix.split(".")[-2])
466
+
467
+ if self.attention_q_lora_dim is not None:
468
+ self.q_a_proj = ReplicatedLinear(self.hidden_size,
469
+ self.attention_q_lora_dim,
470
+ bias=False,
471
+ quant_config=quant_config,
472
+ prefix=f"{prefix}.q_a_proj")
473
+ self.q_a_layernorm = RMSNorm(self.attention_q_lora_dim, eps=config.rms_norm_eps)
474
+ self.q_b_proj = ColumnParallelLinear(attention_q_lora_dim,
475
+ self.num_heads * self.qk_head_dim,
476
+ bias=False,
477
+ quant_config=quant_config,
478
+ prefix=f"{prefix}.q_b_proj")
479
+ else:
480
+ self.q_proj = ColumnParallelLinear(self.hidden_size,
481
+ self.num_heads * self.qk_head_dim,
482
+ bias=False,
483
+ quant_config=quant_config,
484
+ prefix=f"{prefix}.q_proj")
485
+
486
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
487
+ self.hidden_size,
488
+ self.attention_kv_lora_dim + self.attention_qk_rope_dim,
489
+ bias=False,
490
+ quant_config=quant_config,
491
+ prefix=f"{prefix}.kv_a_proj_with_mqa")
492
+ self.kv_a_layernorm = RMSNorm(self.attention_kv_lora_dim,
493
+ eps=config.rms_norm_eps)
494
+ self.kv_b_proj = ColumnParallelLinear(
495
+ self.attention_kv_lora_dim,
496
+ self.num_heads * (self.attention_qk_dim + self.attention_v_dim),
497
+ bias=False,
498
+ quant_config=quant_config,
499
+ prefix=f"{prefix}.kv_b_proj")
500
+ if (config.num_routed_experts is not None
501
+ and self.debug_layer_idx >= config.num_dense_layers and
502
+ ascend_config.torchair_graph_config.enable_multistream_moe):
503
+ self.o_proj = OpenPanguRowParallelLinearReplaceAllreduce(
504
+ self.num_heads * self.attention_v_dim,
505
+ self.hidden_size,
506
+ bias=False,
507
+ quant_config=quant_config,
508
+ prefix=f"{prefix}.o_proj")
509
+ else:
510
+ self.o_proj = OpenPanguRowParallelLinear(
511
+ self.num_heads * self.attention_v_dim,
512
+ self.hidden_size,
513
+ bias=False,
514
+ quant_config=quant_config,
515
+ prefix=f"{prefix}.o_proj")
516
+
517
+ self.rotary_emb = OpenPanguRotaryEmbedding(attention_qk_rope_dim,
518
+ rotary_dim=attention_qk_rope_dim,
519
+ max_position_embeddings=max_position_embeddings,
520
+ base=rope_theta)
521
+
522
+ self.mla_attn = Attention(
523
+ num_heads=self.num_local_heads,
524
+ head_size=self.attention_kv_lora_dim + self.attention_qk_rope_dim,
525
+ scale=self.scaling,
526
+ num_kv_heads=1,
527
+ cache_config=cache_config,
528
+ quant_config=quant_config,
529
+ prefix=f"{prefix}.attn",
530
+ use_mla=True,
531
+ # MLA Args
532
+ q_lora_rank=self.attention_q_lora_dim,
533
+ kv_lora_rank=self.attention_kv_lora_dim,
534
+ qk_nope_head_dim=self.attention_qk_dim,
535
+ qk_rope_head_dim=self.attention_qk_rope_dim,
536
+ qk_head_dim=self.qk_head_dim,
537
+ v_head_dim=self.attention_v_dim,
538
+ rotary_emb=self.rotary_emb,
539
+ q_proj=self.q_proj if self.attention_q_lora_dim is None else self.q_b_proj,
540
+ kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
541
+ kv_a_layernorm=self.kv_a_layernorm,
542
+ kv_b_proj=self.kv_b_proj,
543
+ o_proj=self.o_proj,
544
+ )
545
+
546
+ def forward(
547
+ self,
548
+ positions: torch.Tensor,
549
+ hidden_states: torch.Tensor,
550
+ kv_cache: Optional[torch.Tensor] = None,
551
+ attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
552
+ enable_multistream_mla = (self.enable_multistream_mla
553
+ and attn_metadata is not None
554
+ and not attn_metadata.with_prefill_across_dp
555
+ and attn_metadata.num_decodes > 0)
556
+ forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
557
+ if self.attention_q_lora_dim is not None:
558
+ npu_prefetch(self.q_a_proj.weight,
559
+ hidden_states,
560
+ enabled=enable_multistream_mla)
561
+ ckq = self.q_a_proj(hidden_states)[0]
562
+ hidden_states_or_q_c = self.q_a_layernorm(ckq)
563
+ forward_kwargs['ckq'] = ckq
564
+ else:
565
+ hidden_states_or_q_c = hidden_states
566
+ if self.torchair_graph_enabled:
567
+ if envs.VLLM_USE_V1:
568
+ output_shape = hidden_states.shape
569
+ output = torch.empty(output_shape,
570
+ dtype=hidden_states_or_q_c.dtype,
571
+ device=hidden_states_or_q_c.device)
572
+ forward_kwargs['output'] = output
573
+
574
+ output = self.mla_attn.impl.forward(self.mla_attn,
575
+ hidden_states_or_q_c,
576
+ hidden_states, None, kv_cache,
577
+ attn_metadata,
578
+ **forward_kwargs)
579
+ if envs.VLLM_USE_V1:
580
+ output = output.view(-1, output_shape[-1])
581
+ return output
582
+ else:
583
+ kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
584
+ [self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1)
585
+ kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
586
+ return self.mla_attn(hidden_states_or_q_c,
587
+ kv_c_normed,
588
+ k_pe,
589
+ output_shape=hidden_states.shape)
590
+
591
+
592
+ class OpenPanguEmbeddedAttention(nn.Module):
593
+
594
+ def __init__(
595
+ self,
596
+ config: PretrainedConfig,
597
+ hidden_size: int,
598
+ num_heads: int,
599
+ num_kv_heads: int,
600
+ rope_theta: float = 10000,
601
+ rope_scaling: Optional[dict[str, Any]] = None,
602
+ max_position_embeddings: int = 8192,
603
+ quant_config: Optional[QuantizationConfig] = None,
604
+ bias: bool = False,
605
+ bias_o_proj: bool = False,
606
+ cache_config: Optional[CacheConfig] = None,
607
+ prefix: str = "",
608
+ attn_type: str = AttentionType.DECODER,
609
+ ) -> None:
610
+ super().__init__()
611
+ layer_idx = extract_layer_index(prefix)
612
+ self.hidden_size = hidden_size
613
+ tp_size = get_tensor_model_parallel_world_size()
614
+ self.total_num_heads = num_heads
615
+ if self.total_num_heads % tp_size != 0:
616
+ raise ValueError(f'total_num_heads {total_num_heads} is not divisible by tp_size {tp_size}.')
617
+ self.num_heads = self.total_num_heads // tp_size
618
+ self.total_num_kv_heads = num_kv_heads
619
+ if self.total_num_kv_heads >= tp_size and self.total_num_kv_heads % tp_size != 0:
620
+ # Number of KV heads is greater than TP size, so we partition
621
+ # the KV heads across multiple tensor parallel NPUs.
622
+ raise ValueError(f'Number of KV heads is less than TP size, but total_num_kv_heads {self.total_num_kv_heads} '
623
+ f'is not divisible by tp_size {tp_size}.')
624
+ elif self.total_num_kv_heads < tp_size and tp_size % self.total_num_kv_heads != 0:
625
+ # Number of KV heads is less than TP size, so we replicate
626
+ # the KV heads across multiple tensor parallel NPUs.
627
+ raise ValueError(f'Number of KV heads is less than TP size, but tp_size {tp_size} '
628
+ f'is not divisible by total_num_kv_heads {self.total_num_kv_heads}.')
629
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
630
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
631
+ head_dim = getattr(config, "head_dim", None)
632
+ if head_dim is None:
633
+ head_dim = self.hidden_size // self.total_num_heads
634
+ self.head_dim = head_dim
635
+ # Phi models introduced a partial_rotary_factor parameter in the config
636
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
637
+ self.q_size = self.num_heads * self.head_dim
638
+ self.kv_size = self.num_kv_heads * self.head_dim
639
+ self.scaling = self.head_dim**-0.5
640
+ self.rope_theta = rope_theta
641
+ self.max_position_embeddings = max_position_embeddings
642
+
643
+ self.qkv_proj = QKVParallelLinear(
644
+ hidden_size=hidden_size,
645
+ head_size=self.head_dim,
646
+ total_num_heads=self.total_num_heads,
647
+ total_num_kv_heads=self.total_num_kv_heads,
648
+ bias=bias,
649
+ quant_config=quant_config,
650
+ prefix=f"{prefix}.qkv_proj",
651
+ )
652
+
653
+ self.o_proj = RowParallelLinear(
654
+ input_size=self.total_num_heads * self.head_dim,
655
+ output_size=hidden_size,
656
+ bias=bias_o_proj,
657
+ quant_config=quant_config,
658
+ prefix=f"{prefix}.o_proj",
659
+ )
660
+
661
+ self._init_rotary_emb(config,
662
+ rope_scaling=rope_scaling,
663
+ quant_config=quant_config)
664
+
665
+ if hasattr(config, "interleaved_sliding_window"):
666
+ interleaved_sliding_window = config.interleaved_sliding_window
667
+ if isinstance(interleaved_sliding_window, int):
668
+ sliding_window = interleaved_sliding_window
669
+ elif isinstance(interleaved_sliding_window, list):
670
+ sw_idx = layer_idx % len(interleaved_sliding_window)
671
+ sliding_window = interleaved_sliding_window[sw_idx]
672
+ else:
673
+ raise ValueError(
674
+ f"{type(interleaved_sliding_window)} is not supported.")
675
+ else:
676
+ sliding_window = None
677
+
678
+ self.attn = Attention(
679
+ self.num_heads,
680
+ self.head_dim,
681
+ self.scaling,
682
+ num_kv_heads=self.num_kv_heads,
683
+ cache_config=cache_config,
684
+ quant_config=quant_config,
685
+ per_layer_sliding_window=sliding_window,
686
+ attn_type=attn_type,
687
+ prefix=f"{prefix}.attn",
688
+ )
689
+
690
+ def forward(
691
+ self,
692
+ positions: torch.Tensor,
693
+ hidden_states: torch.Tensor,
694
+ kv_cache: Optional[torch.Tensor] = None,
695
+ attn_metadata: Optional[AttentionMetadata] = None
696
+ ) -> torch.Tensor:
697
+ qkv, _ = self.qkv_proj(hidden_states)
698
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
699
+ q, k = self.rotary_emb(positions, q, k)
700
+ attn_output = self.attn(q, k, v)
701
+ output, _ = self.o_proj(attn_output)
702
+ return output
703
+
704
+ def _init_rotary_emb(self, config: PretrainedConfig,
705
+ rope_scaling: Optional[dict[str, Any]],
706
+ quant_config: Optional[QuantizationConfig]) -> None:
707
+ is_neox_style = True
708
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
709
+ if is_gguf and config.model_type == "Pangu":
710
+ is_neox_style = False
711
+
712
+ self.rotary_emb = get_rope(
713
+ self.head_dim,
714
+ rotary_dim=self.head_dim,
715
+ max_position=self.max_position_embeddings,
716
+ base=self.rope_theta,
717
+ rope_scaling=rope_scaling,
718
+ is_neox_style=is_neox_style,
719
+ #partial_rotary_factor=self.partial_rotary_factor,
720
+ )
721
+
722
+
723
+ class OpenPanguDecoderLayer(nn.Module):
724
+
725
+ def __init__(
726
+ self,
727
+ config: PretrainedConfig,
728
+ prefix: str,
729
+ model_config: ModelConfig,
730
+ cache_config: Optional[CacheConfig] = None,
731
+ quant_config: Optional[QuantizationConfig] = None,
732
+ ) -> None:
733
+ super().__init__()
734
+ self.hidden_size = config.hidden_size
735
+ rope_theta = getattr(config, "rope_theta", 10000)
736
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
737
+
738
+ layer_idx = int(prefix.split(sep='.')[-1])
739
+ self.layer_idx = layer_idx
740
+ self.layers = config.num_hidden_layers
741
+ self.tp_size = get_tensor_model_parallel_world_size()
742
+ self.tp_rank = get_tp_group().rank_in_group
743
+ ascend_config = get_ascend_config()
744
+
745
+ self.use_mla = hasattr(config, 'attention_qk_dim') and hasattr(config, 'attention_qk_rope_dim') \
746
+ and hasattr(config, 'attention_v_dim') and hasattr(config, 'attention_kv_lora_dim')
747
+ if self.use_mla:
748
+ self.self_attn = OpenPanguMLAAttention(
749
+ config=config,
750
+ hidden_size=self.hidden_size,
751
+ num_heads=config.num_attention_heads,
752
+ attention_qk_dim=config.attention_qk_dim,
753
+ attention_qk_rope_dim=config.attention_qk_rope_dim,
754
+ attention_v_dim=config.attention_v_dim,
755
+ attention_q_lora_dim=config.attention_q_lora_dim
756
+ if hasattr(config, "attention_q_lora_dim") else None,
757
+ attention_kv_lora_dim=config.attention_kv_lora_dim,
758
+ rope_theta=rope_theta,
759
+ max_position_embeddings=max_position_embeddings,
760
+ cache_config=cache_config,
761
+ quant_config=quant_config,
762
+ prefix=f"{prefix}.self_attn",
763
+ )
764
+ else:
765
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
766
+ config, "bias", False)
767
+ bias_o_proj = attention_bias
768
+ if hasattr(config, 'qkv_bias'):
769
+ attention_bias = config.qkv_bias
770
+ # By default, PanguEmbedded uses causal attention as it is a decoder-only model.
771
+ # You can override the HF config with `is_causal=False` to enable
772
+ # bidirectional attention, which is used in some embedding models
773
+ if getattr(config, "is_causal", True):
774
+ attn_type = AttentionType.DECODER
775
+ else:
776
+ attn_type = AttentionType.ENCODER_ONLY
777
+ self.self_attn = OpenPanguEmbeddedAttention(
778
+ config=config,
779
+ hidden_size=self.hidden_size,
780
+ num_heads=config.num_attention_heads,
781
+ num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads),
782
+ rope_theta=rope_theta,
783
+ rope_scaling=getattr(config, "rope_scaling", None),
784
+ max_position_embeddings=max_position_embeddings,
785
+ quant_config=quant_config,
786
+ bias=attention_bias,
787
+ bias_o_proj=bias_o_proj,
788
+ cache_config=cache_config,
789
+ prefix=f"{prefix}.self_attn",
790
+ attn_type=attn_type,
791
+ )
792
+
793
+ if getattr(config, 'num_routed_experts', None) is not None and layer_idx >= config.num_dense_layers:
794
+ self.mlp = OpenPanguMoE(
795
+ config=config,
796
+ quant_config=quant_config,
797
+ prefix=f"{prefix}.mlp",
798
+ )
799
+ self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
800
+ and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
801
+ else:
802
+ self.mlp = OpenPanguMLP(
803
+ hidden_size=self.hidden_size,
804
+ intermediate_size=config.intermediate_size,
805
+ hidden_act=config.hidden_act,
806
+ quant_config=quant_config,
807
+ bias=getattr(config, "mlp_bias", False),
808
+ prefix=f"{prefix}.mlp",
809
+ )
810
+ self.mla_moe_communication = False
811
+ self.routed_scaling_factor = getattr(config, 'routed_scaling_factor', None)
812
+ self.num_dense_layers = getattr(config, 'num_dense_layers', None)
813
+
814
+ self.input_layernorm = RMSNorm(config.hidden_size,
815
+ eps=config.rms_norm_eps)
816
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
817
+ eps=config.rms_norm_eps)
818
+ if getattr(config, 'sandwich_norm', False):
819
+ self.sandwich_norm = True
820
+ self.pre_mlp_layernorm = RMSNorm(config.hidden_size,
821
+ eps=config.rms_norm_eps)
822
+ self.post_mlp_layernorm = RMSNorm(config.hidden_size,
823
+ eps=config.rms_norm_eps)
824
+ else:
825
+ self.sandwich_norm = False
826
+
827
+ def forward(
828
+ self,
829
+ positions: torch.Tensor,
830
+ hidden_states: torch.Tensor,
831
+ residual: Optional[torch.Tensor],
832
+ kv_cache: Optional[torch.Tensor] = None,
833
+ attn_metadata: Optional[AttentionMetadata] = None,
834
+ replace_allreduce: bool = False,
835
+ ) -> torch.Tensor:
836
+ # Self Attention
837
+ if self.use_mla and attn_metadata is not None and attn_metadata.num_decodes > 0:
838
+ mla_moe_communication = self.mla_moe_communication and replace_allreduce
839
+ else:
840
+ mla_moe_communication = False
841
+ if residual is None:
842
+ residual = hidden_states
843
+ hidden_states = self.input_layernorm(hidden_states)
844
+ else:
845
+ previous_hidden_states, previous_residual = hidden_states, residual
846
+ hidden_states, residual = self.input_layernorm(
847
+ hidden_states, residual)
848
+ # Dispose hidden_states and residual from the previous layer
849
+ # to save npu memory because they're no longer used.
850
+ dispose_tensor(previous_hidden_states)
851
+ dispose_tensor(previous_residual)
852
+ if mla_moe_communication and self.layer_idx > self.num_dense_layers:
853
+ hidden_states = tensor_model_parallel_all_gather(hidden_states,
854
+ dim=0)
855
+
856
+ hidden_states = self.self_attn(
857
+ positions=positions,
858
+ hidden_states=hidden_states,
859
+ kv_cache=kv_cache,
860
+ attn_metadata=attn_metadata,
861
+ )
862
+
863
+ if mla_moe_communication and residual.shape[0] != hidden_states.shape[0]:
864
+ chunk_hidden_states = torch.tensor_split(residual,
865
+ self.tp_size,
866
+ dim=0)
867
+ residual = chunk_hidden_states[self.tp_rank]
868
+
869
+ if self.routed_scaling_factor is not None and hidden_states.dtype == torch.float16:
870
+ # Fix FP16 overflow
871
+ # We scale both hidden_states and residual before
872
+ # rmsnorm, and rmsnorm result would not affect by scale.
873
+ hidden_states *= 1. / self.routed_scaling_factor
874
+ if self.layer_idx == 0:
875
+ # The residual is shared by all layers, we only scale it on
876
+ # first layer.
877
+ residual *= 1. / self.routed_scaling_factor
878
+
879
+ if self.sandwich_norm:
880
+ hidden_states = self.post_attention_layernorm(
881
+ hidden_states)
882
+ hidden_states, residual = self.pre_mlp_layernorm(
883
+ hidden_states, residual)
884
+ else:
885
+ hidden_states, residual = self.post_attention_layernorm(
886
+ hidden_states, residual)
887
+
888
+ # Fully Connected
889
+ if isinstance(self.mlp, OpenPanguMoE):
890
+ hidden_states = self.mlp(hidden_states,
891
+ attn_metadata,
892
+ replace_allreduce=mla_moe_communication)
893
+ else:
894
+ hidden_states = self.mlp(hidden_states)
895
+
896
+ if self.routed_scaling_factor is not None and isinstance(self.mlp, OpenPanguMLP) \
897
+ and hidden_states.dtype == torch.float16:
898
+ hidden_states *= 1. / self.routed_scaling_factor
899
+
900
+ if self.sandwich_norm:
901
+ hidden_states = self.post_mlp_layernorm(hidden_states)
902
+
903
+ if mla_moe_communication and self.layer_idx == self.layers - 1:
904
+ hidden_states = tensor_model_parallel_all_gather(hidden_states,
905
+ dim=0)
906
+ residual = tensor_model_parallel_all_gather(residual, dim=0)
907
+
908
+ return hidden_states, residual
909
+
910
+
911
+ @support_torch_compile
912
+ class OpenPanguModel(nn.Module):
913
+
914
+ fall_back_to_pt_during_load = False
915
+
916
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
917
+ super().__init__()
918
+
919
+ config = vllm_config.model_config.hf_config
920
+ model_config = vllm_config.model_config
921
+ cache_config = vllm_config.cache_config
922
+ quant_config = vllm_config.quant_config
923
+
924
+ self.padding_idx = config.pad_token_id
925
+ self.vocab_size = config.vocab_size
926
+ self.tp_size = get_tensor_model_parallel_world_size()
927
+
928
+ self.embed_tokens = VocabParallelEmbedding(
929
+ config.vocab_size,
930
+ config.hidden_size,
931
+ quant_config=quant_config,
932
+ prefix=f"{prefix}.embed_tokens")
933
+
934
+ self.start_layer, self.end_layer, self.layers = make_layers(
935
+ config.num_hidden_layers,
936
+ lambda prefix: OpenPanguDecoderLayer(
937
+ config,
938
+ prefix,
939
+ model_config=model_config,
940
+ cache_config=cache_config,
941
+ quant_config=quant_config,
942
+ ),
943
+ prefix=f"{prefix}.layers")
944
+
945
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
946
+
947
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
948
+ return self.embed_tokens(input_ids)
949
+
950
+ def forward(
951
+ self,
952
+ input_ids: torch.Tensor,
953
+ positions: torch.Tensor,
954
+ kv_caches: Optional[List[torch.Tensor]] = None,
955
+ attn_metadata: Optional[AttentionMetadata] = None,
956
+ inputs_embeds: Optional[torch.Tensor] = None,
957
+ **kwargs,
958
+ ) -> torch.Tensor:
959
+ if inputs_embeds is not None:
960
+ hidden_states = inputs_embeds
961
+ else:
962
+ hidden_states = self.get_input_embeddings(input_ids)
963
+ residual = None
964
+
965
+ replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
966
+
967
+ for i in range(self.start_layer, self.end_layer):
968
+ layer = self.layers[i]
969
+ hidden_states, residual = layer(
970
+ positions,
971
+ hidden_states,
972
+ residual,
973
+ kv_caches[i -
974
+ self.start_layer] if kv_caches is not None else None,
975
+ attn_metadata,
976
+ replace_allreduce=replace_allreduce)
977
+
978
+ hidden_states, _ = self.norm(hidden_states, residual)
979
+ return hidden_states
980
+
981
+
982
+ class OpenPanguForCausalLM(nn.Module):
983
+ packed_modules_mapping = {
984
+ "gate_up_proj": ["gate_proj", "up_proj"],
985
+ "experts":
986
+ ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
987
+ }
988
+
989
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
990
+ super().__init__()
991
+ config = vllm_config.model_config.hf_config
992
+ quant_config = vllm_config.quant_config
993
+ self.config = config
994
+ self.quant_config = quant_config
995
+ self.model = OpenPanguModel(vllm_config=vllm_config,
996
+ prefix=maybe_prefix(prefix, "model"))
997
+ self.lm_head = ParallelLMHead(config.vocab_size,
998
+ config.hidden_size,
999
+ quant_config=quant_config,
1000
+ prefix=maybe_prefix(prefix, "lm_head"))
1001
+ self.logits_processor = LogitsProcessor(config.vocab_size)
1002
+ self.sampler = get_sampler()
1003
+
1004
+ def load_attn_mlp_weight(self,
1005
+ attn_mlp_replace_mapping: List[Tuple[str, str, int]],
1006
+ params_dict: Dict[str, Any],
1007
+ weight_name: str,
1008
+ loaded_weight: torch.Tensor,
1009
+ loaded_params: set[str]) -> bool:
1010
+ for (param_name, origin_name, shard_id) in attn_mlp_replace_mapping:
1011
+ if origin_name not in weight_name or \
1012
+ (("mlp.experts." in weight_name) and weight_name not in params_dict):
1013
+ continue
1014
+ weight_name = weight_name.replace(origin_name, param_name)
1015
+ if weight_name.endswith(".bias") and weight_name not in params_dict:
1016
+ continue
1017
+ param = params_dict[weight_name]
1018
+ weight_loader = param.weight_loader
1019
+ weight_loader(param, loaded_weight, shard_id)
1020
+ loaded_params.add(weight_name)
1021
+ return True
1022
+ return False
1023
+
1024
+ def load_expert_weight(self,
1025
+ expert_merge_mapping: List[Tuple[str, str, int, str]],
1026
+ params_dict: Dict[str, Any],
1027
+ weight_name: str,
1028
+ loaded_weight: torch.Tensor,
1029
+ loaded_params: set[str]) -> bool:
1030
+ for mapping in expert_merge_mapping:
1031
+ param_name, origin_name, expert_id, shard_id = mapping
1032
+ if origin_name not in weight_name:
1033
+ continue
1034
+ weight_name = weight_name.replace(origin_name, param_name)
1035
+ param = params_dict[weight_name]
1036
+ weight_loader = param.weight_loader
1037
+ weight_loader(param,
1038
+ loaded_weight,
1039
+ weight_name,
1040
+ shard_id=shard_id,
1041
+ expert_id=expert_id,
1042
+ return_success=False)
1043
+ loaded_params.add(weight_name)
1044
+ return True
1045
+ return False
1046
+
1047
+ def load_weights(self, weights: Iterable[tuple[str,
1048
+ torch.Tensor]]) -> set[str]:
1049
+ # (param_name, shard_name, shard_id)
1050
+ attn_mlp_replace_mapping = [
1051
+ (".qkv_proj", ".q_proj", "q"),
1052
+ (".qkv_proj", ".k_proj", "k"),
1053
+ (".qkv_proj", ".v_proj", "v"),
1054
+ (".gate_up_proj", ".gate_proj", 0),
1055
+ (".gate_up_proj", ".up_proj", 1),
1056
+ ]
1057
+ has_experts = hasattr(self.config, 'num_routed_experts')
1058
+ if has_experts:
1059
+ expert_merge_mapping = AscendFusedMoE.make_expert_params_mapping(
1060
+ ckpt_gate_proj_name="gate_proj",
1061
+ ckpt_down_proj_name="down_proj",
1062
+ ckpt_up_proj_name="up_proj",
1063
+ num_experts=self.config.num_routed_experts)
1064
+
1065
+ params_dict = dict(self.named_parameters())
1066
+ loaded_params: set[str] = set()
1067
+ for name, loaded_weight in weights:
1068
+ if "rotary_emb.inv_freq" in name:
1069
+ continue
1070
+ if 'layers' in name: # skip spec decode layers for main model
1071
+ layer_idx = int(name.split('layers.')[-1].split('.')[0])
1072
+ if layer_idx > self.config.num_hidden_layers:
1073
+ continue
1074
+
1075
+ if 'layers' in name and hasattr(self.config, "num_mtp_layers") \
1076
+ and (self.config.num_mtp_layers > 0):
1077
+ layer_idx = int(name.split('layers.')[-1].split('.')[0])
1078
+ mtp_idx = layer_idx - self.config.num_hidden_layers
1079
+ if mtp_idx >= 0 and mtp_idx < self.config.num_mtp_layers:
1080
+ continue # skip spec decode layers for main model
1081
+ if self.load_attn_mlp_weight(attn_mlp_replace_mapping, params_dict, name, loaded_weight, loaded_params):
1082
+ continue
1083
+ elif has_experts and self.load_expert_weight(expert_merge_mapping, params_dict, name, loaded_weight, loaded_params):
1084
+ continue
1085
+ else:
1086
+ if name.endswith(".bias") and name not in params_dict:
1087
+ continue
1088
+ name = maybe_remap_kv_scale_name(name, params_dict)
1089
+ if name is None:
1090
+ continue
1091
+ param = params_dict[name]
1092
+ weight_loader = getattr(param, "weight_loader",
1093
+ default_weight_loader)
1094
+ weight_loader(param, loaded_weight)
1095
+ loaded_params.add(name)
1096
+ if self.config.tie_word_embeddings:
1097
+ self.lm_head.weight = self.model.embed_tokens.weight
1098
+ return loaded_params
1099
+
1100
+ def forward(
1101
+ self,
1102
+ input_ids: torch.Tensor,
1103
+ positions: torch.Tensor,
1104
+ kv_caches: Optional[List[torch.Tensor]] = None,
1105
+ attn_metadata: Optional[AttentionMetadata] = None,
1106
+ inputs_embeds: Optional[torch.Tensor] = None,
1107
+ **kwargs,
1108
+ ) -> torch.Tensor:
1109
+ hidden_states = self.model(input_ids, positions, kv_caches,
1110
+ attn_metadata, inputs_embeds)
1111
+ return hidden_states
1112
+
1113
+ def compute_logits(
1114
+ self,
1115
+ hidden_states: torch.Tensor,
1116
+ sampling_metadata: SamplingMetadata,
1117
+ ) -> Optional[torch.Tensor]:
1118
+ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
1119
+ return logits
1120
+
1121
+
1122
+ class PanguUltraMoEForCausalLM(OpenPanguForCausalLM):
1123
+ pass
1124
+
1125
+
1126
+ class PanguEmbeddedForCausalLM(OpenPanguForCausalLM):
1127
+ pass
inference/vllm_ascend/ops/fused_moe.py ADDED
@@ -0,0 +1,1530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ # Copyright 2023 The vLLM team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # This file is a part of the vllm-ascend project.
16
+ # Adapted from vllm/tests/kernels/test_moe.py
17
+
18
+ import os
19
+ from typing import Any, Callable, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ import torch_npu
24
+ from torch import nn
25
+ from vllm.config import get_current_vllm_config
26
+ from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ tensor_model_parallel_all_reduce)
29
+ from vllm.distributed.parallel_state import get_dp_group, get_tp_group
30
+ from vllm.forward_context import get_forward_context
31
+ from vllm.model_executor.layers.fused_moe.config import \
32
+ FusedMoEConfig # isort: skip
33
+ from vllm.model_executor.layers.fused_moe.config import \
34
+ FusedMoEParallelConfig # isort: skip
35
+ from vllm.model_executor.layers.fused_moe.layer import (
36
+ FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
37
+ from vllm.model_executor.layers.quantization.base_config import \
38
+ QuantizationConfig
39
+
40
+ import vllm_ascend.envs as envs_ascend
41
+ from vllm_ascend.ascend_config import get_ascend_config
42
+ from vllm_ascend.distributed.communication_op import \
43
+ data_parallel_reduce_scatter
44
+ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
45
+ from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
46
+ from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
47
+ get_all_reduce_merge_state, get_fused_moe_state,
48
+ get_rm_router_logits_state, is_310p,
49
+ npu_stream_switch, npu_wait_tensor)
50
+
51
+ MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
52
+ SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
53
+
54
+
55
+ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
56
+ max_row_per_ep_rank: int, num_tokens: int,
57
+ top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
58
+ original_total_elements = num_tokens * top_k
59
+ device = topk_ids.device
60
+ original_dtype = topk_ids.dtype
61
+
62
+ if original_total_elements == 0:
63
+ output_len = ep_size * max_row_per_ep_rank
64
+ topk_ids_pad = torch.full((output_len, ),
65
+ expert_num,
66
+ dtype=original_dtype,
67
+ device=device)
68
+ unpad_indices = torch.full((original_total_elements, ),
69
+ -1,
70
+ dtype=torch.long,
71
+ device=device)
72
+ return topk_ids_pad, unpad_indices
73
+
74
+ experts_per_ep_rank_val = expert_num // ep_size
75
+ if experts_per_ep_rank_val == 0:
76
+ raise ValueError(
77
+ "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
78
+ "Ensure expert_num >= ep_size.")
79
+
80
+ assigned_ep_rank = (topk_ids.float() /
81
+ experts_per_ep_rank_val).to(original_dtype)
82
+ indices_arange = torch.arange(topk_ids.shape[0], device=device)
83
+
84
+ is_new_segment = torch.cat(
85
+ (torch.tensor([True], device=device), assigned_ep_rank[1:]
86
+ != assigned_ep_rank[:-1]))
87
+ temp_start_markers = torch.full_like(indices_arange,
88
+ -1,
89
+ dtype=indices_arange.dtype)
90
+ temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
91
+ start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
92
+ token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
93
+ is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
94
+ cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
95
+ indices_in_rec_cond_list_for_all = cumsum_kept - 1
96
+ unpad_indices = torch.where(
97
+ is_kept_mask, indices_in_rec_cond_list_for_all,
98
+ torch.tensor(-1, device=device, dtype=torch.long))
99
+ output_len = ep_size * max_row_per_ep_rank
100
+ topk_ids_pad = torch.full((output_len, ),
101
+ expert_num,
102
+ dtype=original_dtype,
103
+ device=device)
104
+ if topk_ids.shape[0] > 0:
105
+ all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
106
+ temp_pad_buffer = torch.full((output_len + 1, ),
107
+ expert_num,
108
+ dtype=original_dtype,
109
+ device=device)
110
+ output_len_tensor = torch.tensor(output_len,
111
+ dtype=torch.long,
112
+ device=device)
113
+ scatter_indices = torch.where(is_kept_mask, all_destination_indices,
114
+ output_len_tensor)
115
+ temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
116
+ topk_ids_pad = temp_pad_buffer[:output_len]
117
+ return topk_ids_pad, unpad_indices
118
+
119
+
120
+ def fused_experts_with_mc2(
121
+ hidden_states: torch.Tensor,
122
+ w1: torch.Tensor,
123
+ w2: torch.Tensor,
124
+ topk_weights: torch.Tensor,
125
+ topk_ids: torch.Tensor,
126
+ top_k: int,
127
+ expert_map: torch.Tensor = None,
128
+ moe_all_to_all_group_name: Optional[str] = None,
129
+ shared_experts: Optional[Any] = None
130
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
131
+ global_bs = 0
132
+ moe_expert_num = len(expert_map)
133
+ kwargs_mc2 = {
134
+ "x": hidden_states,
135
+ "expert_ids": topk_ids,
136
+ "expert_shard_type": 0,
137
+ "shared_expert_rank_num": 0,
138
+ "moe_expert_num": moe_expert_num,
139
+ "global_bs": global_bs,
140
+ }
141
+
142
+ rank = torch.distributed.get_rank()
143
+
144
+ quant_mode = 0
145
+ ep_group = get_ep_group().device_group
146
+ local_rank = torch.distributed.get_rank(group=ep_group)
147
+ all_to_all_group_size = torch.distributed.get_world_size(ep_group)
148
+
149
+ tp_size = get_etp_group().world_size
150
+ tp_rank = rank % tp_size
151
+
152
+ stage1_kwargs = {
153
+ "scales": None,
154
+ "quant_mode": quant_mode,
155
+ "group_ep": moe_all_to_all_group_name,
156
+ "ep_world_size": all_to_all_group_size,
157
+ "ep_rank_id": local_rank,
158
+ # "group_tp": self.moe_rs_group_name,
159
+ "group_tp": moe_all_to_all_group_name,
160
+ "tp_world_size": tp_size,
161
+ "tp_rank_id": tp_rank,
162
+ }
163
+ kwargs_mc2.update(stage1_kwargs)
164
+
165
+ output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
166
+ expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
167
+ 0:5]
168
+
169
+ if shared_experts is not None:
170
+ with npu_stream_switch("moe_secondary", 0):
171
+ npu_wait_tensor(hidden_states, topk_weights)
172
+ shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
173
+ npu_wait_tensor(shared_gate_up, expand_x)
174
+ shared_act = shared_experts.act_fn(shared_gate_up)
175
+
176
+ w1 = w1.transpose(1, 2)
177
+
178
+ group_list = expert_token_nums.to(torch.int64)
179
+ gate_up_out_list = torch_npu.npu_grouped_matmul(
180
+ x=[expand_x],
181
+ weight=[w1],
182
+ split_item=2,
183
+ # 1 means count mode, to avoid cumulative operation of the group list
184
+ group_list_type=1,
185
+ group_type=0,
186
+ group_list=group_list,
187
+ )
188
+
189
+ # TODO: Remove this in the future.
190
+ gate_up_out = torch.cat(gate_up_out_list, dim=0)
191
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
192
+
193
+ w2 = w2.transpose(1, 2)
194
+ down_out_list = torch_npu.npu_grouped_matmul(
195
+ x=[gate_up_out],
196
+ weight=[w2],
197
+ split_item=2,
198
+ group_list_type=1,
199
+ group_type=0,
200
+ group_list=group_list,
201
+ )
202
+
203
+ down_out_list = torch.cat(down_out_list, dim=0)
204
+
205
+ # moeCombine
206
+ kwargs_mc2 = {
207
+ "expand_x": down_out_list,
208
+ "expert_ids": topk_ids,
209
+ "expand_idx": expand_idx,
210
+ "expert_scales": topk_weights.to(torch.float32),
211
+ "expert_shard_type": 0,
212
+ "shared_expert_rank_num": 0,
213
+ "moe_expert_num": moe_expert_num,
214
+ "global_bs": 0,
215
+ }
216
+ tp_recv_counts = output[5]
217
+ stage3_kwargs = {
218
+ "ep_send_counts": ep_recv_counts,
219
+ "group_ep": moe_all_to_all_group_name,
220
+ "ep_world_size": all_to_all_group_size,
221
+ "ep_rank_id": local_rank,
222
+ "tp_send_counts": tp_recv_counts,
223
+ # "group_tp": self.moe_rs_group_name,
224
+ "group_tp": moe_all_to_all_group_name,
225
+ "tp_world_size": tp_size,
226
+ "tp_rank_id": tp_rank,
227
+ }
228
+ kwargs_mc2.update(stage3_kwargs)
229
+
230
+ hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
231
+
232
+ if shared_experts is None:
233
+ return hidden_states
234
+ else:
235
+ with npu_stream_switch("moe_secondary", 0):
236
+ npu_wait_tensor(shared_act, down_out_list)
237
+ shared_hidden_states, _ = shared_experts.down_proj(shared_act)
238
+ return hidden_states, shared_hidden_states
239
+
240
+
241
+ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
242
+ w1: torch.Tensor,
243
+ w2: torch.Tensor,
244
+ group_list: torch.Tensor,
245
+ group_list_type: int = 1) -> torch.Tensor:
246
+ """
247
+ apply MLP: gate_up_proj -> swiglu -> down_proj
248
+
249
+ Args:
250
+ hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
251
+ w1: expert weights1 with shape
252
+ (num_experts, hidden_size, intermediate_size * 2)
253
+ w2: expert weights2 with shape
254
+ (num_experts, intermediate_size, hidden_size)
255
+ group_list: number of tokens for each expert, follow cumsum mode, and
256
+ with shape (num_experts).
257
+ transpose_weight:
258
+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
259
+ (num_experts, hidden_size, intermediate_size * 2)
260
+ w2: (num_experts, hidden_size, intermediate_size) ->
261
+ (num_experts, intermediate_size, hidden_size)
262
+
263
+ Returns:
264
+ hidden_states: output hidden states after MLP.
265
+ """
266
+
267
+ assert len(hidden_states_wrapper) == 1
268
+ hidden_states = hidden_states_wrapper.pop()
269
+
270
+ w1 = w1.transpose(1, 2)
271
+ hidden_states = torch_npu.npu_grouped_matmul(
272
+ x=[hidden_states],
273
+ weight=[w1],
274
+ split_item=2,
275
+ group_list_type=group_list_type,
276
+ group_type=0,
277
+ group_list=group_list,
278
+ )
279
+
280
+ hidden_states = torch.cat(hidden_states, dim=0)
281
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
282
+
283
+ w2 = w2.transpose(1, 2)
284
+ hidden_states = torch_npu.npu_grouped_matmul(
285
+ x=[hidden_states],
286
+ weight=[w2],
287
+ split_item=2,
288
+ group_list_type=group_list_type,
289
+ group_type=0,
290
+ group_list=group_list,
291
+ )
292
+
293
+ hidden_states = torch.cat(hidden_states, dim=0)
294
+ return hidden_states
295
+
296
+
297
+ def fused_experts_with_all2all(
298
+ hidden_states: torch.Tensor,
299
+ w1: torch.Tensor,
300
+ w2: torch.Tensor,
301
+ topk_weights: torch.Tensor,
302
+ topk_ids: torch.Tensor,
303
+ top_k: int,
304
+ expert_map: torch.Tensor = None,
305
+ ep_group: GroupCoordinator = None,
306
+ ):
307
+ original_shape = hidden_states.shape
308
+ if len(original_shape) == 3:
309
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
310
+
311
+ num_tokens, _ = hidden_states.shape
312
+ num_experts = w1.shape[0]
313
+ device = hidden_states.device
314
+
315
+ if expert_map is not None:
316
+ global_num_experts = len(expert_map)
317
+ local_num_experts = global_num_experts // ep_group.world_size
318
+ row_idx_len = num_tokens * top_k
319
+ row_idx = (torch.arange(0,
320
+ row_idx_len,
321
+ dtype=torch.int32,
322
+ device=device).view(top_k, -1).permute(
323
+ 1, 0).contiguous())
324
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
325
+ hidden_states,
326
+ row_idx=row_idx,
327
+ expert_idx=topk_ids,
328
+ active_num=num_tokens)
329
+
330
+ global_expert_tokens = torch.bincount(expanded_expert_idx,
331
+ minlength=global_num_experts)
332
+ scatter_sizes = global_expert_tokens.view(ep_group.world_size,
333
+ -1).sum(-1)
334
+
335
+ gather_sizes = torch.empty_like(scatter_sizes)
336
+ dist.all_to_all_single(gather_sizes,
337
+ scatter_sizes,
338
+ group=ep_group.device_group)
339
+ scatter_size_list = scatter_sizes.cpu().tolist()
340
+ gather_size_list = gather_sizes.cpu().tolist()
341
+
342
+ expanded_expert_idx = expanded_expert_idx % local_num_experts
343
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
344
+ scatter_size_list,
345
+ gather_size_list)
346
+ local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
347
+ scatter_size_list,
348
+ gather_size_list)
349
+
350
+ sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
351
+
352
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
353
+ sorted_local_expert_idx, local_num_experts).to(torch.int64)
354
+
355
+ hidden_states = hidden_states[sorted_idx]
356
+ else:
357
+ row_idx_len = num_tokens * top_k
358
+ row_idx = torch.arange(0,
359
+ row_idx_len,
360
+ dtype=torch.int32,
361
+ device=topk_weights.device).view(
362
+ top_k, -1).permute(1, 0).contiguous()
363
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
364
+ hidden_states,
365
+ row_idx=row_idx,
366
+ expert_idx=topk_ids,
367
+ active_num=num_tokens)
368
+
369
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
370
+ expanded_expert_idx, num_experts)
371
+ expert_tokens = expert_tokens.to(torch.int64)
372
+
373
+ w1 = w1.transpose(1, 2)
374
+ gate_up_out_list = torch_npu.npu_grouped_matmul(
375
+ x=[hidden_states],
376
+ weight=[w1],
377
+ split_item=2,
378
+ group_list_type=0,
379
+ group_type=0,
380
+ group_list=expert_tokens,
381
+ )
382
+
383
+ # TODO: Remove this in the future.
384
+ hidden_states = torch.cat(gate_up_out_list, dim=0)
385
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
386
+
387
+ w2 = w2.transpose(1, 2)
388
+ down_out_list = torch_npu.npu_grouped_matmul(
389
+ x=[hidden_states],
390
+ weight=[w2],
391
+ split_item=2,
392
+ group_list_type=0,
393
+ group_type=0,
394
+ group_list=expert_tokens,
395
+ )
396
+
397
+ hidden_states = torch.cat(down_out_list, dim=0)
398
+
399
+ if expert_map is not None:
400
+ resorted_idx = torch.argsort(sorted_idx)
401
+ hidden_states = hidden_states[resorted_idx]
402
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
403
+ gather_size_list,
404
+ scatter_size_list)
405
+
406
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
407
+ hidden_states,
408
+ skip1=None,
409
+ skip2=None,
410
+ bias=None,
411
+ scales=topk_weights,
412
+ expanded_src_to_dst_row=expanded_row_idx,
413
+ export_for_source_row=topk_ids,
414
+ )
415
+ else:
416
+ # TODO: Reorder device memory 2 times here, replace the current
417
+ # implementation here when suitable operators become available.
418
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
419
+ hidden_states,
420
+ skip1=None,
421
+ skip2=None,
422
+ bias=None,
423
+ scales=topk_weights,
424
+ expanded_src_to_dst_row=expanded_row_idx,
425
+ export_for_source_row=topk_ids,
426
+ )
427
+ if len(original_shape) == 3:
428
+ final_hidden_states = final_hidden_states.view(original_shape)
429
+ return final_hidden_states
430
+
431
+
432
+ # currently expert parallelism implemented with all2all
433
+ # is under-optimized.
434
+ def fused_experts_with_all2all_buffer(
435
+ hidden_states: torch.Tensor,
436
+ w1: torch.Tensor,
437
+ w2: torch.Tensor,
438
+ topk_weights: torch.Tensor,
439
+ topk_ids: torch.Tensor,
440
+ top_k: int,
441
+ max_model_len: int,
442
+ global_batch_size: int,
443
+ expert_map: torch.Tensor = None,
444
+ ep_group: GroupCoordinator = None,
445
+ ):
446
+ original_shape = hidden_states.shape
447
+ if len(original_shape) == 3:
448
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
449
+
450
+ num_tokens, _ = hidden_states.shape
451
+ device = hidden_states.device
452
+
453
+ global_num_experts = len(expert_map)
454
+ local_num_experts = global_num_experts // ep_group.world_size
455
+ row_idx_len = num_tokens * top_k
456
+ row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
457
+ device=device).view(top_k,
458
+ -1).permute(1, 0).contiguous())
459
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
460
+ hidden_states,
461
+ row_idx=row_idx,
462
+ expert_idx=topk_ids,
463
+ active_num=num_tokens)
464
+
465
+ max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
466
+ max_model_len // ep_group.world_size +
467
+ 1) * top_k * 2
468
+ expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
469
+ expanded_expert_idx, global_num_experts, ep_group.world_size,
470
+ max_row_per_ep_rank, num_tokens, top_k)
471
+ hidden_states_pad_idx = torch.zeros(
472
+ expert_idx_buffer_scatter.shape,
473
+ dtype=expert_idx_buffer_scatter.dtype,
474
+ device=expert_idx_buffer_scatter.device)
475
+ non_pad_len = torch.sum((expert_idx_buffer_scatter
476
+ != global_num_experts).to(torch.int32))
477
+ hidden_states_pad_idx[expert_idx_buffer_scatter !=
478
+ global_num_experts] = torch.arange(
479
+ non_pad_len,
480
+ dtype=expert_idx_buffer_scatter.dtype,
481
+ device=hidden_states.device)
482
+
483
+ hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
484
+ expert_idx_buffer_gather = torch.empty_like(
485
+ expert_idx_buffer_scatter,
486
+ dtype=expert_idx_buffer_scatter.dtype,
487
+ device=expert_idx_buffer_scatter.device)
488
+ hidden_states_buffer_gather = torch.empty_like(
489
+ hidden_states_buffer_scatter,
490
+ dtype=hidden_states_buffer_scatter.dtype,
491
+ device=hidden_states_buffer_scatter.device)
492
+ dist.all_to_all_single(expert_idx_buffer_gather,
493
+ expert_idx_buffer_scatter,
494
+ group=ep_group.device_group)
495
+ dist.all_to_all_single(hidden_states_buffer_gather,
496
+ hidden_states_buffer_scatter,
497
+ group=ep_group.device_group)
498
+ mask = expert_idx_buffer_gather != global_num_experts
499
+ local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
500
+ global_num_experts // ep_group.world_size)
501
+ hidden_states = hidden_states_buffer_gather[mask]
502
+ idx_type = local_expert_idx.dtype
503
+ sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
504
+ sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
505
+
506
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
507
+ sorted_local_expert_idx, local_num_experts).to(torch.int64)
508
+ hidden_states = hidden_states[sorted_idx]
509
+ group_list_type = 0
510
+
511
+ hidden_states_wrapper = [hidden_states]
512
+ del hidden_states
513
+
514
+ hidden_states = apply_mlp(hidden_states_wrapper,
515
+ w1,
516
+ w2,
517
+ expert_tokens,
518
+ group_list_type=group_list_type)
519
+
520
+ resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
521
+ hidden_states = hidden_states[resorted_idx]
522
+ hidden_states_scatter = torch.zeros(
523
+ (mask.shape[0], hidden_states.shape[1]),
524
+ dtype=hidden_states.dtype,
525
+ device=hidden_states.device)
526
+ hidden_states_scatter[mask] = hidden_states
527
+ hidden_states_gatter = torch.empty_like(
528
+ hidden_states_scatter,
529
+ dtype=hidden_states_scatter.dtype,
530
+ device=hidden_states_scatter.device)
531
+ dist.all_to_all_single(hidden_states_gatter,
532
+ hidden_states_scatter,
533
+ group=ep_group.device_group)
534
+ hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
535
+ global_num_experts]
536
+ if hidden_states_gatter.shape[0] != row_idx_len:
537
+ hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
538
+ dtype=hidden_states.dtype,
539
+ device=hidden_states.device)
540
+ hidden_states[unpad_indices != -1] = hidden_states_gatter
541
+ else:
542
+ # TODO: Reorder device memory 2 times here, replace the current
543
+ hidden_states = hidden_states_gatter
544
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
545
+ hidden_states,
546
+ skip1=None,
547
+ skip2=None,
548
+ bias=None,
549
+ scales=topk_weights,
550
+ expanded_src_to_dst_row=expanded_row_idx,
551
+ export_for_source_row=topk_ids,
552
+ )
553
+
554
+ if len(original_shape) == 3:
555
+ final_hidden_states = final_hidden_states.view(original_shape)
556
+ return final_hidden_states
557
+
558
+
559
+ def fused_experts_moge(
560
+ hidden_states: torch.Tensor,
561
+ w1: torch.Tensor,
562
+ w2: torch.Tensor,
563
+ topk_weights: torch.Tensor,
564
+ topk_ids: torch.Tensor,
565
+ top_k: int,
566
+ global_num_experts: int,
567
+ expert_map: torch.Tensor = None,
568
+ apply_router_weight_on_input: bool = False,
569
+ ) -> torch.Tensor:
570
+ """
571
+
572
+ Args:
573
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
574
+ w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
575
+ w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
576
+ topk_weights: Routing weights of shape (num_tokens, top_k).
577
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
578
+ top_k: Number of experts to select.
579
+ expert_map: Expert mapping of shape (num_experts,).
580
+
581
+ Returns:
582
+ hidden_states: Hidden states after routing.
583
+ """
584
+ ep_size = get_ep_group().world_size
585
+ local_num_experts = global_num_experts // ep_size
586
+ local_num_group = top_k // ep_size
587
+
588
+ if apply_router_weight_on_input:
589
+ assert (topk_weights.dim() == 2
590
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
591
+ _, topk = topk_weights.shape
592
+ assert (
593
+ topk == 1
594
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
595
+ hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
596
+
597
+ bsz, _ = hidden_states.shape
598
+ flatten_topk_ids = topk_ids.view(-1)
599
+ sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
600
+ sorted_topk_ids = sorted_topk_ids.to(torch.int32)
601
+ sorted_hidden_states = hidden_states.index_select(
602
+ 0, sorted_topk_ids // local_num_group)
603
+
604
+ experts_id = torch.arange(0,
605
+ local_num_experts,
606
+ dtype=topk_ids.dtype,
607
+ device=topk_ids.device)
608
+ num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
609
+ torch.float32).sum(0)
610
+ topk_scales = topk_weights.view(-1).index_select(
611
+ 0, sorted_topk_ids).unsqueeze(-1)
612
+ group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
613
+
614
+ w1 = w1.transpose(1, 2)
615
+ gate_up_out = torch_npu.npu_grouped_matmul(
616
+ x=[sorted_hidden_states],
617
+ weight=[w1],
618
+ split_item=2,
619
+ group_list_type=0,
620
+ group_type=0,
621
+ group_list=group_list,
622
+ )[0]
623
+
624
+ if is_310p():
625
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
626
+ torch.float16)
627
+ else:
628
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
629
+ gate_up_out *= topk_scales
630
+
631
+ w2 = w2.transpose(1, 2)
632
+ down_out_list = torch_npu.npu_grouped_matmul(
633
+ x=[gate_up_out],
634
+ weight=[w2],
635
+ split_item=2,
636
+ group_list_type=0,
637
+ group_type=0,
638
+ group_list=group_list,
639
+ )[0]
640
+
641
+ unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
642
+ unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
643
+ final_hidden_states = unsorted_hidden_states.reshape(
644
+ bsz, top_k // ep_size, -1).sum(1)
645
+
646
+ return final_hidden_states
647
+
648
+
649
+ def fused_experts(
650
+ hidden_states: torch.Tensor,
651
+ w1: torch.Tensor,
652
+ w2: torch.Tensor,
653
+ topk_weights: torch.Tensor,
654
+ topk_ids: torch.Tensor,
655
+ top_k: int,
656
+ expert_map: torch.Tensor = None,
657
+ apply_router_weight_on_input: bool = False,
658
+ max_num_tokens: Optional[int] = None,
659
+ ) -> torch.Tensor:
660
+ """
661
+ Fused experts with top-k routing.
662
+
663
+ Args:
664
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
665
+ w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
666
+ w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
667
+ topk_weights: Routing weights of shape (num_tokens, top_k).
668
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
669
+ top_k: Number of experts to select.
670
+ expert_map: Expert mapping of shape (num_experts,).
671
+
672
+ Returns:
673
+ hidden_states: Hidden states after routing.
674
+ """
675
+ """
676
+ # Check constraints.
677
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
678
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
679
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
680
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
681
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
682
+ """
683
+ # if torch.distributed.get_rank() == 0:
684
+ # print(w1.shape)
685
+ # print(hidden_states.shape)
686
+
687
+ original_shape = hidden_states.shape
688
+ # assert len(original_shape) == 2
689
+
690
+ num_tokens = hidden_states.shape[:-1].numel()
691
+ num_experts = w1.shape[0]
692
+ dtype = hidden_states.dtype
693
+ device = hidden_states.device
694
+ # assert dtype in [torch.float32, torch.float16, torch.bfloat16
695
+ # ], "Only float32, float16, and bfloat16 are supported"
696
+
697
+ if apply_router_weight_on_input:
698
+ assert (topk_weights.dim() == 2
699
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
700
+ _, topk = topk_weights.shape
701
+ assert (
702
+ topk == 1
703
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
704
+ hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
705
+
706
+ if expert_map is not None:
707
+ # Generate token indices and flatten
708
+ token_indices = (torch.arange(num_tokens,
709
+ device=device,
710
+ dtype=torch.int64).unsqueeze(1).expand(
711
+ -1, top_k).reshape(-1))
712
+
713
+ # Flatten token-to-expert mappings and map to local experts
714
+ weights_flat = topk_weights.view(-1)
715
+ experts_flat = topk_ids.view(-1)
716
+ local_experts_flat = expert_map[experts_flat]
717
+
718
+ # Filter valid token-expert pairs
719
+ mask = local_experts_flat != -1
720
+ filtered_weights = torch.where(
721
+ mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
722
+ filtered_experts = torch.where(
723
+ mask, local_experts_flat,
724
+ torch.full_like(local_experts_flat,
725
+ num_experts)).to(topk_ids.dtype)
726
+
727
+ # Sort by local expert IDs
728
+ sort_indices = torch.argsort(filtered_experts.view(torch.float32))
729
+ sorted_token_indices = token_indices[sort_indices]
730
+ sorted_weights = filtered_weights[sort_indices]
731
+
732
+ # Compute token counts with minlength of num_experts
733
+ # This is equivalent to but faster than:
734
+ # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
735
+ token_counts = torch.zeros(num_experts + 1,
736
+ device=device,
737
+ dtype=torch.int64)
738
+ ones = torch.ones_like(filtered_experts, dtype=torch.int64)
739
+ token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
740
+ token_counts = token_counts[:num_experts]
741
+ expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
742
+
743
+ # Rearrange hidden_states
744
+ sorted_hidden_states = hidden_states[sorted_token_indices]
745
+ else:
746
+ row_idx_len = num_tokens * top_k
747
+ row_idx = (torch.arange(0,
748
+ row_idx_len,
749
+ dtype=torch.int32,
750
+ device=device).view(top_k, -1).permute(
751
+ 1, 0).contiguous())
752
+ active_num = max_num_tokens if max_num_tokens is not None else num_tokens
753
+ sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
754
+ hidden_states,
755
+ row_idx=row_idx,
756
+ expert_idx=topk_ids,
757
+ active_num=active_num)
758
+
759
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
760
+ expanded_expert_idx, num_experts)
761
+ expert_tokens = expert_tokens.to(torch.int64)
762
+
763
+ w1 = w1.transpose(1, 2)
764
+ gate_up_out_list = torch_npu.npu_grouped_matmul(
765
+ x=[sorted_hidden_states],
766
+ weight=[w1],
767
+ split_item=2,
768
+ group_list_type=0,
769
+ group_type=0,
770
+ group_list=expert_tokens,
771
+ )
772
+
773
+ # TODO: Remove this in the future.
774
+ gate_up_out = torch.cat(gate_up_out_list, dim=0)
775
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
776
+
777
+ w2 = w2.transpose(1, 2)
778
+ down_out_list = torch_npu.npu_grouped_matmul(
779
+ x=[gate_up_out],
780
+ weight=[w2],
781
+ split_item=2,
782
+ group_list_type=0,
783
+ group_type=0,
784
+ group_list=expert_tokens,
785
+ )
786
+
787
+ down_out_list = torch.cat(down_out_list, dim=0)
788
+
789
+ if expert_map is not None:
790
+ weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
791
+
792
+ final_hidden_states = torch.zeros(*original_shape,
793
+ device=hidden_states.device,
794
+ dtype=dtype)
795
+
796
+ # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
797
+ # This created multiple NaN and index_add_ will mix them up which harms accuracy
798
+ # remove this mask and filter after it being fixed
799
+ num_valid_tokens = mask.sum()
800
+ valid_token_mask = torch.arange(
801
+ 0, sorted_token_indices.shape[0],
802
+ device=device).unsqueeze(1) < num_valid_tokens
803
+ valid_output = torch.where(
804
+ valid_token_mask, weighted_down_out,
805
+ torch.zeros_like(weighted_down_out)).to(dtype)
806
+ final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
807
+ else:
808
+ scales = torch.ones_like(
809
+ topk_weights) if apply_router_weight_on_input else topk_weights
810
+ # TODO: Reorder device memory 2 times here, replace the current
811
+ # implementation here when suitable operators become available.
812
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
813
+ down_out_list,
814
+ skip1=None,
815
+ skip2=None,
816
+ bias=None,
817
+ scales=scales,
818
+ expanded_src_to_dst_row=expanded_row_idx,
819
+ export_for_source_row=topk_ids,
820
+ )
821
+
822
+ return final_hidden_states
823
+
824
+
825
+ def fused_experts_allgather_ep(
826
+ hidden_states: torch.Tensor,
827
+ w1: torch.Tensor,
828
+ w2: torch.Tensor,
829
+ topk_weights: torch.Tensor,
830
+ topk_ids: torch.Tensor,
831
+ is_prefill: bool
832
+ ):
833
+ local_rank = torch.distributed.get_rank(group=get_ep_group().device_group)
834
+ num_experts_per_ep = w1.shape[0]
835
+ local_expert_indices_offset = local_rank * num_experts_per_ep
836
+ global_local_mask = (topk_ids >= local_expert_indices_offset) & \
837
+ (topk_ids <= local_expert_indices_offset + num_experts_per_ep - 1)
838
+ non_global_local_mask = (~global_local_mask).to(torch.int32)
839
+ global_local_mask = global_local_mask.to(torch.int32)
840
+ row_idx = torch.arange(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32).view(
841
+ -1, topk_ids.shape[0]).transpose(0, 1).contiguous()
842
+
843
+ topk_ids -= local_expert_indices_offset
844
+ local_topk_ids_mask_with_max = topk_ids * global_local_mask + non_global_local_mask * num_experts_per_ep
845
+ sorted_tokens, expanded_src_to_dst_row, expanded_expert_idx = torch_npu.npu_moe_init_routing(
846
+ x=hidden_states,
847
+ row_idx=row_idx,
848
+ expert_idx=local_topk_ids_mask_with_max,
849
+ active_num=topk_ids.shape[0]*topk_ids.shape[1]
850
+ )
851
+ if expanded_expert_idx.shape[0] > 8192:
852
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep + 1)
853
+ expert_tokens = expert_tokens[:-1]
854
+ else:
855
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep)
856
+ expert_tokens = expert_tokens.to(torch.int64)
857
+
858
+ w1 = w1.transpose(1, 2)
859
+ gate_up_out = torch_npu.npu_grouped_matmul(
860
+ x=[sorted_tokens],
861
+ weight=[w1],
862
+ group_list=expert_tokens,
863
+ split_item=3,
864
+ group_type=0
865
+ )[0]
866
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
867
+
868
+ w2 = w2.transpose(1, 2)
869
+ down_out = torch_npu.npu_grouped_matmul(
870
+ x=[gate_up_out],
871
+ weight=[w2],
872
+ group_list=expert_tokens,
873
+ split_item=3,
874
+ group_type=0
875
+ )[0]
876
+
877
+ if is_prefill:
878
+ down_out[expert_tokens[-1]:] = 0
879
+ else:
880
+ sorted_tokens_mask = expanded_expert_idx != num_experts_per_ep
881
+ down_out *= sorted_tokens_mask.unsqueeze(1)
882
+
883
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
884
+ expanded_permuted_rows=down_out,
885
+ skip1=None,
886
+ skip2=None,
887
+ bias=None,
888
+ scales=topk_weights.to(down_out.dtype),
889
+ expanded_src_to_dst_row=expanded_src_to_dst_row,
890
+ export_for_source_row=topk_ids
891
+ )
892
+ return final_hidden_states
893
+
894
+
895
+ def select_gating_top_k_softmax_experts(
896
+ hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
897
+ renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
898
+ """
899
+ Select top-k experts based on router logits.
900
+ only supports float16、bfloat16、float32
901
+
902
+ Args:
903
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
904
+ router_logits: Router logits of shape (num_tokens, num_experts).
905
+ top_k: Number of experts to select.
906
+ renormalize: Whether to renormalize the routing weights.
907
+
908
+ Returns:
909
+ topk_weights: Routing weights of shape (num_tokens, top_k).
910
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
911
+
912
+ Raises:
913
+ ValueError: If an unsupported scoring function is provided.
914
+ """
915
+ topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
916
+ router_logits, None, k=top_k)
917
+
918
+ # # Required by npu_moe_init_routing
919
+ # topk_weights = topk_weights.to(hidden_states.dtype)
920
+ # topk_ids = topk_ids.to(torch.int32)
921
+
922
+ if renormalize:
923
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
924
+
925
+ return topk_weights, topk_ids
926
+
927
+
928
+ def native_grouped_topk(
929
+ topk_weights: torch.Tensor,
930
+ num_expert_group: Optional[int],
931
+ topk_group: Optional[int],
932
+ ):
933
+ topk_group = 0 if topk_group is None else topk_group
934
+ num_expert_group = 0 if num_expert_group is None else num_expert_group
935
+
936
+ num_token = topk_weights.shape[0]
937
+ grouped_weights = topk_weights.view(num_token, num_expert_group,
938
+ -1).max(dim=-1).values
939
+ topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
940
+ k=topk_group,
941
+ dim=-1,
942
+ sorted=False)[1]
943
+ topk_group_mask = torch.zeros_like(grouped_weights)
944
+ topk_group_mask.scatter_(1, topk_group_indices, 1)
945
+ topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
946
+ num_token, num_expert_group,
947
+ topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
948
+ topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
949
+
950
+ return topk_weights
951
+
952
+
953
+ def select_experts(
954
+ hidden_states: torch.Tensor,
955
+ router_logits: torch.Tensor,
956
+ top_k: int,
957
+ use_grouped_topk: bool,
958
+ renormalize: bool,
959
+ topk_group: Optional[int] = None,
960
+ num_expert_group: Optional[int] = None,
961
+ custom_routing_function: Optional[Callable] = None,
962
+ scoring_func: str = "softmax",
963
+ e_score_correction_bias: Optional[torch.Tensor] = None,
964
+ global_num_experts: Optional[torch.Tensor] = None
965
+ ) -> tuple[torch.Tensor, torch.Tensor]:
966
+ """
967
+ Select top-k experts based on router logits.
968
+
969
+ Args:
970
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
971
+ router_logits: Router logits of shape (num_tokens, num_experts).
972
+ top_k: Number of experts to select.
973
+ use_grouped_topk: Whether to group experts before selecting top-k.
974
+ renormalize: Whether to renormalize the routing weights.
975
+ topk_group: Number of expert groups to select from.
976
+ num_expert_group: Number of experts in each group.
977
+ custom_routing_function: Custom routing function.
978
+ scoring_func: Scoring function to use.
979
+ e_score_correction_bias: Correction bias to apply to expert scores.
980
+
981
+ Returns:
982
+ topk_weights: Routing weights of shape (num_tokens, top_k).
983
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
984
+
985
+ Raises:
986
+ ValueError: If an unsupported scoring function is provided.
987
+ """
988
+
989
+ if scoring_func == "softmax":
990
+ # NOTE: vLLM use dtype=torch.float here
991
+ topk_weights = router_logits.softmax(dim=-1)
992
+ elif scoring_func == "sigmoid":
993
+ topk_weights = router_logits.sigmoid()
994
+ else:
995
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
996
+
997
+ if use_grouped_topk:
998
+ assert topk_group is not None
999
+ assert num_expert_group is not None
1000
+
1001
+ if e_score_correction_bias is not None:
1002
+ # Store original scores before applying correction bias. We use biased
1003
+ # scores for expert selection but original scores for routing weights
1004
+ original_weights = topk_weights
1005
+ topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
1006
+
1007
+ # TODO: Change to npu_group_topk when the latest CANN and NNAL is available
1008
+ # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
1009
+ topk_weights = native_grouped_topk(topk_weights, num_expert_group,
1010
+ topk_group)
1011
+ # TODO bfloat16 is not supported in torch.topk with ge graph.
1012
+ if e_score_correction_bias is not None:
1013
+ topk_ids = torch.topk(topk_weights.to(torch.float32),
1014
+ k=top_k,
1015
+ dim=-1,
1016
+ sorted=False)[1]
1017
+ # Use original unbiased scores for the routing weights
1018
+ topk_weights = original_weights.gather(1, topk_ids)
1019
+ else:
1020
+ topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
1021
+ k=top_k,
1022
+ dim=-1,
1023
+ sorted=False)
1024
+ elif custom_routing_function is None:
1025
+ topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
1026
+ else:
1027
+ topk_weights, topk_ids = custom_routing_function(
1028
+ hidden_states=hidden_states,
1029
+ gating_output=router_logits,
1030
+ topk=top_k,
1031
+ renormalize=renormalize,
1032
+ global_num_experts=global_num_experts)
1033
+ # Required by npu_moe_init_routing
1034
+ topk_ids = topk_ids.to(torch.int32)
1035
+ return topk_weights, topk_ids
1036
+
1037
+ # Required by npu_moe_init_routing
1038
+ topk_ids = topk_ids.to(torch.int32)
1039
+
1040
+ if renormalize:
1041
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
1042
+
1043
+ return topk_weights, topk_ids
1044
+
1045
+
1046
+ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1047
+
1048
+ def __init__(self, moe: FusedMoEConfig = None):
1049
+
1050
+ super().__init__(moe=moe)
1051
+ vllm_config = get_current_vllm_config()
1052
+
1053
+ self.ep_group = get_ep_group()
1054
+ self.ep_size = self.ep_group.world_size
1055
+ self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
1056
+ self.local_batch_size = self.global_batch_size // self.ep_size
1057
+ self.max_model_len = vllm_config.model_config.max_model_len
1058
+
1059
+ ascend_config = get_ascend_config()
1060
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1061
+
1062
+ try:
1063
+ device_group = self.ep_group.device_group
1064
+ # TODO: Try local_rank = ep_group.rank_in_group
1065
+ local_rank = torch.distributed.get_rank(group=device_group)
1066
+ backend = device_group._get_backend(torch.device("npu"))
1067
+ self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
1068
+ local_rank)
1069
+ except AttributeError:
1070
+ self.moe_all_to_all_group_name = None
1071
+
1072
+ def process_weights_after_loading(self, layer):
1073
+ super(UnquantizedFusedMoEMethod,
1074
+ self).process_weights_after_loading(layer)
1075
+ layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
1076
+ layer.w13_weight.data),
1077
+ requires_grad=False)
1078
+ layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
1079
+ layer.w2_weight.data),
1080
+ requires_grad=False)
1081
+
1082
+ def apply(
1083
+ self,
1084
+ layer: torch.nn.Module,
1085
+ x: torch.Tensor,
1086
+ router_logits: torch.Tensor,
1087
+ top_k: int,
1088
+ renormalize: bool,
1089
+ use_grouped_topk: bool = False,
1090
+ global_num_experts: int = -1,
1091
+ expert_map: Optional[torch.Tensor] = None,
1092
+ topk_group: Optional[int] = None,
1093
+ num_expert_group: Optional[int] = None,
1094
+ custom_routing_function: Optional[Callable] = None,
1095
+ scoring_func: str = "softmax",
1096
+ e_score_correction_bias: Optional[torch.Tensor] = None,
1097
+ is_prefill: bool = False,
1098
+ enable_force_load_balance: bool = False,
1099
+ shared_experts: Optional[Any] = None,
1100
+ **kwargs,
1101
+ ) -> torch.Tensor:
1102
+ use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
1103
+ is_deepseek_v3_r1 = global_num_experts == 256
1104
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1105
+ if use_grouped_topk and is_deepseek_v3_r1:
1106
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1107
+ router_logits,
1108
+ k=top_k, # topk当前写8
1109
+ bias=e_score_correction_bias,
1110
+ k_group=topk_group, # fix: 4
1111
+ group_count=num_expert_group, # fix 8
1112
+ group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
1113
+ renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
1114
+ norm_type=1, # 0: softmax; 1: sigmoid(fix)
1115
+ # out_flag=False, # todo new api; 第三个输出是否输出
1116
+ # y2_flag=False, # old api; 第三个输出是否输出
1117
+ routed_scaling_factor=1,
1118
+ eps=float(1e-20))
1119
+ elif use_grouped_topk and SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
1120
+ topk_weights, topk_ids = select_gating_top_k_softmax_experts(
1121
+ hidden_states=x,
1122
+ router_logits=router_logits,
1123
+ top_k=top_k,
1124
+ renormalize=renormalize)
1125
+ else:
1126
+ topk_weights, topk_ids = select_experts(
1127
+ hidden_states=x,
1128
+ router_logits=router_logits,
1129
+ top_k=top_k,
1130
+ use_grouped_topk=use_grouped_topk,
1131
+ renormalize=renormalize,
1132
+ topk_group=topk_group,
1133
+ num_expert_group=num_expert_group,
1134
+ custom_routing_function=custom_routing_function,
1135
+ scoring_func=scoring_func,
1136
+ e_score_correction_bias=e_score_correction_bias,
1137
+ )
1138
+
1139
+ topk_weights = topk_weights.to(x.dtype)
1140
+ # this is a naive implementation for experts load balance so as
1141
+ # to avoid accumulating too much tokens on a single rank.
1142
+ # currently it is only activated when doing profile runs.
1143
+ if enable_force_load_balance:
1144
+ topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
1145
+
1146
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
1147
+ is_prefill, is_deepseek_v3_r1)
1148
+ if fused_moe_state == FusedMoEState.MC2:
1149
+ return fused_experts_with_mc2(
1150
+ hidden_states=x,
1151
+ w1=layer.w13_weight,
1152
+ w2=layer.w2_weight,
1153
+ topk_weights=topk_weights,
1154
+ topk_ids=topk_ids,
1155
+ top_k=top_k,
1156
+ expert_map=expert_map,
1157
+ moe_all_to_all_group_name=self.moe_all_to_all_group_name,
1158
+ shared_experts=shared_experts)
1159
+ elif fused_moe_state == FusedMoEState.AllGatherEP:
1160
+ return fused_experts_allgather_ep(
1161
+ hidden_states=x,
1162
+ w1=layer.w13_weight,
1163
+ w2=layer.w2_weight,
1164
+ topk_weights=topk_weights,
1165
+ topk_ids=topk_ids,
1166
+ is_prefill=is_prefill)
1167
+ elif fused_moe_state in [
1168
+ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
1169
+ ]:
1170
+ return fused_experts(hidden_states=x,
1171
+ w1=layer.w13_weight,
1172
+ w2=layer.w2_weight,
1173
+ topk_weights=topk_weights,
1174
+ topk_ids=topk_ids,
1175
+ top_k=top_k,
1176
+ expert_map=expert_map)
1177
+ elif MOE_ALL2ALL_BUFFER:
1178
+ return fused_experts_with_all2all_buffer(
1179
+ hidden_states=x,
1180
+ w1=layer.w13_weight,
1181
+ w2=layer.w2_weight,
1182
+ topk_weights=topk_weights,
1183
+ topk_ids=topk_ids,
1184
+ top_k=top_k,
1185
+ max_model_len=self.max_model_len,
1186
+ global_batch_size=self.global_batch_size,
1187
+ expert_map=expert_map,
1188
+ ep_group=get_ep_group())
1189
+ else:
1190
+ return fused_experts_with_all2all(hidden_states=x,
1191
+ w1=layer.w13_weight,
1192
+ w2=layer.w2_weight,
1193
+ topk_weights=topk_weights,
1194
+ topk_ids=topk_ids,
1195
+ top_k=top_k,
1196
+ expert_map=expert_map,
1197
+ ep_group=get_ep_group())
1198
+
1199
+
1200
+ class AscendFusedMoE(FusedMoE):
1201
+
1202
+ # The moe_counter parameter is required during the initialization of EPLB
1203
+ # to identify the current layer index within the MOE model.
1204
+ moe_counter = -1
1205
+
1206
+ def __init__(
1207
+ self,
1208
+ num_experts: int, # Global number of experts
1209
+ top_k: int,
1210
+ hidden_size: int,
1211
+ intermediate_size: int,
1212
+ params_dtype: Optional[torch.dtype] = None,
1213
+ reduce_results: bool = False,
1214
+ renormalize: bool = True,
1215
+ use_grouped_topk: bool = False,
1216
+ num_expert_group: Optional[int] = None,
1217
+ topk_group: Optional[int] = None,
1218
+ quant_config: Optional[QuantizationConfig] = None,
1219
+ tp_size: Optional[int] = None,
1220
+ ep_size: Optional[int] = None,
1221
+ dp_size: Optional[int] = None,
1222
+ prefix: str = "",
1223
+ custom_routing_function: Optional[Callable] = None,
1224
+ scoring_func: str = "softmax",
1225
+ e_score_correction_bias: Optional[torch.Tensor] = None,
1226
+ activation: str = "silu",
1227
+ apply_router_weight_on_input: bool = False,
1228
+ ):
1229
+ # TODO: This could not initialize FusedMoE baseclass,
1230
+ # fixme and make __init__() of AscendFusedMoE more clear
1231
+ super(FusedMoE, self).__init__()
1232
+
1233
+ AscendFusedMoE.moe_counter += 1
1234
+ self.moe_instance_id = AscendFusedMoE.moe_counter
1235
+
1236
+ if params_dtype is None:
1237
+ params_dtype = torch.get_default_dtype()
1238
+
1239
+ vllm_config = get_current_vllm_config()
1240
+
1241
+ self.moe_parallel_config = FusedMoEParallelConfig.make(
1242
+ tp_size_=(tp_size if tp_size is not None else
1243
+ get_tensor_model_parallel_world_size()),
1244
+ dp_size_=(dp_size
1245
+ if dp_size is not None else get_dp_group().world_size),
1246
+ vllm_parallel_config=vllm_config.parallel_config)
1247
+
1248
+ self.top_k = top_k
1249
+ self.num_experts = num_experts
1250
+ self.global_num_experts = num_experts
1251
+ assert intermediate_size % self.tp_size == 0
1252
+ self.intermediate_size_per_partition = intermediate_size // self.tp_size
1253
+ self.reduce_results = reduce_results
1254
+ self.renormalize = renormalize
1255
+ self.use_grouped_topk = use_grouped_topk
1256
+ if self.use_grouped_topk:
1257
+ assert num_expert_group is not None and topk_group is not None
1258
+ self.num_expert_group = num_expert_group
1259
+ self.topk_group = topk_group
1260
+ self.custom_routing_function = custom_routing_function
1261
+ self.scoring_func = scoring_func
1262
+ self.e_score_correction_bias = e_score_correction_bias
1263
+ self.expert_map = None
1264
+ self.activation = activation
1265
+ self.log2phy = None
1266
+ self.global_redundant_expert_num = 0
1267
+
1268
+ is_deepseek_v3_r1 = self.global_num_experts == 256
1269
+ self.rm_router_logits = get_rm_router_logits_state(
1270
+ self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
1271
+ self.all_reduce_merge = get_all_reduce_merge_state(
1272
+ self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
1273
+
1274
+ ascend_config = get_ascend_config()
1275
+ expert_map_path = ascend_config.expert_map_path
1276
+ if expert_map_path and os.path.exists(expert_map_path):
1277
+ # moe expert load balance
1278
+ expert_load_balancer = ExpertLoadBalancer(expert_map_path,
1279
+ self.global_num_experts)
1280
+ self.local_num_experts, self.expert_map = \
1281
+ expert_load_balancer.get_rank_placement_map(
1282
+ self.moe_instance_id,
1283
+ get_ep_group().rank_in_group)
1284
+ self.log2phy = expert_load_balancer.get_rank_log2phy_map(
1285
+ self.moe_instance_id,
1286
+ get_ep_group().rank_in_group)
1287
+ self.global_redundant_expert_num = \
1288
+ expert_load_balancer.get_global_redundant_expert_num()
1289
+ else:
1290
+ # Create a tensor of size num_experts filled with -1
1291
+ self.local_num_experts, self.expert_map = determine_expert_map(
1292
+ self.ep_size,
1293
+ get_ep_group().rank_in_group, self.global_num_experts)
1294
+
1295
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1296
+ self.enable_multistream_moe = \
1297
+ ascend_config.torchair_graph_config.enable_multistream_moe
1298
+
1299
+ if self.scoring_func != "softmax" and not self.use_grouped_topk:
1300
+ raise ValueError("Only softmax scoring function is supported for "
1301
+ "non-grouped topk.")
1302
+ moe = FusedMoEConfig.make(
1303
+ num_experts=self.global_num_experts,
1304
+ experts_per_token=top_k,
1305
+ hidden_dim=hidden_size,
1306
+ num_local_experts=self.local_num_experts,
1307
+ moe_parallel_config=self.moe_parallel_config,
1308
+ # TODO (bnell): this needs to be fixed for quantized types.
1309
+ in_dtype=params_dtype,
1310
+ quant_config=quant_config)
1311
+
1312
+ if quant_config is None:
1313
+ self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
1314
+ else:
1315
+ self.quant_method = quant_config.get_quant_method(self, prefix)
1316
+
1317
+ assert self.quant_method is not None
1318
+
1319
+ local_num_experts = torch.sum(self.expert_map != -1) \
1320
+ if self.expert_map is not None else num_experts
1321
+
1322
+ moe_quant_params = {
1323
+ "num_experts": local_num_experts,
1324
+ "hidden_size": hidden_size,
1325
+ "intermediate_size_per_partition":
1326
+ self.intermediate_size_per_partition,
1327
+ "params_dtype": params_dtype,
1328
+ "weight_loader": self.weight_loader,
1329
+ }
1330
+ # need full intermediate size pre-sharding for WNA16 act order
1331
+ if (self.quant_method.__class__.__name__
1332
+ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
1333
+ moe_quant_params["intermediate_size_full"] = intermediate_size
1334
+
1335
+ self.ep_group = get_ep_group()
1336
+ # NOTE: self.tp_group is not expert_tp_group
1337
+ self.tp_group = get_tp_group().device_group
1338
+ self.quant_method.create_weights(layer=self, **moe_quant_params)
1339
+
1340
+ def naive_multicast(self, x: torch.Tensor,
1341
+ cu_tokens_across_dp_cpu: torch.Tensor):
1342
+ assert (len(x.shape) == 2)
1343
+ buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
1344
+ device=x.device,
1345
+ dtype=x.dtype)
1346
+ start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
1347
+ self.dp_rank - 1]
1348
+ end = cu_tokens_across_dp_cpu[self.dp_rank]
1349
+ buffer[start:end, :].copy_(x)
1350
+ for idx in range(self.dp_size):
1351
+ start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
1352
+ end = cu_tokens_across_dp_cpu[idx]
1353
+ get_dp_group().broadcast(buffer[start:end, :], idx)
1354
+ return buffer
1355
+
1356
+ def forward(self,
1357
+ hidden_states: torch.Tensor,
1358
+ router_logits: torch.Tensor,
1359
+ is_prefill: bool,
1360
+ enable_force_load_balance: bool = False,
1361
+ top_k: Optional[int] = None,
1362
+ shared_experts: Optional[Any] = None,
1363
+ gate=None,
1364
+ replace_allreduce: bool = False):
1365
+
1366
+ assert self.quant_method is not None
1367
+
1368
+ if top_k:
1369
+ real_top_k = top_k
1370
+ else:
1371
+ real_top_k = self.top_k
1372
+
1373
+ num_tokens, hidden_size = hidden_states.shape
1374
+ is_deepseek_v3_r1 = self.global_num_experts == 256
1375
+
1376
+ fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
1377
+ is_prefill, is_deepseek_v3_r1)
1378
+ if shared_experts:
1379
+ if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
1380
+ # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
1381
+ shared_hidden_states = shared_experts(hidden_states)
1382
+
1383
+ tp_size = get_tensor_model_parallel_world_size()
1384
+ if (tp_size > 1 and fused_moe_state not in [
1385
+ FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1386
+ FusedMoEState.NaiveMulticast
1387
+ ] and not replace_allreduce):
1388
+ if num_tokens < tp_size:
1389
+ hidden_states = nn.functional.pad(
1390
+ hidden_states, (0, 0, 0, tp_size - num_tokens))
1391
+ router_logits = nn.functional.pad(
1392
+ router_logits, (0, 0, 0, tp_size - num_tokens))
1393
+ chunk_hidden_states = torch.tensor_split(hidden_states,
1394
+ tp_size,
1395
+ dim=0)
1396
+ chunk_router_logits = torch.tensor_split(router_logits,
1397
+ tp_size,
1398
+ dim=0)
1399
+ tp_rank = get_tensor_model_parallel_rank()
1400
+ hidden_states = chunk_hidden_states[tp_rank]
1401
+ router_logits = chunk_router_logits[tp_rank]
1402
+
1403
+ if self.dp_size > 1:
1404
+ if fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
1405
+ # NOTE: When in torchair graph, it has been padded in model_runner_v1
1406
+ if not self.torchair_graph_enabled or is_prefill:
1407
+ attn_metadata = get_forward_context().attn_metadata
1408
+ if attn_metadata is not None:
1409
+ max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
1410
+ if num_tokens < max_num_tokens_across_dp:
1411
+ hidden_states = nn.functional.pad(
1412
+ hidden_states,
1413
+ (0, 0, 0,
1414
+ max_num_tokens_across_dp - num_tokens))
1415
+ if not self.rm_router_logits:
1416
+ router_logits = nn.functional.pad(
1417
+ router_logits,
1418
+ (0, 0, 0,
1419
+ max_num_tokens_across_dp - num_tokens))
1420
+ hidden_states = get_dp_group().all_gather(hidden_states, 0)
1421
+ if self.rm_router_logits:
1422
+ router_logits, _ = gate(hidden_states.float())
1423
+ else:
1424
+ router_logits = get_dp_group().all_gather(router_logits, 0)
1425
+
1426
+ elif fused_moe_state == FusedMoEState.NaiveMulticast:
1427
+ cu_tokens_across_dp_cpu = get_forward_context(
1428
+ ).dp_metadata.cu_tokens_across_dp_cpu
1429
+ hidden_states = self.naive_multicast(hidden_states,
1430
+ cu_tokens_across_dp_cpu)
1431
+ if self.rm_router_logits:
1432
+ router_logits, _ = gate(hidden_states.float())
1433
+ else:
1434
+ router_logits = self.naive_multicast(
1435
+ router_logits, cu_tokens_across_dp_cpu)
1436
+
1437
+ # Matrix multiply.
1438
+ e_hidden_states = self.quant_method.apply(
1439
+ layer=self,
1440
+ x=hidden_states,
1441
+ router_logits=router_logits,
1442
+ top_k=real_top_k,
1443
+ renormalize=self.renormalize,
1444
+ use_grouped_topk=self.use_grouped_topk,
1445
+ global_num_experts=self.global_num_experts,
1446
+ expert_map=self.expert_map,
1447
+ topk_group=self.topk_group,
1448
+ num_expert_group=self.num_expert_group,
1449
+ custom_routing_function=self.custom_routing_function,
1450
+ scoring_func=self.scoring_func,
1451
+ e_score_correction_bias=self.e_score_correction_bias,
1452
+ is_prefill=is_prefill,
1453
+ enable_force_load_balance=enable_force_load_balance,
1454
+ log2phy=self.log2phy,
1455
+ global_redundant_expert_num=self.global_redundant_expert_num,
1456
+ shared_experts=shared_experts if self.torchair_graph_enabled
1457
+ and self.enable_multistream_moe and not is_prefill else None,
1458
+ )
1459
+
1460
+ if shared_experts:
1461
+ if isinstance(e_hidden_states, tuple):
1462
+ e_hidden_states, shared_hidden_states = e_hidden_states
1463
+
1464
+ if (tp_size > 1 and fused_moe_state not in [
1465
+ FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1466
+ FusedMoEState.NaiveMulticast
1467
+ ] and not replace_allreduce):
1468
+ dist.all_gather(list(chunk_hidden_states), e_hidden_states,
1469
+ self.tp_group)
1470
+ final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
1471
+ if num_tokens < tp_size:
1472
+ final_hidden_states = final_hidden_states[:num_tokens]
1473
+ dispose_tensor(e_hidden_states)
1474
+ elif self.dp_size > 1:
1475
+ if fused_moe_state == FusedMoEState.NaiveMulticast:
1476
+ start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
1477
+ self.dp_rank - 1]
1478
+ end = cu_tokens_across_dp_cpu[self.dp_rank]
1479
+ final_hidden_states = get_dp_group().all_reduce(
1480
+ e_hidden_states)
1481
+ final_hidden_states = final_hidden_states[start:end, :]
1482
+ dispose_tensor(e_hidden_states)
1483
+ elif fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
1484
+ final_hidden_states = data_parallel_reduce_scatter(
1485
+ e_hidden_states, dim=0)
1486
+ final_hidden_states = final_hidden_states[:num_tokens]
1487
+ dispose_tensor(e_hidden_states)
1488
+ else:
1489
+ final_hidden_states = e_hidden_states
1490
+
1491
+ if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
1492
+ FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1493
+ FusedMoEState.NaiveMulticast
1494
+ ]:
1495
+ final_hidden_states = tensor_model_parallel_all_reduce(
1496
+ final_hidden_states)
1497
+
1498
+ if shared_experts:
1499
+ return final_hidden_states, shared_hidden_states
1500
+ else:
1501
+ return final_hidden_states
1502
+
1503
+ # ----------------------------------------- TBO-related --------------------------------------------
1504
+
1505
+ def _forward_ms_fused_moe_comp(
1506
+ self,
1507
+ hidden_states: torch.Tensor,
1508
+ router_logits: torch.Tensor,
1509
+ is_prefill: bool,
1510
+ real_top_k,
1511
+ enable_force_load_balance: bool = False,
1512
+ ):
1513
+ hidden_states = self.quant_method.apply(
1514
+ layer=self,
1515
+ x=hidden_states,
1516
+ router_logits=router_logits,
1517
+ top_k=real_top_k,
1518
+ renormalize=self.renormalize,
1519
+ use_grouped_topk=self.use_grouped_topk,
1520
+ global_num_experts=self.global_num_experts,
1521
+ expert_map=self.expert_map,
1522
+ topk_group=self.topk_group,
1523
+ num_expert_group=self.num_expert_group,
1524
+ custom_routing_function=self.custom_routing_function,
1525
+ scoring_func=self.scoring_func,
1526
+ e_score_correction_bias=self.e_score_correction_bias,
1527
+ is_prefill=is_prefill,
1528
+ enable_force_load_balance=enable_force_load_balance)
1529
+
1530
+ return hidden_states
inference/vllm_ascend/patch/worker/patch_common/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ # patch_utils should be the first import, because it will be used by other
19
+ # patch files.
20
+ import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
21
+ import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
22
+ import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
23
+ import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
24
+ import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
25
+ import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
26
+ import vllm_ascend.patch.worker.patch_common.patch_config # noqa
27
+ import vllm_ascend.patch.worker.patch_common.patch_parsers # noqa
inference/vllm_ascend/patch/worker/patch_common/patch_config.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from vllm.config import ModelConfig
18
+
19
+
20
+ def get_attr_by_names(src_config, attrs, default_value):
21
+ for attr in attrs:
22
+ value = getattr(src_config, attr, 0)
23
+ if value > 0:
24
+ return value
25
+ return default_value
26
+
27
+
28
+ def _verify_with_expert_parallelism(self) -> None:
29
+ num_expert_names = [
30
+ "moe_num_experts", # Dbrx
31
+ "num_experts", # Jamba
32
+ "n_routed_experts", # DeepSeek
33
+ "num_local_experts", # Mixtral
34
+ "num_routed_experts", # Pangu
35
+ ]
36
+ num_experts = 0
37
+ for name in num_expert_names:
38
+ num_experts = getattr(self.hf_text_config, name, 0)
39
+ if num_experts > 0:
40
+ break
41
+ if num_experts < 1:
42
+ raise ValueError(
43
+ "Number of experts in the model must be greater than 0 "
44
+ "when expert parallelism is enabled.")
45
+
46
+
47
+ @property
48
+ def is_deepseek_mla(self) -> bool:
49
+ kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
50
+ kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, None)
51
+ if not hasattr(self.hf_text_config, "model_type"):
52
+ return False
53
+ elif self.hf_text_config.model_type in \
54
+ ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'pangu_ultra_moe'):
55
+ return kv_lora_dim is not None
56
+ elif self.hf_text_config.model_type == 'eagle':
57
+ # if the model is an EAGLE module, check for the
58
+ # underlying architecture
59
+ return self.hf_text_config.model.model_type in \
60
+ ('deepseek_v2', 'deepseek_v3', 'pangu_ultra_moe') \
61
+ and kv_lora_dim is not None
62
+ return False
63
+
64
+
65
+ def get_head_size(self) -> int:
66
+ if self.is_deepseek_mla:
67
+ qk_rope_dim_names = ['attention_qk_rope_dim', 'qk_rope_head_dim']
68
+ kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
69
+ qk_rope_dim = get_attr_by_names(self.hf_text_config, qk_rope_dim_names, 0)
70
+ kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, 0)
71
+ if self.use_mla:
72
+ return kv_lora_dim + qk_rope_dim
73
+ else:
74
+ qk_dim_names = ['attention_qk_dim', 'qk_nope_head_dim']
75
+ qk_dim = get_attr_by_names(self.hf_text_config, qk_dim_names, 0)
76
+ if qk_rope_dim and qk_dim:
77
+ return qk_rope_dim + qk_dim
78
+ if hasattr(self.hf_text_config,
79
+ "model_type") and (self.hf_text_config.model_type
80
+ == "zamba2"):
81
+ return self.hf_text_config.attention_head_dim
82
+
83
+ if self.is_attention_free:
84
+ return 0
85
+
86
+ # NOTE: Some configs may set head_dim=None in the config
87
+ if getattr(self.hf_text_config, "head_dim", None) is not None:
88
+ return self.hf_text_config.head_dim
89
+
90
+ # FIXME(woosuk): This may not be true for all models.
91
+ return (self.hf_text_config.hidden_size //
92
+ self.hf_text_config.num_attention_heads)
93
+
94
+
95
+ ModelConfig._verify_with_expert_parallelism = _verify_with_expert_parallelism
96
+ ModelConfig.is_deepseek_mla = is_deepseek_mla
97
+ ModelConfig.get_head_size = get_head_size
inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+
19
+ from vllm.entrypoints.openai import tool_parsers
20
+ from vllm_ascend.entrypoints.openai.tool_parsers import PanguToolParser
21
+ tool_parsers.__all__.append("PanguToolParser")
22
+
23
+
24
+ from vllm import reasoning
25
+ from vllm_ascend.entrypoints.openai.reasoning_parsers import PanguReasoningParser
26
+ reasoning.__all__.append("PanguReasoningParser")
inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # This file is a part of the vllm-ascend project.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ #
18
+
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch_npu
23
+ from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
24
+ from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS
25
+ from vllm.v1.sample.metadata import SamplingMetadata
26
+ from vllm_ascend import envs
27
+
28
+
29
+ def apply_top_k_top_p(
30
+ logits: torch.Tensor,
31
+ k: torch.Tensor,
32
+ p: torch.Tensor,
33
+ ) -> torch.Tensor:
34
+ if p is not None and k is not None:
35
+ # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
36
+ return torch_npu.npu_top_k_top_p(logits, p, k)
37
+
38
+ probs = logits.softmax(dim=-1)
39
+ probs_sort, _ = probs.sort(dim=-1, descending=False)
40
+
41
+ if k is not None:
42
+ top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
43
+ top_k_count = top_k_count.unsqueeze(dim=1)
44
+ top_k_cutoff = probs_sort.gather(-1, top_k_count)
45
+
46
+ # Make sure the no top-k rows are no-op.
47
+ no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
48
+ top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
49
+
50
+ elements_to_discard = probs < top_k_cutoff
51
+ logits.masked_fill_(elements_to_discard, -float("inf"))
52
+
53
+ if p is not None:
54
+ cumprob = torch.cumsum(probs_sort, dim=-1)
55
+ top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
56
+ top_p_mask[:, -1] = False # at least one
57
+
58
+ top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
59
+ top_p_cutoff = probs_sort.gather(-1, top_p_count)
60
+ elements_to_discard = probs < top_p_cutoff
61
+ logits.masked_fill_(elements_to_discard, -float("inf"))
62
+
63
+ return logits
64
+
65
+
66
+ def topk_topp_forward_native(
67
+ self,
68
+ logits: torch.Tensor,
69
+ generators: dict[int, torch.Generator],
70
+ k: Optional[torch.Tensor],
71
+ p: Optional[torch.Tensor],
72
+ ) -> torch.Tensor:
73
+ """
74
+ PyTorch-native implementation of top-k and top-p sampling.
75
+
76
+ The logits tensor may be updated in-place.
77
+ """
78
+ logits = apply_top_k_top_p(logits, k, p)
79
+ probs = logits.softmax(dim=-1, dtype=torch.float32)
80
+ return random_sample(probs, generators)
81
+
82
+
83
+ def apply_top_n_sigma(
84
+ logits: torch.Tensor,
85
+ sampling_metadata: SamplingMetadata,
86
+ ):
87
+ if sampling_metadata.no_top_n_sigma:
88
+ return logits
89
+
90
+ top_n_sigma = sampling_metadata.top_n_sigma[:, None]
91
+ top_n_sigma_mask = (top_n_sigma != -1)
92
+ filter_value = -3.4028e+38
93
+ max_vals, _ = logits.max(dim=-1, keepdim=True)
94
+ std_vals = logits.std(dim=-1, keepdim=True)
95
+ threshold = max_vals - top_n_sigma * std_vals
96
+ threshold[~top_n_sigma_mask] = filter_value
97
+ mask = (logits < threshold)
98
+ logits = torch.where(mask, filter_value, logits)
99
+ return logits
100
+
101
+
102
+ def sample(
103
+ self,
104
+ logits: torch.Tensor,
105
+ sampling_metadata: SamplingMetadata,
106
+ ) -> torch.Tensor:
107
+ """Sample logits based on sampling metadata.
108
+
109
+ The various logits processing functions called in this method
110
+ may update the logits tensor in-place.
111
+ """
112
+
113
+ assert not (sampling_metadata.all_greedy
114
+ and sampling_metadata.all_random)
115
+ if sampling_metadata.all_random:
116
+ greedy_sampled = None
117
+ else:
118
+ greedy_sampled = self.greedy_sample(logits)
119
+ if sampling_metadata.all_greedy:
120
+ return greedy_sampled
121
+
122
+ assert sampling_metadata.temperature is not None
123
+
124
+ # Apply temperature.
125
+ logits = self.apply_temperature(logits, sampling_metadata.temperature)
126
+
127
+ # Apply logits processors that only apply to random sampling
128
+ # (argmax invariant)
129
+ for processor in sampling_metadata.logitsprocs.argmax_invariant:
130
+ logits = processor.apply(logits)
131
+
132
+ # Apply top_n_sigma
133
+ logits = apply_top_n_sigma(logits, sampling_metadata)
134
+
135
+ # Apply top_k and/or top_p.
136
+ random_sampled = self.topk_topp_sampler(
137
+ logits,
138
+ sampling_metadata.generators,
139
+ sampling_metadata.top_k,
140
+ sampling_metadata.top_p,
141
+ )
142
+
143
+ if greedy_sampled is None:
144
+ return random_sampled
145
+
146
+ sampled = torch.where(
147
+ sampling_metadata.temperature < _SAMPLING_EPS,
148
+ greedy_sampled,
149
+ random_sampled,
150
+ out=greedy_sampled, # Reuse tensor
151
+ )
152
+ return sampled
153
+
154
+
155
+ if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
156
+ TopKTopPSampler.forward_native = topk_topp_forward_native
157
+
158
+ if envs.VLLM_ASCEND_ENABLE_TOP_N_SIGMA:
159
+ Sampler.sample = sample
inference/vllm_ascend/quantization/w8a8.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from typing import Any, Callable, Dict, Optional
19
+
20
+ import torch
21
+ import torch_npu
22
+ from vllm.attention.backends.abstract import AttentionType
23
+
24
+ from vllm_ascend.attention.attention_v1 import AscendAttentionState
25
+ from vllm_ascend.distributed.parallel_state import get_ep_group
26
+ from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
27
+
28
+
29
+ def quant_per_tensor(in_tensor: torch.Tensor,
30
+ input_scale: torch.Tensor,
31
+ input_offset: torch.Tensor,
32
+ function=False):
33
+ return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
34
+ torch.qint8, -1, function)
35
+
36
+
37
+ class AscendW8A8LinearMethod:
38
+ """Linear method for Ascend W8A8.
39
+
40
+ Args:
41
+ w_sym: whether the linear weight is symmetrically quantized.
42
+ """
43
+
44
+ def __init__(self) -> None:
45
+ # aclnn quant matmul requires to transpose matrix B, set to true by default.
46
+ self.transpose_weight = not is_310p()
47
+
48
+ @staticmethod
49
+ def get_weight(
50
+ input_size: int,
51
+ output_size: int,
52
+ params_dtype: torch.dtype = torch.bfloat16,
53
+ ) -> Dict[str, Any]:
54
+ params_dict = {
55
+ "weight": torch.empty(output_size, input_size, dtype=torch.int8)
56
+ }
57
+ return params_dict
58
+
59
+ @staticmethod
60
+ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
61
+ params_dict = {}
62
+ params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
63
+ params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
64
+ return params_dict
65
+
66
+ @staticmethod
67
+ def get_perchannel_param(
68
+ output_size: int,
69
+ params_dtype: torch.dtype,
70
+ ) -> Dict[str, Any]:
71
+ params_dict = {}
72
+ params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
73
+ if params_dtype == torch.bfloat16:
74
+ params_dict["deq_scale"] = torch.empty(output_size,
75
+ dtype=torch.float32)
76
+ elif params_dtype == torch.float16:
77
+ params_dict["deq_scale"] = torch.empty(output_size,
78
+ dtype=torch.int64)
79
+ params_dict["weight_scale"] = torch.empty(output_size,
80
+ 1,
81
+ dtype=params_dtype)
82
+ params_dict["weight_offset"] = torch.empty(output_size,
83
+ 1,
84
+ dtype=params_dtype)
85
+ return params_dict
86
+
87
+ @staticmethod
88
+ def apply(
89
+ layer: torch.nn.Module,
90
+ x: torch.Tensor,
91
+ bias: Optional[torch.Tensor] = None,
92
+ tp_rank: Optional[int] = 0,
93
+ ) -> torch.Tensor:
94
+ original_dtype = x.dtype
95
+ if original_dtype != torch.int8:
96
+ x = quant_per_tensor(x, layer.aclnn_input_scale,
97
+ layer.aclnn_input_offset)
98
+ quant_bias = layer.quant_bias if tp_rank == 0 else None
99
+ if is_310p():
100
+ # On 300I Duo platform, we need transpose again if
101
+ # using nz. This transpose can be skipped in torchair.
102
+ output = torch_npu.npu_quant_matmul(
103
+ x,
104
+ layer.weight.data.transpose(1, 0),
105
+ layer.deq_scale,
106
+ bias=quant_bias,
107
+ output_dtype=original_dtype,
108
+ )
109
+ else:
110
+ output = torch_npu.npu_quant_matmul(
111
+ x,
112
+ layer.weight,
113
+ layer.deq_scale,
114
+ bias=quant_bias,
115
+ output_dtype=original_dtype,
116
+ )
117
+ return output
118
+
119
+ def process_weights_after_loading(self, layer):
120
+ expanding_factor = layer.weight.data.shape[1]
121
+ layer.aclnn_input_scale = 1 / torch.nn.Parameter(
122
+ layer.input_scale.data.repeat(expanding_factor),
123
+ requires_grad=False)
124
+ layer.aclnn_input_offset = torch.nn.Parameter(
125
+ layer.input_offset.data.repeat(expanding_factor),
126
+ requires_grad=False).to(layer.aclnn_input_scale.dtype)
127
+ if self.transpose_weight:
128
+ layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
129
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
130
+ ACL_FORMAT_FRACTAL_NZ)
131
+ layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
132
+ layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
133
+
134
+
135
+ class AscendW8A8FusedMoEMethod:
136
+ """FusedMoe method for Ascend W8A8.
137
+ """
138
+
139
+ def __init__(self):
140
+ self.transpose_weight = True
141
+
142
+ @staticmethod
143
+ def get_weight(num_experts: int, intermediate_size_per_partition: int,
144
+ hidden_sizes: int,
145
+ params_dtype: torch.dtype) -> Dict[str, Any]:
146
+ param_dict = {}
147
+ param_dict["w13_weight"] = torch.empty(num_experts,
148
+ 2 *
149
+ intermediate_size_per_partition,
150
+ hidden_sizes,
151
+ dtype=torch.int8,
152
+ requires_grad=False)
153
+ param_dict["w2_weight"] = torch.empty(num_experts,
154
+ hidden_sizes,
155
+ intermediate_size_per_partition,
156
+ dtype=torch.int8,
157
+ requires_grad=False)
158
+ return param_dict
159
+
160
+ @staticmethod
161
+ def get_dynamic_quant_param(num_experts: int,
162
+ intermediate_size_per_partition: int,
163
+ hidden_sizes: int,
164
+ params_dtype: torch.dtype) -> Dict[str, Any]:
165
+ param_dict = {}
166
+ param_dict["w13_weight_scale"] = torch.empty(
167
+ num_experts,
168
+ 2 * intermediate_size_per_partition,
169
+ 1,
170
+ dtype=torch.float32)
171
+ param_dict["w13_weight_offset"] = torch.empty(
172
+ num_experts,
173
+ 2 * intermediate_size_per_partition,
174
+ 1,
175
+ dtype=torch.float16)
176
+ param_dict["w2_weight_scale"] = torch.empty(num_experts,
177
+ hidden_sizes,
178
+ 1,
179
+ dtype=torch.float32)
180
+ param_dict["w2_weight_offset"] = torch.empty(num_experts,
181
+ hidden_sizes,
182
+ 1,
183
+ dtype=torch.float16)
184
+ param_dict["w2_deq_scale"] = torch.empty(num_experts,
185
+ hidden_sizes,
186
+ dtype=torch.float32)
187
+ param_dict["w13_deq_scale"] = torch.empty(
188
+ num_experts,
189
+ 2 * intermediate_size_per_partition,
190
+ dtype=torch.float32)
191
+ param_dict["w2_input_scale"] = torch.empty(num_experts,
192
+ 1,
193
+ dtype=torch.float32)
194
+ param_dict["w13_input_scale"] = torch.empty(num_experts,
195
+ 1,
196
+ dtype=torch.float32)
197
+ param_dict["w2_input_offset"] = torch.empty(num_experts,
198
+ 1,
199
+ dtype=torch.int8)
200
+ param_dict["w13_input_offset"] = torch.empty(num_experts,
201
+ 1,
202
+ dtype=torch.int8)
203
+ param_dict["quant_bias"] = torch.empty(num_experts,
204
+ hidden_sizes,
205
+ dtype=torch.int32)
206
+
207
+ return param_dict
208
+
209
+ def apply(
210
+ self,
211
+ layer: torch.nn.Module,
212
+ x: torch.Tensor,
213
+ router_logits: torch.Tensor,
214
+ top_k: int,
215
+ renormalize: bool,
216
+ use_grouped_topk: bool = False,
217
+ global_num_experts: int = -1,
218
+ expert_map: Optional[torch.Tensor] = None,
219
+ topk_group: Optional[int] = None,
220
+ num_expert_group: Optional[int] = None,
221
+ custom_routing_function: Optional[Callable] = None,
222
+ scoring_func: str = "softmax",
223
+ e_score_correction_bias: Optional[torch.Tensor] = None,
224
+ is_prefill: bool = True,
225
+ enable_force_load_balance: bool = False,
226
+ log2phy: torch.Tensor = None,
227
+ global_redundant_expert_num: int = 0,
228
+ shared_experts: Optional[Any] = None,
229
+ **kwargs,
230
+ ) -> torch.Tensor:
231
+ assert router_logits.shape[
232
+ 1] == global_num_experts, "Number of global experts mismatch"
233
+
234
+ topk_weights, topk_ids = select_experts(
235
+ hidden_states=x,
236
+ router_logits=router_logits,
237
+ top_k=top_k,
238
+ use_grouped_topk=use_grouped_topk,
239
+ renormalize=renormalize,
240
+ topk_group=topk_group,
241
+ num_expert_group=num_expert_group,
242
+ custom_routing_function=custom_routing_function,
243
+ scoring_func=scoring_func,
244
+ e_score_correction_bias=e_score_correction_bias,
245
+ global_num_experts=global_num_experts,
246
+ )
247
+
248
+ if is_310p():
249
+ return fused_experts_310p(hidden_states=x,
250
+ w1=layer.w13_weight,
251
+ w1_scale=layer.w13_weight_scale,
252
+ w1_input_scale=layer.w13_input_scale,
253
+ w2=layer.w2_weight,
254
+ w2_scale=layer.w2_weight_scale,
255
+ w2_input_scale=layer.w2_input_scale,
256
+ topk_weights=topk_weights,
257
+ topk_ids=topk_ids,
258
+ top_k=top_k,
259
+ global_num_experts=global_num_experts,
260
+ expert_map=expert_map)
261
+ return fused_experts(hidden_states=x,
262
+ w1=layer.w13_weight,
263
+ w1_scale=layer.w13_weight_scale,
264
+ w1_input_scale=layer.w13_input_scale,
265
+ w1_input_offset=layer.w13_input_offset,
266
+ w2=layer.w2_weight,
267
+ w2_scale=layer.w2_weight_scale,
268
+ w2_input_scale=layer.w2_input_scale,
269
+ w2_input_offset=layer.w2_input_offset,
270
+ topk_weights=topk_weights,
271
+ topk_ids=topk_ids,
272
+ top_k=top_k,
273
+ global_num_experts=global_num_experts,
274
+ expert_map=expert_map)
275
+
276
+ def process_weights_after_loading(self, layer):
277
+ if not is_310p():
278
+ layer.w13_weight.data = layer.w13_weight.data.transpose(
279
+ 1, 2).contiguous()
280
+ layer.w2_weight.data = layer.w2_weight.data.transpose(
281
+ 1, 2).contiguous()
282
+ layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
283
+ layer.w13_weight_scale.data.shape[0], -1)
284
+
285
+ layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
286
+ layer.w13_weight_offset.data.shape[0], -1)
287
+ layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
288
+ layer.w2_weight_scale.data.shape[0], -1)
289
+ layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
290
+ layer.w2_weight_offset.data.shape[0], -1)
291
+ expanding_factor_w13 = layer.w13_weight.data.shape[1]
292
+ expanding_factor_w2 = layer.w2_weight.data.shape[1]
293
+
294
+ if is_310p():
295
+ layer.w13_input_scale.data = torch.nn.Parameter(
296
+ layer.w13_input_scale.data.max())
297
+ layer.w2_input_scale.data = torch.nn.Parameter(
298
+ layer.w2_input_scale.data.max())
299
+ else:
300
+ layer.w13_input_scale.data = torch.nn.Parameter(
301
+ layer.w13_input_scale.data.repeat(1,
302
+ expanding_factor_w13)[0:1])
303
+ layer.w2_input_scale.data = torch.nn.Parameter(
304
+ layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
305
+
306
+ layer.w13_input_offset.data = torch.nn.Parameter(
307
+ layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
308
+ layer.w2_input_offset.data = torch.nn.Parameter(
309
+ layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
310
+
311
+ # converting ACL_FORMAT_FRACTAL_NZ.
312
+ # npu_quant_grouped_matmul_dequant in eager mode does not accept
313
+ # ACL_FORMAT_FRACTAL_NZ.
314
+ if not is_310p():
315
+ layer.w13_weight.data = torch_npu.npu_format_cast(
316
+ layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
317
+ layer.w2_weight.data = torch_npu.npu_format_cast(
318
+ layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
319
+
320
+
321
+ class AscendC8KVCacheMethod:
322
+
323
+ def __init__(self) -> None:
324
+ self.antiquant_scale_comb = None
325
+
326
+ @staticmethod
327
+ def create_weights(layer) -> None:
328
+ param_dict = {} # num_kv_heads * head_size
329
+ param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
330
+ layer.head_size,
331
+ dtype=torch.float16,
332
+ requires_grad=False)
333
+ param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
334
+ layer.head_size,
335
+ dtype=torch.float16,
336
+ requires_grad=False)
337
+ for weight_name, weight_param in param_dict.items():
338
+ param = torch.nn.Parameter(weight_param, requires_grad=False)
339
+ layer.register_parameter(weight_name, param)
340
+
341
+ def process_weights_after_loading(self, layer):
342
+ self.antiquant_scale_comb = torch.cat(
343
+ (layer.key_antiquant_scale.data.unsqueeze(0),
344
+ layer.value_antiquant_scale.data.unsqueeze(0)),
345
+ dim=0).to(torch.float16).contiguous()
346
+
347
+ def apply(self, layer, query, key, value, kv_cache, attn_metadata,
348
+ attn_type, scale, output) -> torch.Tensor:
349
+ num_tokens = query.shape[0]
350
+ if attn_metadata is None:
351
+ return output.view(num_tokens, layer.num_heads * layer.head_size)
352
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
353
+ if attn_type != AttentionType.DECODER:
354
+ raise NotImplementedError("Encoder self-attention and "
355
+ "encoder/decoder cross-attention "
356
+ "are not implemented for "
357
+ "PallasAttentionBackendImpl")
358
+
359
+ # C8
360
+ quant_key = quant_per_tensor(
361
+ key.view(-1, layer.num_kv_heads * layer.head_size),
362
+ layer.key_antiquant_scale.data.view(-1), None, True)
363
+ quant_value = quant_per_tensor(
364
+ value.view(-1, layer.num_kv_heads * layer.head_size),
365
+ layer.value_antiquant_scale.data.view(-1), None, True)
366
+
367
+ # View q k v to BSH.
368
+ query = query.view(-1, layer.num_heads, layer.head_size)
369
+ key = key.view(-1, layer.num_kv_heads, layer.head_size)
370
+ value = value.view(-1, layer.num_kv_heads, layer.head_size)
371
+ # TODO: Remove this contiguous in the future.
372
+ value = value.contiguous()
373
+
374
+ if kv_cache[0].numel() > 0:
375
+ # if key_cache is None:
376
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
377
+ slots = attn_metadata.slot_mapping
378
+
379
+ block_size = key_cache.shape[1]
380
+ slots_indices = slots.reshape(-1, 1)
381
+ block_indices = slots_indices // block_size
382
+ slots_indices = slots_indices % block_size
383
+ indices = torch.cat((block_indices, slots_indices), dim=1)
384
+
385
+ # C8
386
+ torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
387
+ torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
388
+
389
+ # V0-Style scheduler situation.
390
+ if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
391
+ assert attn_metadata is not None
392
+ assert attn_metadata.attn_mask is not None
393
+ mask = attn_metadata.attn_mask
394
+ torch_npu._npu_flash_attention(query=query,
395
+ key=key,
396
+ value=value,
397
+ mask=mask,
398
+ seq_len=attn_metadata.seq_lens,
399
+ scale_value=scale,
400
+ num_heads=layer.num_heads,
401
+ num_kv_heads=layer.num_kv_heads,
402
+ out=output.reshape(query.shape))
403
+
404
+ elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
405
+ raise NotImplementedError("kv cache int8 are not "
406
+ "implemented for "
407
+ "PrefillCacheHit")
408
+ elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
409
+ if hasattr(attn_metadata, "decode"):
410
+ # torch_air
411
+ decode_meta = attn_metadata.decode
412
+ seq_lens = decode_meta.seq_lens_list
413
+ else:
414
+ seq_lens = attn_metadata.seq_lens
415
+ block_size = key_cache.shape[1]
416
+ query = query.view(num_tokens, 1, layer.num_heads *
417
+ layer.head_size).contiguous() # changed
418
+
419
+ # [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
420
+ key = key_cache
421
+ value = value_cache
422
+
423
+ output = torch_npu.npu_incre_flash_attention(
424
+ query,
425
+ key,
426
+ value,
427
+ num_key_value_heads=layer.num_kv_heads,
428
+ num_heads=layer.num_heads,
429
+ actual_seq_lengths=seq_lens,
430
+ scale_value=scale,
431
+ input_layout='BSH',
432
+ block_size=block_size,
433
+ block_table=attn_metadata.block_tables,
434
+ antiquant_scale=self.antiquant_scale_comb,
435
+ )
436
+
437
+ # Normal V1 situation.
438
+ else:
439
+ raise NotImplementedError("kv cache int8 are not "
440
+ "implemented for "
441
+ "other case")
442
+ return output
443
+
444
+
445
+ def fused_experts_310p(
446
+ hidden_states: torch.Tensor,
447
+ w1: torch.Tensor,
448
+ w1_scale: torch.Tensor,
449
+ w1_input_scale: torch.Tensor,
450
+ w2: torch.Tensor,
451
+ w2_scale: torch.Tensor,
452
+ w2_input_scale: torch.Tensor,
453
+ topk_weights: torch.Tensor,
454
+ topk_ids: torch.Tensor,
455
+ top_k: int,
456
+ global_num_experts: int,
457
+ expert_map: torch.Tensor = None,
458
+ ) -> torch.Tensor:
459
+ ep_size = get_ep_group().world_size
460
+ local_num_experts = global_num_experts // ep_size
461
+ local_num_group = top_k // ep_size
462
+
463
+ bsz, _ = hidden_states.shape
464
+ flatten_topk_ids = topk_ids.view(-1)
465
+ sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
466
+ sorted_topk_ids = sorted_topk_ids.to(torch.int32)
467
+ sorted_hidden_states = hidden_states.index_select(
468
+ 0, sorted_topk_ids // local_num_group)
469
+
470
+ experts_id = torch.arange(0,
471
+ local_num_experts,
472
+ dtype=topk_ids.dtype,
473
+ device=topk_ids.device)
474
+ num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
475
+ torch.float32).sum(0)
476
+ topk_scales = topk_weights.view(-1).index_select(
477
+ 0, sorted_topk_ids).unsqueeze(-1)
478
+ group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
479
+
480
+ gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
481
+ x=sorted_hidden_states,
482
+ quantized_weight=w1,
483
+ weight_scale=w1_scale,
484
+ group_list=group_list,
485
+ x_scale=w1_input_scale,
486
+ quant_mode="pertensor")
487
+
488
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
489
+ torch.float16)
490
+ gate_up_out *= topk_scales
491
+
492
+ down_out = torch_npu.npu_quant_grouped_matmul_dequant(
493
+ x=gate_up_out,
494
+ quantized_weight=w2,
495
+ weight_scale=w2_scale,
496
+ group_list=group_list,
497
+ x_scale=w2_input_scale,
498
+ quant_mode="pertensor")
499
+
500
+ unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
501
+ unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
502
+ final_hidden_states = unsorted_hidden_states.reshape(
503
+ bsz, top_k // ep_size, -1).sum(1)
504
+
505
+ return final_hidden_states
506
+
507
+
508
+ def fused_experts(
509
+ hidden_states: torch.Tensor,
510
+ w1: torch.Tensor,
511
+ w1_scale: torch.Tensor,
512
+ w1_input_scale: torch.Tensor,
513
+ w1_input_offset: torch.Tensor,
514
+ w2: torch.Tensor,
515
+ w2_scale: torch.Tensor,
516
+ w2_input_scale: torch.Tensor,
517
+ w2_input_offset: torch.Tensor,
518
+ topk_weights: torch.Tensor,
519
+ topk_ids: torch.Tensor,
520
+ top_k: int,
521
+ global_num_experts: int,
522
+ expert_map: torch.Tensor = None,
523
+ ) -> torch.Tensor:
524
+ """
525
+ Fused experts with top-k routing.
526
+
527
+ Args:
528
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
529
+ w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
530
+ w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
531
+ topk_weights: Routing weights of shape (num_tokens, top_k).
532
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
533
+ top_k: Number of experts to select.
534
+ expert_map: Expert mapping of shape (num_experts,).
535
+
536
+ Returns:
537
+ hidden_states: Hidden states after routing.
538
+ """
539
+ """
540
+ # Check constraints.
541
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
542
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
543
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
544
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
545
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
546
+ """
547
+
548
+ original_dtype = hidden_states.dtype
549
+ ep_size = get_ep_group().world_size
550
+ local_num_experts = global_num_experts // ep_size
551
+ w1_input_scale, _ = w1_input_scale.max(0)
552
+ quant_sorted_hidden_states = quant_per_tensor(
553
+ hidden_states,
554
+ w1_input_scale,
555
+ None,
556
+ True,
557
+ )
558
+ if expert_map is not None:
559
+ expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
560
+ quant_sorted_hidden_states,
561
+ topk_ids,
562
+ scale=None,
563
+ active_num=topk_ids.numel(),
564
+ expert_capacity=-1,
565
+ expert_num=local_num_experts,
566
+ drop_pad_mode=0,
567
+ expert_tokens_num_type=1,
568
+ expert_tokens_num_flag=True,
569
+ quant_mode=-1,
570
+ active_expert_range=[0, local_num_experts],
571
+ row_idx_type=0,
572
+ )
573
+
574
+ else:
575
+ raise NotImplementedError(
576
+ "The quantified version of MOE class models "
577
+ "currently does not support tensor parallelism")
578
+ if expanded_x.dtype != w1.dtype:
579
+ w1_input_scale, _ = w1_input_scale.max(0)
580
+ quant_sorted_hidden_states = quant_per_tensor(
581
+ expanded_x,
582
+ w1_input_scale,
583
+ None,
584
+ True,
585
+ )
586
+ else:
587
+ quant_sorted_hidden_states = expanded_x
588
+ gate_up_out = torch_npu.npu_grouped_matmul(
589
+ x=[quant_sorted_hidden_states],
590
+ weight=[w1],
591
+ scale=[w1_scale * w1_input_scale[0]],
592
+ split_item=2,
593
+ group_list_type=1,
594
+ group_type=0,
595
+ group_list=expert_token_count,
596
+ output_dtype=original_dtype,
597
+ )[0]
598
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
599
+
600
+ if gate_up_out.dtype != w2.dtype:
601
+ w2_input_scale, _ = w2_input_scale.max(0)
602
+ quant_gate_up_out = quant_per_tensor(
603
+ gate_up_out,
604
+ w2_input_scale,
605
+ None,
606
+ True,
607
+ )
608
+ else:
609
+ quant_gate_up_out = gate_up_out
610
+
611
+ down_out = torch_npu.npu_grouped_matmul(
612
+ x=[quant_gate_up_out],
613
+ weight=[w2],
614
+ scale=[w2_scale * w2_input_scale[0]],
615
+ split_item=2,
616
+ group_list_type=1,
617
+ group_type=0,
618
+ group_list=expert_token_count,
619
+ output_dtype=original_dtype,
620
+ )[0]
621
+
622
+ if expert_map is not None:
623
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
624
+ down_out,
625
+ skip1=None,
626
+ skip2=None,
627
+ bias=None,
628
+ scales=topk_weights.to(down_out.dtype),
629
+ expanded_src_to_dst_row=expanded_row_idx,
630
+ export_for_source_row=topk_ids,
631
+ drop_pad_mode=2,
632
+ )
633
+ else:
634
+ raise NotImplementedError(
635
+ "The quantified version of MOE class models "
636
+ "currently does not support tensor parallelism")
637
+
638
+ return final_hidden_states
639
+
640
+
641
+ def select_experts(
642
+ hidden_states: torch.Tensor,
643
+ router_logits: torch.Tensor,
644
+ top_k: int,
645
+ use_grouped_topk: bool,
646
+ renormalize: bool,
647
+ topk_group: Optional[int] = None,
648
+ num_expert_group: Optional[int] = None,
649
+ custom_routing_function: Optional[Callable] = None,
650
+ scoring_func: str = "softmax",
651
+ e_score_correction_bias: Optional[torch.Tensor] = None,
652
+ global_num_experts=-1,
653
+ ) -> tuple[torch.Tensor, torch.Tensor]:
654
+ """
655
+ Select top-k experts based on router logits.
656
+
657
+ Args:
658
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
659
+ router_logits: Router logits of shape (num_tokens, num_experts).
660
+ top_k: Number of experts to select.
661
+ use_grouped_topk: Whether to group experts before selecting top-k.
662
+ renormalize: Whether to renormalize the routing weights.
663
+ topk_group: Number of expert groups to select from.
664
+ num_expert_group: Number of experts in each group.
665
+ custom_routing_function: Custom routing function.
666
+ scoring_func: Scoring function to use.
667
+ e_score_correction_bias: Correction bias to apply to expert scores.
668
+
669
+ Returns:
670
+ topk_weights: Routing weights of shape (num_tokens, top_k).
671
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
672
+
673
+ Raises:
674
+ ValueError: If an unsupported scoring function is provided.
675
+ """
676
+
677
+ if scoring_func == "softmax":
678
+ # NOTE: vLLM use dtype=torch.float here
679
+ topk_weights = router_logits.softmax(dim=-1)
680
+ elif scoring_func == "sigmoid":
681
+ topk_weights = router_logits.sigmoid()
682
+ else:
683
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
684
+
685
+ if use_grouped_topk:
686
+ assert topk_group is not None
687
+ assert num_expert_group is not None
688
+
689
+ if e_score_correction_bias is not None:
690
+ # Store original scores before applying correction bias. We use biased
691
+ # scores for expert selection but original scores for routing weights
692
+ original_weights = topk_weights
693
+ topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
694
+
695
+ # TODO: Change to npu_group_topk when the latest CANN and NNAL is available
696
+ # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
697
+ topk_weights = native_grouped_topk(topk_weights, num_expert_group,
698
+ topk_group)
699
+ # TODO bfloat16 is not supported in torch.topk with ge graph.
700
+ if e_score_correction_bias is not None:
701
+ topk_ids = torch.topk(topk_weights.to(torch.float32),
702
+ k=top_k,
703
+ dim=-1,
704
+ sorted=False)[1]
705
+ # Use original unbiased scores for the routing weights
706
+ topk_weights = original_weights.gather(1, topk_ids)
707
+ else:
708
+ topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
709
+ k=top_k,
710
+ dim=-1,
711
+ sorted=False)
712
+ elif custom_routing_function is None:
713
+ topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
714
+ else:
715
+ topk_weights, topk_ids = custom_routing_function(
716
+ hidden_states=hidden_states,
717
+ gating_output=router_logits,
718
+ topk=top_k,
719
+ renormalize=renormalize,
720
+ global_num_experts=global_num_experts,
721
+ )
722
+ # Required by npu_moe_init_routing
723
+ topk_ids = topk_ids.to(torch.int32)
724
+ return topk_weights, topk_ids
725
+
726
+ # Required by npu_moe_init_routing
727
+ topk_ids = topk_ids.to(torch.int32)
728
+
729
+ if renormalize:
730
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
731
+
732
+ return topk_weights, topk_ids
733
+
734
+
735
+ def native_grouped_topk(
736
+ topk_weights: torch.Tensor,
737
+ num_expert_group: Optional[int],
738
+ topk_group: Optional[int],
739
+ ):
740
+ topk_group = 0 if topk_group is None else topk_group
741
+ num_expert_group = 0 if num_expert_group is None else num_expert_group
742
+
743
+ num_token = topk_weights.shape[0]
744
+ grouped_weights = topk_weights.view(num_token, num_expert_group,
745
+ -1).max(dim=-1).values
746
+ topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
747
+ k=topk_group,
748
+ dim=-1,
749
+ sorted=False)[1]
750
+ topk_group_mask = torch.zeros_like(grouped_weights)
751
+ topk_group_mask.scatter_(1, topk_group_indices, 1)
752
+ topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
753
+ num_token, num_expert_group,
754
+ topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
755
+ topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
756
+
757
+ return topk_weights
inference/vllm_ascend/quantization/w8a8_dynamic.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ import torch_npu
23
+ from vllm.distributed import GroupCoordinator
24
+
25
+ import vllm_ascend.envs as envs
26
+ from vllm_ascend.ascend_config import get_ascend_config
27
+ from vllm_ascend.distributed.parallel_state import get_ep_group
28
+ from vllm_ascend.ops.fused_moe import select_experts
29
+ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
30
+ dispose_tensor, get_fused_moe_state,
31
+ npu_stream_switch, npu_wait_tensor)
32
+
33
+
34
+ def apply_mlp(hidden_states: torch.Tensor,
35
+ w1: torch.Tensor,
36
+ w1_scale: torch.Tensor,
37
+ w2: torch.Tensor,
38
+ w2_scale: torch.Tensor,
39
+ group_list: torch.Tensor,
40
+ dynamic_scale: torch.Tensor = None,
41
+ group_list_type: int = 1) -> torch.Tensor:
42
+ """
43
+ apply MLP: gate_up_proj -> swiglu -> down_proj
44
+
45
+ Args:
46
+ hidden_states: input hidden states with shape (num_tokens, hidden_size).
47
+ w1: expert weights1 with shape
48
+ (num_experts, hidden_size, intermediate_size * 2)
49
+ w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
50
+ w2: expert weights2 with shape
51
+ (num_experts, intermediate_size, hidden_size)
52
+ w2_scale: weights2 scale with shape (num_experts, hidden_size)
53
+ group_list: number of tokens for each expert, follow cumsum mode, and
54
+ with shape (num_experts).
55
+ transpose_weight:
56
+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
57
+ (num_experts, hidden_size, intermediate_size * 2)
58
+ w2: (num_experts, hidden_size, intermediate_size) ->
59
+ (num_experts, intermediate_size, hidden_size)
60
+
61
+ Returns:
62
+ hidden_states: output hidden states after MLP.
63
+ """
64
+
65
+ if dynamic_scale is None:
66
+ unquantized_hidden_states = hidden_states
67
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
68
+ hidden_states)
69
+ # Dispose the original unquantized hidden states
70
+ # to save npu memory because they're no longer used.
71
+ dispose_tensor(unquantized_hidden_states)
72
+ else:
73
+ pertoken_scale = dynamic_scale
74
+
75
+ # gmm1: gate_up_proj
76
+ hidden_states = torch_npu.npu_grouped_matmul(
77
+ x=[hidden_states],
78
+ weight=[w1],
79
+ scale=[w1_scale],
80
+ per_token_scale=[pertoken_scale],
81
+ split_item=2,
82
+ group_list_type=group_list_type,
83
+ group_type=0,
84
+ group_list=group_list,
85
+ output_dtype=w2_scale.dtype)[0]
86
+
87
+ # act_fn: swiglu
88
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
89
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
90
+ hidden_states)
91
+
92
+ # gmm2: down_proj
93
+ hidden_states = torch_npu.npu_grouped_matmul(
94
+ x=[hidden_states],
95
+ weight=[w2],
96
+ scale=[w2_scale],
97
+ per_token_scale=[swiglu_out_scale],
98
+ split_item=2,
99
+ group_list_type=group_list_type,
100
+ group_type=0,
101
+ group_list=group_list,
102
+ output_dtype=w2_scale.dtype)[0]
103
+
104
+ return hidden_states
105
+
106
+
107
+ def fused_experts_with_mc2(
108
+ hidden_states: torch.Tensor,
109
+ w1: torch.Tensor,
110
+ w2: torch.Tensor,
111
+ w1_scale: torch.Tensor,
112
+ w2_scale: torch.Tensor,
113
+ topk_weights: torch.Tensor,
114
+ topk_ids: torch.Tensor,
115
+ top_k: int,
116
+ expert_map: torch.Tensor = None,
117
+ moe_all_to_all_group_name: str = "",
118
+ log2phy: torch.Tensor = None,
119
+ global_redundant_expert_num: int = 0,
120
+ shared_experts: Optional[Any] = None,
121
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122
+ if log2phy is not None:
123
+ topk_ids = log2phy[topk_ids]
124
+ global_bs = 0
125
+ moe_expert_num = len(expert_map) + global_redundant_expert_num
126
+ # hidden_states = hidden_states.bfloat16()
127
+ kwargs_mc2 = {
128
+ "x": hidden_states,
129
+ "expert_ids": topk_ids,
130
+ "expert_shard_type": 0,
131
+ "shared_expert_rank_num": 0,
132
+ "moe_expert_num": moe_expert_num,
133
+ "global_bs": global_bs,
134
+ "expert_scales": topk_weights.to(torch.float32),
135
+ }
136
+
137
+ rank = torch.distributed.get_rank()
138
+
139
+ quant_mode = 2
140
+ ep_group = get_ep_group().device_group
141
+ local_rank = torch.distributed.get_rank(group=ep_group)
142
+ all_to_all_group_size = torch.distributed.get_world_size(ep_group)
143
+
144
+ world_size = torch.distributed.get_world_size()
145
+ tp_size = world_size // all_to_all_group_size
146
+ tp_rank = rank % tp_size
147
+
148
+ stage1_kwargs = {
149
+ "scales": None,
150
+ "quant_mode": quant_mode,
151
+ "group_ep": moe_all_to_all_group_name,
152
+ "ep_world_size": all_to_all_group_size,
153
+ "ep_rank_id": local_rank,
154
+ # "group_tp": self.moe_rs_group_name,
155
+ "group_tp": moe_all_to_all_group_name,
156
+ "tp_world_size": tp_size,
157
+ "tp_rank_id": tp_rank,
158
+ }
159
+ kwargs_mc2.update(stage1_kwargs)
160
+
161
+ output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
162
+ # comm_stream.wait_stream(torch.npu.current_stream())
163
+ expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
164
+ 0:7]
165
+
166
+ if shared_experts is not None:
167
+ with npu_stream_switch("moe_secondary", 0):
168
+ npu_wait_tensor(hidden_states, topk_weights)
169
+ shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
170
+ npu_wait_tensor(shared_gate_up[0], expand_x)
171
+ shared_act = shared_experts.act_fn(shared_gate_up)
172
+
173
+ # `expand_x` will be disposed in the `apply_mlp` function
174
+ down_out_list = apply_mlp(expand_x,
175
+ w1,
176
+ w1_scale,
177
+ w2,
178
+ w2_scale,
179
+ expert_token_nums,
180
+ dynamic_scale=dynamic_scale)
181
+
182
+ # moeCombine
183
+ kwargs_mc2 = {
184
+ "expand_x": down_out_list,
185
+ "expert_ids": topk_ids,
186
+ "expand_idx": expand_idx,
187
+ "expert_scales": topk_weights.to(torch.float32),
188
+ "expert_shard_type": 0,
189
+ "shared_expert_rank_num": 0,
190
+ "moe_expert_num": moe_expert_num,
191
+ "global_bs": 0,
192
+ "expand_scales": expand_scales,
193
+ }
194
+ tp_recv_counts = torch.empty(1,
195
+ dtype=torch.int32,
196
+ device=hidden_states.device)
197
+ stage3_kwargs = {
198
+ "ep_send_counts": ep_recv_counts,
199
+ "group_ep": moe_all_to_all_group_name,
200
+ "ep_world_size": all_to_all_group_size,
201
+ "ep_rank_id": local_rank,
202
+ "tp_send_counts": tp_recv_counts,
203
+ # "group_tp": self.moe_rs_group_name,
204
+ "group_tp": moe_all_to_all_group_name,
205
+ "tp_world_size": tp_size,
206
+ "tp_rank_id": tp_rank,
207
+ }
208
+ kwargs_mc2.update(stage3_kwargs)
209
+
210
+ hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
211
+
212
+ if shared_experts is None:
213
+ return hidden_states
214
+ else:
215
+ with npu_stream_switch("moe_secondary", 0):
216
+ npu_wait_tensor(shared_act[0], down_out_list)
217
+ shared_output, _ = shared_experts.down_proj(shared_act)
218
+ return hidden_states, shared_output
219
+
220
+
221
+ # currently expert parallelism implemented with all2all
222
+ # is under-optimized.
223
+ def fused_experts_with_all2all(
224
+ hidden_states: torch.Tensor,
225
+ w1: torch.Tensor,
226
+ w1_scale: torch.Tensor,
227
+ w2: torch.Tensor,
228
+ w2_scale: torch.Tensor,
229
+ topk_weights: torch.Tensor,
230
+ topk_ids: torch.Tensor,
231
+ top_k: int,
232
+ expert_map: torch.Tensor = None,
233
+ ep_group: GroupCoordinator = None,
234
+ log2phy: torch.Tensor = None,
235
+ global_redundant_expert_num: int = 0,
236
+ ):
237
+ if log2phy is not None:
238
+ topk_ids = log2phy[topk_ids]
239
+ original_shape = hidden_states.shape
240
+ if len(original_shape) == 3:
241
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
242
+
243
+ num_tokens, _ = hidden_states.shape
244
+ num_experts = w1.shape[0]
245
+ device = hidden_states.device
246
+
247
+ if expert_map is not None:
248
+ global_num_experts = len(expert_map) + global_redundant_expert_num
249
+ local_num_experts = global_num_experts // ep_group.world_size
250
+ row_idx_len = num_tokens * top_k
251
+ row_idx = (torch.arange(0,
252
+ row_idx_len,
253
+ dtype=torch.int32,
254
+ device=device).view(top_k, -1).permute(
255
+ 1, 0).contiguous())
256
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
257
+ hidden_states,
258
+ row_idx=row_idx,
259
+ expert_idx=topk_ids,
260
+ active_num=num_tokens)
261
+
262
+ global_expert_tokens = torch.bincount(expanded_expert_idx,
263
+ minlength=global_num_experts)
264
+ scatter_sizes = global_expert_tokens.view(ep_group.world_size,
265
+ -1).sum(-1)
266
+
267
+ gather_sizes = torch.empty_like(scatter_sizes)
268
+ dist.all_to_all_single(gather_sizes,
269
+ scatter_sizes,
270
+ group=ep_group.device_group)
271
+ scatter_size_list = scatter_sizes.cpu().tolist()
272
+ gather_size_list = gather_sizes.cpu().tolist()
273
+
274
+ expanded_expert_idx = expanded_expert_idx % local_num_experts
275
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
276
+ scatter_size_list,
277
+ gather_size_list)
278
+ local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
279
+ scatter_size_list,
280
+ gather_size_list)
281
+
282
+ sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
283
+
284
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
285
+ sorted_local_expert_idx, local_num_experts).to(torch.int64)
286
+
287
+ hidden_states = hidden_states[sorted_idx]
288
+ group_list_type = 0
289
+ else:
290
+ row_idx_len = num_tokens * top_k
291
+ row_idx = torch.arange(0,
292
+ row_idx_len,
293
+ dtype=torch.int32,
294
+ device=topk_weights.device).view(
295
+ top_k, -1).permute(1, 0).contiguous()
296
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
297
+ hidden_states,
298
+ row_idx=row_idx,
299
+ expert_idx=topk_ids,
300
+ active_num=num_tokens)
301
+
302
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
303
+ expanded_expert_idx, num_experts)
304
+ expert_tokens = expert_tokens.to(torch.int64)
305
+ group_list_type = 0
306
+
307
+ # `hidden_states` will be disposed in the `apply_mlp` function
308
+ hidden_states = apply_mlp(
309
+ hidden_states,
310
+ w1,
311
+ w1_scale, #17
312
+ w2,
313
+ w2_scale,
314
+ expert_tokens, #16
315
+ group_list_type=group_list_type)
316
+
317
+ if expert_map is not None:
318
+ resorted_idx = torch.argsort(sorted_idx)
319
+ hidden_states = hidden_states[resorted_idx]
320
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
321
+ gather_size_list,
322
+ scatter_size_list)
323
+
324
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
325
+ hidden_states,
326
+ skip1=None,
327
+ skip2=None,
328
+ bias=None,
329
+ scales=topk_weights,
330
+ expanded_src_to_dst_row=expanded_row_idx,
331
+ export_for_source_row=topk_ids,
332
+ )
333
+ else:
334
+ # TODO: Reorder device memory 2 times here, replace the current
335
+ # implementation here when suitable operators become available.
336
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
337
+ hidden_states,
338
+ skip1=None,
339
+ skip2=None,
340
+ bias=None,
341
+ scales=topk_weights,
342
+ expanded_src_to_dst_row=expanded_row_idx,
343
+ export_for_source_row=topk_ids,
344
+ )
345
+ if len(original_shape) == 3:
346
+ final_hidden_states = final_hidden_states.view(original_shape)
347
+ return final_hidden_states
348
+
349
+
350
+ def fused_experts_with_allgather(hidden_states: torch.Tensor,
351
+ w1: torch.Tensor,
352
+ w1_scale: torch.Tensor,
353
+ w2: torch.Tensor,
354
+ w2_scale: torch.Tensor,
355
+ topk_weights: torch.Tensor,
356
+ topk_ids: torch.Tensor,
357
+ top_k: int,
358
+ expert_map: torch.Tensor = None):
359
+ original_shape = hidden_states.shape
360
+ if len(original_shape) == 3:
361
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
362
+ num_tokens = hidden_states.shape[0]
363
+ batch_size, hidden_size = hidden_states.shape
364
+
365
+ ep_group = get_ep_group().device_group
366
+ ep_rank = torch.distributed.get_rank(group=ep_group)
367
+ ep_size = torch.distributed.get_world_size(ep_group)
368
+
369
+ global_num_experts = len(expert_map)
370
+ local_num_experts = global_num_experts // ep_size
371
+
372
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
373
+
374
+ hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
375
+ hidden_states,
376
+ topk_ids,
377
+ scale=pertoken_scale,
378
+ offset=None,
379
+ active_num=num_tokens * top_k,
380
+ expert_num=global_num_experts,
381
+ expert_tokens_num_type=1,
382
+ expert_tokens_num_flag=True,
383
+ active_expert_range=[
384
+ ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
385
+ ],
386
+ quant_mode=-1,
387
+ row_idx_type=0)
388
+ group_list_type = 1
389
+
390
+
391
+ hidden_states = torch_npu.npu_grouped_matmul(
392
+ x=[hidden_states],
393
+ weight=[w1],
394
+ split_item=3,
395
+ group_list_type=group_list_type,
396
+ group_type=0,
397
+ group_list=expert_tokens,
398
+ output_dtype=torch.int32)[0]
399
+
400
+ # act_fn: swiglu
401
+ hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
402
+ x=hidden_states,
403
+ weight_scale=w1_scale.to(torch.float32),
404
+ activation_scale=pertoken_scale,
405
+ bias=None,
406
+ quant_scale=None,
407
+ quant_offset=None,
408
+ group_index=expert_tokens,
409
+ activate_left=True,
410
+ quant_mode=1,
411
+ )
412
+
413
+ hidden_states = torch_npu.npu_grouped_matmul(
414
+ x=[hidden_states],
415
+ weight=[w2],
416
+ scale=[w2_scale.to(torch.bfloat16)],
417
+ per_token_scale=[pertoken_scale.view(-1)],
418
+ split_item=3,
419
+ group_list_type=group_list_type,
420
+ group_type=0,
421
+ group_list=expert_tokens,
422
+ output_dtype=torch.bfloat16)[0]
423
+
424
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
425
+ expanded_permuted_rows=hidden_states.unsqueeze(1),
426
+ skip1=None,
427
+ skip2=None,
428
+ bias=None,
429
+ scales=topk_weights.to(torch.bfloat16),
430
+ expanded_src_to_dst_row=expanded_x_idx.to(torch.int32),
431
+ export_for_source_row=topk_ids,
432
+ drop_pad_mode=3
433
+ ).to(torch.bfloat16)
434
+
435
+ if len(original_shape) == 3:
436
+ final_hidden_states = final_hidden_states.view(original_shape)
437
+
438
+ return final_hidden_states
439
+
440
+
441
+ def fused_experts(hidden_states: torch.Tensor,
442
+ w1: torch.Tensor,
443
+ w1_scale: torch.Tensor,
444
+ w2: torch.Tensor,
445
+ w2_scale: torch.Tensor,
446
+ topk_weights: torch.Tensor,
447
+ topk_ids: torch.Tensor,
448
+ top_k: int,
449
+ expert_map: torch.Tensor = None):
450
+ original_shape = hidden_states.shape
451
+ if len(original_shape) == 3:
452
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
453
+
454
+ num_tokens, _ = hidden_states.shape
455
+ num_experts = w1.shape[0]
456
+ dtype = hidden_states.dtype
457
+ device = hidden_states.device
458
+
459
+ if expert_map is not None:
460
+ # Generate token indices and flatten
461
+ token_indices = (torch.arange(num_tokens,
462
+ device=device,
463
+ dtype=torch.int64).unsqueeze(1).expand(
464
+ -1, top_k).reshape(-1))
465
+
466
+ # Flatten token-to-expert mappings and map to local experts
467
+ weights_flat = topk_weights.view(-1)
468
+ experts_flat = topk_ids.view(-1)
469
+ local_experts_flat = expert_map[experts_flat]
470
+
471
+ # Filter valid token-expert pairs
472
+ mask = local_experts_flat != -1
473
+ filtered_weights = torch.where(
474
+ mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
475
+ filtered_experts = torch.where(
476
+ mask, local_experts_flat,
477
+ torch.full_like(local_experts_flat,
478
+ num_experts)).to(topk_ids.dtype)
479
+
480
+ # Sort by local expert IDs
481
+ sort_indices = torch.argsort(filtered_experts)
482
+ sorted_token_indices = token_indices[sort_indices]
483
+ sorted_weights = filtered_weights[sort_indices]
484
+
485
+ # Compute token counts with minlength of num_experts
486
+ # This is equivalent to but faster than:
487
+ # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
488
+ token_counts = torch.zeros(num_experts + 1,
489
+ device=device,
490
+ dtype=torch.int64)
491
+ ones = torch.ones_like(filtered_experts, dtype=torch.int64)
492
+ token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
493
+ expert_tokens = token_counts[:num_experts]
494
+ # Rearrange hidden_states
495
+ hidden_states = hidden_states[sorted_token_indices]
496
+ group_list_type = 1
497
+ else:
498
+ row_idx_len = num_tokens * top_k
499
+ row_idx = torch.arange(0,
500
+ row_idx_len,
501
+ dtype=torch.int32,
502
+ device=topk_weights.device).view(
503
+ top_k, -1).permute(1, 0).contiguous()
504
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
505
+ hidden_states,
506
+ row_idx=row_idx,
507
+ expert_idx=topk_ids,
508
+ active_num=num_tokens)
509
+
510
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
511
+ expanded_expert_idx, num_experts)
512
+ expert_tokens = expert_tokens.to(torch.int64)
513
+ group_list_type = 0
514
+
515
+ # `hidden_states` will be disposed in the `apply_mlp` function
516
+ hidden_states = apply_mlp(hidden_states,
517
+ w1,
518
+ w1_scale,
519
+ w2,
520
+ w2_scale,
521
+ expert_tokens,
522
+ group_list_type=group_list_type)
523
+
524
+ if expert_map is not None:
525
+ hidden_states.mul_(sorted_weights.unsqueeze(1))
526
+ final_hidden_states = torch.zeros(*original_shape,
527
+ device=device,
528
+ dtype=dtype)
529
+
530
+ num_valid_tokens = mask.sum()
531
+ valid_token_mask = torch.arange(
532
+ 0, sorted_token_indices.shape[0],
533
+ device=device).unsqueeze(1) < num_valid_tokens
534
+ hidden_states = hidden_states.masked_fill_(~valid_token_mask,
535
+ 0).to(dtype)
536
+ final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
537
+ else:
538
+ # TODO: Reorder device memory 2 times here, replace the current
539
+ # implementation here when suitable operators become available.
540
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
541
+ hidden_states,
542
+ skip1=None,
543
+ skip2=None,
544
+ bias=None,
545
+ scales=topk_weights,
546
+ expanded_src_to_dst_row=expanded_row_idx,
547
+ export_for_source_row=topk_ids,
548
+ )
549
+
550
+ if len(original_shape) == 3:
551
+ final_hidden_states = final_hidden_states.view(original_shape)
552
+ return final_hidden_states
553
+
554
+
555
+ class AscendW8A8DynamicLinearMethod:
556
+ """Linear method for Ascend W8A8_DYNAMIC.
557
+ """
558
+
559
+ def __init__(self):
560
+ self.transpose_weight = True
561
+
562
+ @staticmethod
563
+ def get_weight(input_size: int, output_size: int,
564
+ params_dtype: torch.dtype) -> Dict[str, Any]:
565
+ params_dict = {
566
+ "weight": torch.empty(output_size, input_size, dtype=torch.int8)
567
+ }
568
+ return params_dict
569
+
570
+ @staticmethod
571
+ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
572
+ return {}
573
+
574
+ @staticmethod
575
+ def get_perchannel_param(
576
+ output_size: int,
577
+ params_dtype: torch.dtype,
578
+ ) -> Dict[str, Any]:
579
+ params_dict = {}
580
+ params_dict["weight_scale"] = torch.empty(output_size,
581
+ 1,
582
+ dtype=params_dtype)
583
+ params_dict["weight_offset"] = torch.empty(output_size,
584
+ 1,
585
+ dtype=params_dtype)
586
+ return params_dict
587
+
588
+ @staticmethod
589
+ def apply(
590
+ layer: torch.nn.Module,
591
+ x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
592
+ bias: Optional[torch.Tensor] = None,
593
+ tp_rank: Optional[int] = 0,
594
+ ) -> torch.Tensor:
595
+ config = getattr(layer, "_ascend_quant_config", {})
596
+ if not isinstance(x, tuple):
597
+ output_dtype = config.get("output_dtype", x.dtype)
598
+ quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
599
+ else:
600
+ assert "output_dtype" in config.keys(), (
601
+ f"DynamicLinearMethod needs explicitly specified `output_dtype`"
602
+ f"for pre-quantized input, got config [{config}]")
603
+ output_dtype = config["output_dtype"]
604
+ quantized_x, dynamic_scale = x
605
+ pertoken_scale = (dynamic_scale
606
+ if config.get("pertoken_scale", True) else None)
607
+
608
+ output = torch_npu.npu_quant_matmul(
609
+ quantized_x,
610
+ layer.weight,
611
+ layer.weight_scale,
612
+ pertoken_scale=pertoken_scale,
613
+ bias=bias,
614
+ output_dtype=output_dtype,
615
+ )
616
+ return ((output, dynamic_scale)
617
+ if config.get("return_scale", False) else output)
618
+
619
+ def process_weights_after_loading(self, layer):
620
+ if self.transpose_weight:
621
+ layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
622
+ # cast quantized weight tensors in NZ format (29) for higher inference speed
623
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
624
+ layer.weight_scale.data = layer.weight_scale.data.flatten()
625
+ layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
626
+ layer.weight_offset.data = layer.weight_offset.data.flatten()
627
+
628
+
629
+ class AscendW8A8DynamicFusedMoEMethod:
630
+ """FusedMoe method for Ascend W8A8_DYNAMIC.
631
+ """
632
+
633
+ def __init__(self):
634
+ self.transpose_weight = True
635
+
636
+ self.ep_group = get_ep_group()
637
+
638
+ ascend_config = get_ascend_config()
639
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
640
+
641
+ try:
642
+ device_group = self.ep_group.device_group
643
+ # TODO: Try local_rank = ep_group.rank_in_group
644
+ local_rank = torch.distributed.get_rank(group=device_group)
645
+ backend = device_group._get_backend(torch.device("npu"))
646
+ self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
647
+ local_rank)
648
+ except AttributeError:
649
+ self.moe_all_to_all_group_name = ""
650
+
651
+ @staticmethod
652
+ def get_weight(num_experts: int, intermediate_size_per_partition: int,
653
+ hidden_sizes: int,
654
+ params_dtype: torch.dtype) -> Dict[str, Any]:
655
+ param_dict = {}
656
+ param_dict["w13_weight"] = torch.empty(num_experts,
657
+ 2 *
658
+ intermediate_size_per_partition,
659
+ hidden_sizes,
660
+ dtype=torch.int8)
661
+ param_dict["w2_weight"] = torch.empty(num_experts,
662
+ hidden_sizes,
663
+ intermediate_size_per_partition,
664
+ dtype=torch.int8)
665
+ return param_dict
666
+
667
+ @staticmethod
668
+ def get_dynamic_quant_param(num_experts: int,
669
+ intermediate_size_per_partition: int,
670
+ hidden_sizes: int,
671
+ params_dtype: torch.dtype) -> Dict[str, Any]:
672
+ param_dict = {}
673
+ param_dict["w13_weight_scale"] = torch.empty(
674
+ num_experts,
675
+ 2 * intermediate_size_per_partition,
676
+ 1,
677
+ dtype=params_dtype)
678
+ param_dict["w13_weight_offset"] = torch.empty(
679
+ num_experts,
680
+ 2 * intermediate_size_per_partition,
681
+ 1,
682
+ dtype=params_dtype)
683
+ param_dict["w2_weight_scale"] = torch.empty(num_experts,
684
+ hidden_sizes,
685
+ 1,
686
+ dtype=params_dtype)
687
+ param_dict["w2_weight_offset"] = torch.empty(num_experts,
688
+ hidden_sizes,
689
+ 1,
690
+ dtype=params_dtype)
691
+ return param_dict
692
+
693
+ def apply(
694
+ self,
695
+ layer: torch.nn.Module,
696
+ x: torch.Tensor,
697
+ router_logits: torch.Tensor,
698
+ top_k: int,
699
+ renormalize: bool,
700
+ use_grouped_topk: bool = False,
701
+ global_num_experts: int = -1,
702
+ expert_map: Optional[torch.Tensor] = None,
703
+ topk_group: Optional[int] = None,
704
+ num_expert_group: Optional[int] = None,
705
+ custom_routing_function: Optional[Callable] = None,
706
+ scoring_func: str = "softmax",
707
+ e_score_correction_bias: Optional[torch.Tensor] = None,
708
+ is_prefill: bool = True,
709
+ enable_force_load_balance: bool = True,
710
+ log2phy: torch.Tensor = None,
711
+ global_redundant_expert_num: int = 0,
712
+ shared_experts: Optional[Any] = None,
713
+ **kwargs,
714
+ ) -> torch.Tensor:
715
+ assert router_logits.shape[
716
+ 1] == global_num_experts, "Number of global experts mismatch"
717
+
718
+ is_deepseek_v3_r1 = global_num_experts == 256
719
+ use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
720
+
721
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
722
+ if use_grouped_topk and is_deepseek_v3_r1:
723
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
724
+ router_logits,
725
+ k=top_k, # topk当前写8
726
+ bias=e_score_correction_bias,
727
+ k_group=topk_group, # fix: 4
728
+ group_count=num_expert_group, # fix 8
729
+ group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
730
+ renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
731
+ norm_type=1, # 0: softmax; 1: sigmoid(fix)
732
+ # out_flag=False, # todo new api; 第三个输出是否输出
733
+ # y2_flag=False, # old api; 第三个输出是否输出
734
+ routed_scaling_factor=1,
735
+ eps=float(1e-20))
736
+ else:
737
+ topk_weights, topk_ids = select_experts(
738
+ hidden_states=x,
739
+ router_logits=router_logits,
740
+ top_k=top_k,
741
+ use_grouped_topk=use_grouped_topk,
742
+ renormalize=renormalize,
743
+ topk_group=topk_group,
744
+ num_expert_group=num_expert_group,
745
+ custom_routing_function=custom_routing_function,
746
+ scoring_func=scoring_func,
747
+ e_score_correction_bias=e_score_correction_bias,
748
+ )
749
+
750
+ # this is a naive implementation for experts load balance so as
751
+ # to avoid accumulating too much tokens on a single rank.
752
+ # currently it is only activated when doing profile runs.
753
+ if enable_force_load_balance:
754
+ topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
755
+
756
+ topk_weights = topk_weights.to(x.dtype)
757
+
758
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
759
+ is_prefill, is_deepseek_v3_r1)
760
+ if fused_moe_state == FusedMoEState.AllGatherEP:
761
+ return fused_experts_with_allgather(
762
+ hidden_states=x,
763
+ w1=layer.w13_weight,
764
+ w1_scale=layer.w13_weight_scale,
765
+ w2=layer.w2_weight,
766
+ w2_scale=layer.w2_weight_scale,
767
+ topk_weights=topk_weights,
768
+ topk_ids=topk_ids,
769
+ top_k=top_k,
770
+ expert_map=expert_map)
771
+ elif fused_moe_state == FusedMoEState.MC2:
772
+ return fused_experts_with_mc2(
773
+ hidden_states=x,
774
+ w1=layer.w13_weight,
775
+ w2=layer.w2_weight,
776
+ w1_scale=layer.w13_weight_scale,
777
+ w2_scale=layer.w2_weight_scale,
778
+ topk_weights=topk_weights,
779
+ topk_ids=topk_ids,
780
+ top_k=top_k,
781
+ expert_map=expert_map,
782
+ moe_all_to_all_group_name=self.moe_all_to_all_group_name,
783
+ log2phy=log2phy,
784
+ global_redundant_expert_num=global_redundant_expert_num,
785
+ shared_experts=shared_experts)
786
+ elif fused_moe_state in [
787
+ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
788
+ ]:
789
+ return fused_experts(hidden_states=x,
790
+ w1=layer.w13_weight,
791
+ w1_scale=layer.w13_weight_scale,
792
+ w2=layer.w2_weight,
793
+ w2_scale=layer.w2_weight_scale,
794
+ topk_weights=topk_weights,
795
+ topk_ids=topk_ids,
796
+ top_k=top_k,
797
+ expert_map=expert_map)
798
+ else:
799
+ # The current implementation of deepseek moe splits hidden_states
800
+ # according to tp_size before they are feed into fused_moe module.
801
+ # Therefore, all2all is needed no matter how dp/tp is set so as to
802
+ # dispatch/combine tokens.
803
+ return fused_experts_with_all2all(
804
+ hidden_states=x,
805
+ w1=layer.w13_weight,
806
+ w1_scale=layer.w13_weight_scale,
807
+ w2=layer.w2_weight,
808
+ w2_scale=layer.w2_weight_scale,
809
+ topk_weights=topk_weights,
810
+ topk_ids=topk_ids,
811
+ top_k=top_k,
812
+ expert_map=expert_map,
813
+ ep_group=self.ep_group,
814
+ log2phy=log2phy,
815
+ global_redundant_expert_num=global_redundant_expert_num,
816
+ )
817
+
818
+ def process_weights_after_loading(self, layer):
819
+ if self.transpose_weight:
820
+ layer.w13_weight.data = layer.w13_weight.data.transpose(
821
+ 1, 2).contiguous()
822
+ layer.w2_weight.data = layer.w2_weight.data.transpose(
823
+ 1, 2).contiguous()
824
+ layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
825
+ layer.w13_weight_scale.data.shape[0], -1)
826
+ layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
827
+ layer.w13_weight_offset.data.shape[0], -1)
828
+ layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
829
+ layer.w2_weight_scale.data.shape[0], -1)
830
+ layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
831
+ layer.w2_weight_offset.data.shape[0], -1)
inference/vllm_ascend/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"add_bos_token": true, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguTokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguTokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '<s>[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}
inference/vllm_ascend/utils.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2023 The vLLM team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # This file is a part of the vllm-ascend project.
17
+ # Adapted from vllm-project/vllm/vllm/worker/worker.py
18
+ #
19
+
20
+ import atexit
21
+ import fcntl
22
+ import math
23
+ import os
24
+ import shutil
25
+ from contextlib import contextmanager, nullcontext
26
+ from enum import Enum
27
+ from threading import Lock
28
+ from typing import TYPE_CHECKING, List, Tuple
29
+
30
+ import torch
31
+ import torch_npu # noqa: F401 # noqa: F401
32
+ from packaging.version import InvalidVersion, Version
33
+ from torch_npu.npu.streams import Event
34
+ from vllm.logger import logger
35
+
36
+ import vllm_ascend.envs as envs
37
+ from vllm_ascend.ascend_config import get_ascend_config
38
+
39
+ try:
40
+ # Recent release of torchair has moved these ops to `.scope`.
41
+ from torchair.scope import npu_stream_switch as _npu_stream_switch
42
+ from torchair.scope import npu_wait_tensor as _npu_wait_tensor
43
+ except ImportError:
44
+ from torchair.ops import NpuStreamSwitch as _npu_stream_switch
45
+ from torchair.ops import npu_wait_tensor as _npu_wait_tensor
46
+
47
+ if TYPE_CHECKING:
48
+ from vllm.config import VllmConfig
49
+ else:
50
+ VllmConfig = None
51
+
52
+ # NOTE: Currently, we can only capture 1920 graphs at most,
53
+ # due to the limitation of ACL graph. This number is bounded by
54
+ # the number of streams, which is 2048, we save 128 streams
55
+ # as a buffer.
56
+ # Maximum number of graphs that can be captured by ACL Graph
57
+ MAX_CAPTURE_SIZE = 1920
58
+
59
+ ASCEND_QUATIZATION_METHOD = "ascend"
60
+ SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
61
+
62
+ ACL_FORMAT_FRACTAL_ND = 2
63
+ ACL_FORMAT_FRACTAL_NZ = 29
64
+
65
+ _CUSTOM_OP_ENABLED = None
66
+ _IS_310P = None
67
+ _SLEEP_MODE_ENABLED = None
68
+ _CURRENT_STREAM = None
69
+
70
+
71
+ def is_310p():
72
+ global _IS_310P
73
+ if _IS_310P is None:
74
+ from vllm_ascend import _build_info # type: ignore
75
+ _IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
76
+ return _IS_310P
77
+
78
+
79
+ def sleep_mode_enabled():
80
+ global _SLEEP_MODE_ENABLED
81
+ if _SLEEP_MODE_ENABLED is None:
82
+ from vllm_ascend import _build_info # type: ignore
83
+ _SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__
84
+ return _SLEEP_MODE_ENABLED
85
+
86
+
87
+ def _round_up(x: int, align: int):
88
+ # round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc.
89
+ # input: 15, 16 -> output: 16
90
+ # input: 17, 16 -> output: 32
91
+ # input: 30, 16 -> output: 32
92
+ # input: 33, 16 -> output: 48
93
+ # ...
94
+ return (x + align - 1) // align * align
95
+
96
+
97
+ def _custom_pad(x, pad_dims):
98
+ # pad the input tensor to the shape of pad_dims
99
+ # input: (13, 30), pad_dims: [0, 2, 0, 3]
100
+ # output: (16, 32)
101
+ return torch.nn.functional.pad(x, pad_dims)
102
+
103
+
104
+ def _custom_reshape(x, target_shape):
105
+ # reshape the input tensor to the shape of target_shape
106
+ # input: (16, 32), target_shape: [1, 16, 2, 16]
107
+ # output: (1, 16, 2, 16)
108
+ return x.reshape(target_shape)
109
+
110
+
111
+ def _custom_transpose(x, dim1, dim2):
112
+ # transpose the input tensor
113
+ # input: (1, 16, 2, 16), dim1: 1, dim2: 2
114
+ # output: (1, 2, 16, 16)
115
+ return x.transpose(dim1, dim2)
116
+
117
+
118
+ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
119
+ # in_tensor: (13, 30)
120
+ aux_dims = [1, 0, 0, 16]
121
+ # aux_dims[1]: 16
122
+ aux_dims[1] = _round_up(in_tensor.size(0), 16)
123
+ # aux_dims[2]: 2
124
+ aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
125
+
126
+ # after: aux_dims: [1, 16, 2, 16]
127
+
128
+ pad_dims = [0, 0, 0, 0]
129
+ # pad_dims[1]: 2
130
+ pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
131
+ # pad_dims[3]: 3
132
+ pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
133
+
134
+ # after: pad_dims: [0, 2, 0, 3]
135
+
136
+ # return: (1, 2, 16, 16)
137
+ return _custom_transpose(
138
+ _custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
139
+ 2).contiguous()
140
+
141
+
142
+ def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
143
+ num_tokens = mask_tensor.shape[0]
144
+ max_seq_len = mask_tensor.shape[1]
145
+
146
+ tokens_pad = (num_tokens + 15) // 16 * 16
147
+ max_seq_len_pad = (max_seq_len + 15) // 16 * 16
148
+
149
+ mask_tensor_pad = \
150
+ torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
151
+ mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
152
+ mask = mask_tensor_pad.reshape(
153
+ (1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
154
+ return mask
155
+
156
+
157
+ def aligned_16(tensor: torch.Tensor):
158
+ """Aligned tensor for 310P"""
159
+
160
+ # Get the size of the current 0th dimension
161
+ n = tensor.size(0)
162
+
163
+ # Calculate the aligned size
164
+ n_aligned = ((n + 15) // 16) * 16
165
+
166
+ # If already aligned, return the original tensor
167
+ if n == n_aligned:
168
+ return tensor
169
+
170
+ # Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
171
+ new_tensor = torch.zeros(n_aligned,
172
+ *tensor.shape[1:],
173
+ dtype=tensor.dtype,
174
+ device=tensor.device)
175
+
176
+ # Copy the original tensor to the first N positions of the new tensor
177
+ new_tensor[:n] = tensor
178
+
179
+ return new_tensor
180
+
181
+
182
+ def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
183
+ # currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
184
+ # in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
185
+ # is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
186
+ # conversion when using torchair graph mode on 300I Duo platform.
187
+ # TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
188
+ # accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
189
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
190
+
191
+ use_torchair = get_ascend_config().torchair_graph_config.enabled
192
+ if not is_310p() or not use_torchair:
193
+ return
194
+ for module in model.modules():
195
+ if isinstance(module, FusedMoE):
196
+ if torch_npu.get_npu_format(module.w13_weight.data) == format:
197
+ return
198
+ module.w13_weight.data = torch_npu.npu_format_cast(
199
+ module.w13_weight.data, format)
200
+ module.w2_weight.data = torch_npu.npu_format_cast(
201
+ module.w2_weight.data, format)
202
+
203
+
204
+ def try_register_lib(lib_name: str, lib_info: str = ""):
205
+ import importlib
206
+ import importlib.util
207
+ try:
208
+ module_spec = importlib.util.find_spec(lib_name)
209
+ if module_spec is not None:
210
+ importlib.import_module(lib_name)
211
+ if lib_info:
212
+ logger.info(lib_info)
213
+ except Exception:
214
+ pass
215
+
216
+
217
+ def enable_custom_op():
218
+ """
219
+ Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
220
+ Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
221
+ """
222
+ global _CUSTOM_OP_ENABLED
223
+ if _CUSTOM_OP_ENABLED is not None:
224
+ return _CUSTOM_OP_ENABLED
225
+ try:
226
+ # register custom ops into torch_library here
227
+ import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
228
+ _CUSTOM_OP_ENABLED = True
229
+ except ImportError:
230
+ _CUSTOM_OP_ENABLED = False
231
+ logger.warning(
232
+ "Warning: Failed to register custom ops, all custom ops will be disabled"
233
+ )
234
+ return _CUSTOM_OP_ENABLED
235
+
236
+
237
+ def find_hccl_library() -> str:
238
+ """
239
+ We either use the library file specified by the `HCCL_SO_PATH`
240
+ environment variable, or we find the library file brought by PyTorch.
241
+ After importing `torch`, `libhccl.so` can be
242
+ found by `ctypes` automatically.
243
+ """
244
+ so_file = envs.HCCL_SO_PATH
245
+
246
+ # manually load the hccl library
247
+ if so_file:
248
+ logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
249
+ so_file)
250
+ else:
251
+ if torch.version.cann is not None:
252
+ so_file = "libhccl.so"
253
+ else:
254
+ raise ValueError("HCCL only supports Ascend NPU backends.")
255
+ logger.info("Found hccl from library %s", so_file)
256
+ return so_file
257
+
258
+
259
+ def current_stream() -> torch.npu.Stream:
260
+ """
261
+ replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
262
+ it turns out that `torch.npu.current_stream()` is quite expensive,
263
+ as it will construct a new stream object at each call.
264
+ here we patch `torch.npu.set_stream` to keep track of the current stream
265
+ directly, so that we can avoid calling `torch.npu.current_stream()`.
266
+
267
+ """
268
+ global _CURRENT_STREAM
269
+ if _CURRENT_STREAM is None:
270
+ # when this function is called before any stream is set,
271
+ # we return the default stream.
272
+ _CURRENT_STREAM = torch.npu.current_stream()
273
+ return _CURRENT_STREAM
274
+
275
+
276
+ def adapt_patch(is_global_patch: bool = False):
277
+ if is_global_patch:
278
+ from vllm_ascend.patch import platform # noqa: F401
279
+ else:
280
+ from vllm_ascend.patch import worker # noqa: F401
281
+
282
+
283
+ def vllm_version_is(target_vllm_version: str):
284
+ if envs.VLLM_VERSION is not None:
285
+ vllm_version = envs.VLLM_VERSION
286
+ else:
287
+ import vllm
288
+ vllm_version = vllm.__version__
289
+ try:
290
+ return Version(vllm_version) == Version(target_vllm_version)
291
+ except InvalidVersion:
292
+ raise ValueError(
293
+ f"Invalid vllm version {vllm_version} found. A dev version of vllm "
294
+ "is installed probably. Set the environment variable VLLM_VERSION "
295
+ "to control it by hand. And please make sure the value follows the "
296
+ "format of x.y.z.")
297
+
298
+
299
+ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
300
+ """Update ACL graph capture sizes based on hardware limitations"""
301
+ # Store original configuration and temporarily clear it
302
+ compilation_config = vllm_config.compilation_config
303
+ original_sizes, compilation_config.cudagraph_capture_sizes = \
304
+ compilation_config.cudagraph_capture_sizes, None
305
+
306
+ # Calculate parallel configuration factor
307
+ num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
308
+ parallel_config = vllm_config.parallel_config
309
+
310
+ # TODO: Find out whether we need to take into account the pp_size
311
+ parallel_factor = 1 + sum(size > 1 for size in [
312
+ parallel_config.data_parallel_size_local,
313
+ parallel_config.tensor_parallel_size,
314
+ parallel_config.expert_parallel_size,
315
+ parallel_config.expert_tensor_parallel_size,
316
+ ])
317
+
318
+ # Calculate maximum supported batch sizes considering model architecture
319
+ max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
320
+ (num_hidden_layers + 1) / parallel_factor)
321
+ logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
322
+ max_num_batch_sizes)
323
+
324
+ # If original sizes exceed maximum, sample a representative subset
325
+ if max_num_batch_sizes < len(original_sizes):
326
+ # Sample uniformly from original sizes
327
+ step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
328
+ indices = [round(i * step) for i in range(max_num_batch_sizes)]
329
+
330
+ # Ensure first and last elements are preserved
331
+ indices[0], indices[-1] = 0, len(original_sizes) - 1
332
+
333
+ sampled_sizes = [original_sizes[i] for i in indices]
334
+ compilation_config.init_with_cudagraph_sizes(sampled_sizes)
335
+
336
+ logger.info(
337
+ "Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
338
+ vllm_config.model_config.architectures[0],
339
+ num_hidden_layers,
340
+ len(original_sizes),
341
+ len(compilation_config.
342
+ cudagraph_capture_sizes # type: ignore[arg-type]
343
+ ))
344
+ else:
345
+ # No adjustment needed
346
+ compilation_config.cudagraph_capture_sizes = original_sizes
347
+ logger.info(
348
+ "No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
349
+ vllm_config.model_config.architectures[0], num_hidden_layers,
350
+ len(original_sizes))
351
+
352
+
353
+ # TODO(wxy): Move to ops module
354
+ def dispose_tensor(x: torch.Tensor):
355
+ x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
356
+
357
+
358
+ class ProfileExecuteDuration:
359
+ _instance = None
360
+ _observations: List[Tuple[str, Event, Event]] = []
361
+ _lock = Lock()
362
+
363
+ def __new__(cls):
364
+ with cls._lock:
365
+ if cls._instance is None:
366
+ cls._instance = super().__new__(cls)
367
+ atexit.register(cls._instance.destroy)
368
+ return cls._instance
369
+
370
+ def destroy(self):
371
+ with self._lock:
372
+ self._observations.clear()
373
+
374
+ @contextmanager
375
+ def capture_async(self, duration_tag: str):
376
+ if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
377
+ yield
378
+ return
379
+
380
+ observe_start = Event(enable_timing=True)
381
+ observe_start.record()
382
+ try:
383
+ yield
384
+ finally:
385
+ observe_end = Event(enable_timing=True)
386
+ observe_end.record()
387
+ with self._lock:
388
+ self._observations.append(
389
+ (duration_tag, observe_start, observe_end))
390
+
391
+ def pop_captured_sync(self) -> dict:
392
+ """Pop and synchronize all events in the observation list"""
393
+ durations: dict[str, float] = {}
394
+ if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
395
+ return durations
396
+
397
+ while self._observations:
398
+ with self._lock:
399
+ tag, observe_start, observe_end = self._observations.pop()
400
+ observe_end.synchronize()
401
+ durations[tag] = observe_start.elapsed_time(observe_end)
402
+
403
+ return durations
404
+
405
+
406
+ # TODO(wxy): Move to ops module
407
+ def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
408
+ return _npu_stream_switch(tag, priority) if enabled else nullcontext()
409
+
410
+
411
+ # TODO(wxy): Move to ops module
412
+ def npu_wait_tensor(self: torch.Tensor,
413
+ dependency: torch.Tensor,
414
+ *,
415
+ enabled: bool = True):
416
+ return _npu_wait_tensor(self, dependency) if enabled else self
417
+
418
+
419
+ # TODO(wxy): Move to ops module
420
+ def npu_prefetch(input: torch.Tensor,
421
+ dependency: torch.Tensor,
422
+ max_size: int = 0,
423
+ *,
424
+ enabled: bool = True):
425
+ if not enabled:
426
+ return
427
+ input_size = input.element_size() * input.numel()
428
+ if max_size <= 0 or max_size > input_size:
429
+ max_size = input_size
430
+ torch_npu.npu_prefetch(input, dependency, max_size)
431
+
432
+
433
+ # TODO(zzzzwwjj): move this into forward_context
434
+ class FusedMoEState(Enum):
435
+ AllGather = 0
436
+ All2All = 1
437
+ MC2 = 2
438
+ AllGatherEP = 3
439
+ NaiveMulticast = 4
440
+
441
+
442
+ # TODO(ttanzhiqiang): rm_router_logits
443
+ # dp>1 will trigger
444
+ # In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
445
+ def get_rm_router_logits_state(ep_size: int, dp_size: int,
446
+ is_deepseek_v3_r1: bool):
447
+ # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
448
+ # only supports deepseek v3/r1
449
+ if dp_size > 1:
450
+ if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
451
+ and is_deepseek_v3_r1):
452
+ return True
453
+ elif ep_size == 1 and is_deepseek_v3_r1:
454
+ return True
455
+ return False
456
+
457
+
458
+ # TODO(ttanzhiqiang): all_reduce merge
459
+ # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
460
+ # Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
461
+ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
462
+ # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
463
+ # only supports deepseek v3/r1
464
+ if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
465
+ and is_deepseek_v3_r1):
466
+ return True
467
+ elif ep_size == 1 and is_deepseek_v3_r1:
468
+ return True
469
+ return False
470
+
471
+
472
+ # TODO(zzzzwwjj): add soc_version to choose branch
473
+ def get_fused_moe_state(ep_size: int, with_prefill: bool,
474
+ is_deepseek_v3_r1: bool):
475
+ # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
476
+ # only supports deepseek v3/r1
477
+ if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
478
+ and is_deepseek_v3_r1 and not with_prefill):
479
+ return FusedMoEState.AllGatherEP
480
+ elif ep_size == 1:
481
+ if with_prefill:
482
+ return FusedMoEState.NaiveMulticast
483
+ else:
484
+ return FusedMoEState.AllGather
485
+ # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
486
+ elif ep_size < 16 or with_prefill:
487
+ return FusedMoEState.All2All
488
+ else:
489
+ return FusedMoEState.MC2
490
+
491
+
492
+ KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
493
+ KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
494
+ TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
495
+ TORCHAIR_CACHE_DIR = os.getenv(
496
+ 'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
497
+
498
+
499
+ def get_torchair_current_work_dir(file_name=None):
500
+ if file_name is None:
501
+ return TORCHAIR_CACHE_DIR
502
+ return os.path.join(TORCHAIR_CACHE_DIR, file_name)
503
+
504
+
505
+ def check_torchair_cache_exist():
506
+ res = False
507
+ torch_air_abs_path = get_torchair_current_work_dir()
508
+ if os.path.exists(torch_air_abs_path):
509
+ file_list = os.listdir(torch_air_abs_path)
510
+ if len(file_list) != 0:
511
+ res = True
512
+ return res
513
+
514
+
515
+ def check_kv_cache_bytes_cache_exist():
516
+ res = False
517
+ kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
518
+ KV_CACHE_BYTES_CACHE_PATH_NAME)
519
+ if os.path.exists(kv_cache_bytes_cache_abs_path):
520
+ file_list = os.listdir(kv_cache_bytes_cache_abs_path)
521
+ if len(file_list) != 0:
522
+ res = True
523
+ return res
524
+
525
+
526
+ def read_kv_cache_bytes_from_file(rank) -> int:
527
+ kv_cache_bytes = -1
528
+ kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
529
+ KV_CACHE_BYTES_CACHE_PATH_NAME)
530
+ kv_cache_bytes_file = os.path.join(
531
+ kv_cache_bytes_cache_abs_path,
532
+ f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
533
+ with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
534
+ with file_lock(f, fcntl.LOCK_SH):
535
+ kv_cache_bytes = int(f.readline())
536
+ return kv_cache_bytes
537
+
538
+
539
+ @contextmanager
540
+ def file_lock(file_descriptor, lock_type):
541
+ fcntl.flock(file_descriptor, lock_type)
542
+ try:
543
+ yield
544
+ finally:
545
+ fcntl.flock(file_descriptor, fcntl.LOCK_UN)
546
+
547
+
548
+ def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
549
+ kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
550
+ KV_CACHE_BYTES_CACHE_PATH_NAME)
551
+ os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
552
+ kv_cache_bytes_file = os.path.join(
553
+ kv_cache_bytes_cache_abs_path,
554
+ f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
555
+ with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
556
+ with file_lock(f, fcntl.LOCK_EX):
557
+ f.write(f"{kv_cache_bytes}")
558
+
559
+
560
+ def delete_torchair_cache_file():
561
+ torch_air_abs_path = get_torchair_current_work_dir()
562
+ if os.path.exists(torch_air_abs_path):
563
+ shutil.rmtree(torch_air_abs_path)
inference/vllm_ascend/worker/model_runner_v1.py ADDED
The diff for this file is too large to render. See raw diff
 
inference/vllm_ascend/worker/npu_input_batch.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2023 The vLLM team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # This file is a part of the vllm-ascend project.
17
+ # Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
18
+ #
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, cast, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from vllm.lora.request import LoRARequest
26
+ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
27
+ from vllm.pooling_params import PoolingParams
28
+ from vllm.sampling_params import SamplingParams, SamplingType
29
+ from vllm.utils import swap_dict_values
30
+ from vllm.v1.outputs import LogprobsTensors
31
+ from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
32
+ from vllm.v1.sample.metadata import SamplingMetadata
33
+ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
34
+ from vllm.v1.utils import copy_slice
35
+ from vllm.v1.worker.block_table import MultiGroupBlockTable
36
+
37
+ from vllm_ascend.pool.metadata import PoolingMetadata
38
+
39
+ _SAMPLING_EPS = 1e-5
40
+
41
+
42
+ @dataclass
43
+ class CachedRequestState:
44
+
45
+ req_id: str
46
+ prompt_token_ids: list[int]
47
+ mm_inputs: list[MultiModalKwargs]
48
+ mm_positions: list[PlaceholderRange]
49
+ sampling_params: Optional[SamplingParams]
50
+ pooling_params: Optional[PoolingParams]
51
+ generator: Optional[torch.Generator]
52
+
53
+ block_ids: tuple[list[int], ...]
54
+ num_computed_tokens: int
55
+ output_token_ids: list[int]
56
+
57
+ mrope_positions: Optional[torch.Tensor] = None
58
+ mrope_position_delta: Optional[int] = None
59
+
60
+ lora_request: Optional[LoRARequest] = None
61
+
62
+ def __post_init__(self):
63
+ self.num_prompt_tokens = len(self.prompt_token_ids)
64
+
65
+ @property
66
+ def num_tokens(self) -> int:
67
+ return self.num_prompt_tokens + len(self.output_token_ids)
68
+
69
+ def get_token_id(self, idx: int) -> int:
70
+ if idx < self.num_prompt_tokens:
71
+ return self.prompt_token_ids[idx]
72
+ else:
73
+ return self.output_token_ids[idx - self.num_prompt_tokens]
74
+
75
+ @dataclass
76
+ class SamplingMetadataTopNSigma(SamplingMetadata):
77
+ top_n_sigma: torch.Tensor
78
+ no_top_n_sigma: bool
79
+
80
+ class InputBatch:
81
+
82
+ def __init__(
83
+ self,
84
+ max_num_reqs: int,
85
+ max_model_len: int,
86
+ max_num_batched_tokens: int,
87
+ device: torch.device,
88
+ pin_memory: bool,
89
+ vocab_size: int,
90
+ block_sizes: list[int], # The block_size of each kv cache group
91
+ logits_processing_needs_token_ids: bool = False,
92
+ is_spec_decode: bool = False,
93
+ ):
94
+ self.is_spec_decode = is_spec_decode
95
+ self.max_num_reqs = max_num_reqs
96
+ self.max_model_len = max_model_len
97
+ self.max_num_batched_tokens = max_num_batched_tokens
98
+ self.device = device
99
+ self.pin_memory = pin_memory
100
+ self.vocab_size = vocab_size
101
+ self.logits_processing_needs_token_ids = (
102
+ logits_processing_needs_token_ids)
103
+
104
+ self._req_ids: list[Optional[str]] = []
105
+ self.req_id_to_index: dict[str, int] = {}
106
+
107
+ # TODO(woosuk): This buffer could be too large if max_model_len is big.
108
+ # Find a way to reduce the CPU memory usage.
109
+ # This buffer is not directly transferred to the NPU, so it does not
110
+ # need to be pinned.
111
+ self.token_ids_cpu_tensor = torch.zeros(
112
+ (max_num_reqs, max_model_len),
113
+ device="cpu",
114
+ dtype=torch.int32,
115
+ pin_memory=False,
116
+ )
117
+ self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
118
+ self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
119
+ self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
120
+ self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
121
+ self.num_computed_tokens_cpu_tensor = torch.zeros(
122
+ (max_num_reqs, ),
123
+ device="cpu",
124
+ dtype=torch.int32,
125
+ pin_memory=pin_memory,
126
+ )
127
+ self.num_computed_tokens_cpu = \
128
+ self.num_computed_tokens_cpu_tensor.numpy()
129
+
130
+ # Block table.
131
+ self.block_table = MultiGroupBlockTable(
132
+ max_num_reqs=max_num_reqs,
133
+ max_model_len=max_model_len,
134
+ max_num_batched_tokens=max_num_batched_tokens,
135
+ pin_memory=pin_memory,
136
+ device=device,
137
+ block_sizes=block_sizes,
138
+ )
139
+
140
+ # Sampling-related.
141
+ self.temperature = torch.empty((max_num_reqs, ),
142
+ dtype=torch.float32,
143
+ device=device)
144
+ self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
145
+ dtype=torch.float32,
146
+ device="cpu",
147
+ pin_memory=pin_memory)
148
+ self.temperature_cpu = self.temperature_cpu_tensor.numpy()
149
+ self.greedy_reqs: set[str] = set()
150
+ self.random_reqs: set[str] = set()
151
+
152
+ self.top_p = torch.empty((max_num_reqs, ),
153
+ dtype=torch.float32,
154
+ device=device)
155
+ self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
156
+ dtype=torch.float32,
157
+ device="cpu",
158
+ pin_memory=pin_memory)
159
+ self.top_p_cpu = self.top_p_cpu_tensor.numpy()
160
+ self.top_p_reqs: set[str] = set()
161
+
162
+ self.top_k = torch.empty((max_num_reqs, ),
163
+ dtype=torch.int32,
164
+ device=device)
165
+ self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
166
+ dtype=torch.int32,
167
+ device="cpu",
168
+ pin_memory=pin_memory)
169
+ self.top_k_cpu = self.top_k_cpu_tensor.numpy()
170
+ self.top_k_reqs: set[str] = set()
171
+
172
+ # IDs of requests which do not support spec decoding
173
+ self.spec_decode_unsupported_reqs: set[str] = set()
174
+
175
+ self.min_p = torch.empty((max_num_reqs, ),
176
+ dtype=torch.float32,
177
+ device=device)
178
+ self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
179
+ dtype=torch.float32,
180
+ device="cpu",
181
+ pin_memory=pin_memory)
182
+ self.min_p_cpu = self.min_p_cpu_tensor.numpy()
183
+ self.min_p_reqs: set[str] = set()
184
+
185
+ # topnsigma penalty
186
+ self.top_n_sigma = torch.empty((max_num_reqs, ),
187
+ dtype=torch.float,
188
+ device=device)
189
+ self.top_n_sigma_cpu_tensor = torch.empty(
190
+ (max_num_reqs, ),
191
+ dtype=torch.float,
192
+ device="cpu",
193
+ pin_memory=pin_memory)
194
+ self.top_n_sigma_cpu = \
195
+ self.top_n_sigma_cpu_tensor.numpy()
196
+ self.top_n_sigma_reqs: set[str] = set()
197
+
198
+ # Frequency penalty related data structures
199
+ self.frequency_penalties = torch.empty((max_num_reqs, ),
200
+ dtype=torch.float,
201
+ device=device)
202
+ self.frequency_penalties_cpu_tensor = torch.empty(
203
+ (max_num_reqs, ),
204
+ dtype=torch.float,
205
+ device="cpu",
206
+ pin_memory=pin_memory)
207
+ self.frequency_penalties_cpu = \
208
+ self.frequency_penalties_cpu_tensor.numpy()
209
+ self.frequency_penalties_reqs: set[str] = set()
210
+
211
+ # Presence penalty related data structures
212
+ self.presence_penalties = torch.empty((max_num_reqs, ),
213
+ dtype=torch.float,
214
+ device=device)
215
+ self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
216
+ dtype=torch.float,
217
+ device="cpu",
218
+ pin_memory=pin_memory)
219
+ self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
220
+ )
221
+ self.presence_penalties_reqs: set[str] = set()
222
+
223
+ # Repetition penalty related data structures
224
+ self.repetition_penalties = torch.empty((max_num_reqs, ),
225
+ dtype=torch.float,
226
+ device=device)
227
+ self.repetition_penalties_cpu_tensor = torch.empty(
228
+ (max_num_reqs, ),
229
+ dtype=torch.float,
230
+ device="cpu",
231
+ pin_memory=pin_memory)
232
+ self.repetition_penalties_cpu = \
233
+ self.repetition_penalties_cpu_tensor.numpy()
234
+ self.repetition_penalties_reqs: set[str] = set()
235
+
236
+ # req_index -> (min_tokens, stop_token_ids)
237
+ self.min_tokens: dict[int, tuple[int, set[int]]] = {}
238
+
239
+ # lora related
240
+ self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
241
+ dtype=np.int32)
242
+ self.lora_id_to_request_ids: dict[int, set[str]] = {}
243
+ self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
244
+
245
+ # req_index -> generator
246
+ # NOTE(woosuk): The indices of the requests that do not have their own
247
+ # generator should not be included in the dictionary.
248
+ self.generators: dict[int, torch.Generator] = {}
249
+
250
+ self.num_logprobs: dict[str, int] = {}
251
+ # NOTE(rob): num_prompt_logprobs only includes reqs
252
+ # that are currently in the prefill phase.
253
+ self.num_prompt_logprobs: dict[str, int] = {}
254
+
255
+ # To accumulate prompt logprobs tensor chunks across prefill steps.
256
+ self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
257
+
258
+ self.logit_bias: list[Optional[dict[int,
259
+ float]]] = [None] * max_num_reqs
260
+ self.has_allowed_token_ids: set[str] = set()
261
+ # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
262
+ # the value is False. Since we use masked_fill_ to set -inf.
263
+ self.allowed_token_ids_mask: Optional[torch.Tensor] = None
264
+ self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
265
+
266
+ # req_index -> bad_words_token_ids
267
+ self.bad_words_token_ids: dict[int, list[list[int]]] = {}
268
+
269
+ self.req_output_token_ids: list[Optional[list[int]]] = []
270
+
271
+ # Define logits processors.
272
+ # TODO(andy): logits processor list should be extensible via engine
273
+ # constructor argument; for now the list is fixed.
274
+ self.logitsprocs = init_builtin_logitsprocs(
275
+ pin_memory_available=pin_memory,
276
+ max_num_reqs=max_num_reqs + 1,
277
+ device=device)
278
+
279
+ # This is updated each time the batch constituents change.
280
+ self.sampling_metadata = self._make_sampling_metadata()
281
+
282
+ self.pooling_params: dict[str, PoolingParams] = {}
283
+
284
+ @property
285
+ def req_ids(self) -> list[str]:
286
+ # None elements should only be present transiently
287
+ # while performing state updates to the batch.
288
+ return cast(list[str], self._req_ids)
289
+
290
+ def add_request(
291
+ self,
292
+ request: "CachedRequestState",
293
+ req_index: Optional[int] = None,
294
+ ) -> None:
295
+ if req_index is None:
296
+ req_index = self.num_reqs
297
+ assert req_index < self.max_num_reqs
298
+
299
+ req_id = request.req_id
300
+ if req_index == len(self._req_ids):
301
+ self._req_ids.append(req_id)
302
+ self.req_output_token_ids.append(request.output_token_ids)
303
+ else:
304
+ self._req_ids[req_index] = req_id
305
+ self.req_output_token_ids[req_index] = request.output_token_ids
306
+
307
+ self.req_id_to_index[req_id] = req_index
308
+
309
+ # Copy the prompt token ids and output token ids.
310
+ num_prompt_tokens = len(request.prompt_token_ids)
311
+ self.num_prompt_tokens[req_index] = num_prompt_tokens
312
+ self.token_ids_cpu[
313
+ req_index, :num_prompt_tokens] = request.prompt_token_ids
314
+ start_idx = num_prompt_tokens
315
+ end_idx = start_idx + len(request.output_token_ids)
316
+ self.token_ids_cpu[req_index,
317
+ start_idx:end_idx] = request.output_token_ids
318
+ # Number of token ids in token_ids_cpu.
319
+ # NOTE(woosuk): This may include spec decode tokens.
320
+ self.num_tokens[req_index] = request.num_tokens
321
+ # Number of tokens without spec decode tokens.
322
+ self.num_tokens_no_spec[req_index] = request.num_tokens
323
+
324
+ self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
325
+ self.block_table.add_row(request.block_ids, req_index)
326
+
327
+ if sampling_params := request.sampling_params:
328
+ if self.is_spec_decode and is_spec_decode_unsupported(
329
+ sampling_params):
330
+ self.spec_decode_unsupported_reqs.add(req_id)
331
+ if sampling_params.sampling_type == SamplingType.GREEDY:
332
+ # Avoid later division by zero.
333
+ self.temperature_cpu[req_index] = -1.0
334
+ self.greedy_reqs.add(req_id)
335
+ else:
336
+ self.temperature_cpu[req_index] = sampling_params.temperature
337
+ self.random_reqs.add(req_id)
338
+
339
+ self.top_p_cpu[req_index] = sampling_params.top_p
340
+ if sampling_params.top_p < 1:
341
+ self.top_p_reqs.add(req_id)
342
+ top_k = sampling_params.top_k
343
+ if 0 < top_k < self.vocab_size:
344
+ self.top_k_reqs.add(req_id)
345
+ else:
346
+ top_k = self.vocab_size
347
+ self.top_k_cpu[req_index] = top_k
348
+ self.min_p_cpu[req_index] = sampling_params.min_p
349
+ self.frequency_penalties_cpu[
350
+ req_index] = sampling_params.frequency_penalty
351
+ if sampling_params.min_p > _SAMPLING_EPS:
352
+ self.min_p_reqs.add(req_id)
353
+ if sampling_params.frequency_penalty != 0.0:
354
+ self.frequency_penalties_reqs.add(req_id)
355
+ self.presence_penalties_cpu[
356
+ req_index] = sampling_params.presence_penalty
357
+ if sampling_params.presence_penalty != 0.0:
358
+ self.presence_penalties_reqs.add(req_id)
359
+ self.repetition_penalties_cpu[
360
+ req_index] = sampling_params.repetition_penalty
361
+ if sampling_params.repetition_penalty != 1.0:
362
+ self.repetition_penalties_reqs.add(req_id)
363
+ if sampling_params.min_tokens:
364
+ self.min_tokens[req_index] = (
365
+ sampling_params.min_tokens,
366
+ sampling_params.all_stop_token_ids)
367
+
368
+ if sampling_params.extra_args and "top_n_sigma" in sampling_params.extra_args:
369
+ self.top_n_sigma_cpu[
370
+ req_index] = sampling_params.extra_args["top_n_sigma"]
371
+ self.top_n_sigma_reqs.add(req_id)
372
+ else:
373
+ self.top_n_sigma_cpu[req_index] = -1
374
+
375
+ # NOTE(woosuk): self.generators should not include the requests that
376
+ # do not have their own generator.
377
+ if request.generator is not None:
378
+ self.generators[req_index] = request.generator
379
+
380
+ if sampling_params.logprobs is not None:
381
+ self.num_logprobs[req_id] = sampling_params.logprobs
382
+ if sampling_params.prompt_logprobs is not None:
383
+ self.num_prompt_logprobs[
384
+ req_id] = sampling_params.prompt_logprobs
385
+ if sampling_params.logit_bias is not None:
386
+ self.logit_bias[req_index] = sampling_params.logit_bias
387
+
388
+ if sampling_params.allowed_token_ids:
389
+ self.has_allowed_token_ids.add(req_id)
390
+ if self.allowed_token_ids_mask_cpu_tensor is None:
391
+ # Lazy allocation for this tensor, which can be large.
392
+ # False means we don't fill with -inf.
393
+ self.allowed_token_ids_mask = torch.zeros(
394
+ self.max_num_reqs,
395
+ self.vocab_size,
396
+ dtype=torch.bool,
397
+ device=self.device)
398
+ self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
399
+ self.max_num_reqs,
400
+ self.vocab_size,
401
+ dtype=torch.bool,
402
+ device="cpu")
403
+ self.allowed_token_ids_mask_cpu_tensor[req_index] = True
404
+ # False means we don't fill with -inf.
405
+ self.allowed_token_ids_mask_cpu_tensor[req_index][
406
+ sampling_params.allowed_token_ids] = False
407
+
408
+ if sampling_params.bad_words_token_ids:
409
+ self.bad_words_token_ids[
410
+ req_index] = sampling_params.bad_words_token_ids
411
+ else:
412
+ assert request.pooling_params is not None
413
+ self.pooling_params[req_id] = request.pooling_params
414
+
415
+ # Add request lora ID
416
+ if request.lora_request:
417
+ lora_id = request.lora_request.lora_int_id
418
+ if lora_id not in self.lora_id_to_request_ids:
419
+ self.lora_id_to_request_ids[lora_id] = set()
420
+
421
+ self.request_lora_mapping[req_index] = lora_id
422
+ self.lora_id_to_request_ids[lora_id].add(request.req_id)
423
+ self.lora_id_to_lora_request[lora_id] = request.lora_request
424
+ else:
425
+ # No LoRA
426
+ self.request_lora_mapping[req_index] = 0
427
+
428
+ def remove_request(self, req_id: str) -> Optional[int]:
429
+ """This method must always be followed by a call to condense()."""
430
+
431
+ req_index = self.req_id_to_index.pop(req_id, None)
432
+ if req_index is None:
433
+ return None
434
+ self._req_ids[req_index] = None
435
+ self.req_output_token_ids[req_index] = None
436
+
437
+ self.greedy_reqs.discard(req_id)
438
+ self.random_reqs.discard(req_id)
439
+ self.top_p_reqs.discard(req_id)
440
+ self.top_k_reqs.discard(req_id)
441
+ self.min_p_reqs.discard(req_id)
442
+ self.min_tokens.pop(req_index, None)
443
+ self.frequency_penalties_reqs.discard(req_id)
444
+ self.presence_penalties_reqs.discard(req_id)
445
+ self.repetition_penalties_reqs.discard(req_id)
446
+ self.spec_decode_unsupported_reqs.discard(req_id)
447
+ self.top_n_sigma_reqs.discard(req_id)
448
+ self.generators.pop(req_index, None)
449
+ self.num_logprobs.pop(req_id, None)
450
+ self.num_prompt_logprobs.pop(req_id, None)
451
+ self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
452
+
453
+ # LoRA
454
+ lora_id = self.request_lora_mapping[req_index]
455
+ if lora_id != 0:
456
+ self.lora_id_to_request_ids[lora_id].discard(req_id)
457
+ if len(self.lora_id_to_request_ids[lora_id]) == 0:
458
+ self.lora_id_to_request_ids.pop(lora_id)
459
+ self.lora_id_to_lora_request.pop(lora_id)
460
+ self.request_lora_mapping[req_index] = 0
461
+
462
+ self.logit_bias[req_index] = None
463
+ self.has_allowed_token_ids.discard(req_id)
464
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
465
+ # False means we don't fill with -inf.
466
+ self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
467
+ self.bad_words_token_ids.pop(req_index, None)
468
+ self.pooling_params.pop(req_id, None)
469
+ return req_index
470
+
471
+ def swap_states(self, i1: int, i2: int) -> None:
472
+ old_id_i1 = self._req_ids[i1]
473
+ old_id_i2 = self._req_ids[i2]
474
+ self._req_ids[i1], self._req_ids[i2] =\
475
+ self._req_ids[i2], self._req_ids[i1] # noqa
476
+ self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
477
+ self.req_output_token_ids[i2], self.req_output_token_ids[i1]
478
+ assert old_id_i1 is not None and old_id_i2 is not None
479
+ self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
480
+ self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
481
+ self.num_tokens[i1], self.num_tokens[i2] =\
482
+ self.num_tokens[i2], self.num_tokens[i1]
483
+ self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
484
+ self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
485
+ self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
486
+ self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
487
+ self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
488
+ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
489
+ self.temperature_cpu[i1], self.temperature_cpu[i2] =\
490
+ self.temperature_cpu[i2], self.temperature_cpu[i1]
491
+ self.top_p_cpu[i1], self.top_p_cpu[i2] =\
492
+ self.top_p_cpu[i2], self.top_p_cpu[i1]
493
+ self.top_k_cpu[i1], self.top_k_cpu[i2] =\
494
+ self.top_k_cpu[i2], self.top_k_cpu[i1]
495
+ self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
496
+ self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
497
+ self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
498
+ self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
499
+ self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
500
+ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
501
+ self.min_p_cpu[i1], self.min_p_cpu[i2] =\
502
+ self.min_p_cpu[i2], self.min_p_cpu[i1]
503
+ self.top_n_sigma_cpu[i1], self.top_n_sigma_cpu[i2] =\
504
+ self.top_n_sigma_cpu[i2], self.top_n_sigma_cpu[i1]
505
+
506
+ # NOTE: the following is unsafe
507
+ # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
508
+ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
509
+ # instead, we need to temporiarily copy the data for one of the indices
510
+ # TODO(lucas): optimize this by only copying valid indices
511
+ tmp = self.token_ids_cpu[i1, ...].copy()
512
+ self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
513
+ self.token_ids_cpu[i2, ...] = tmp
514
+
515
+ swap_dict_values(self.generators, i1, i2)
516
+ swap_dict_values(self.min_tokens, i1, i2)
517
+ swap_dict_values(self.bad_words_token_ids, i1, i2)
518
+
519
+ self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
520
+ self.request_lora_mapping[i2], self.request_lora_mapping[i1]
521
+ self.logit_bias[i1], self.logit_bias[i2] =\
522
+ self.logit_bias[i2], self.logit_bias[i1]
523
+
524
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
525
+ self.allowed_token_ids_mask_cpu_tensor[i1], \
526
+ self.allowed_token_ids_mask_cpu_tensor[i2] =\
527
+ self.allowed_token_ids_mask_cpu_tensor[i2], \
528
+ self.allowed_token_ids_mask_cpu_tensor[i1]
529
+ self.block_table.swap_row(i1, i2)
530
+
531
+ def condense(self, empty_req_indices: list[int]) -> None:
532
+ """Move non-empty requests down into lower, empty indices.
533
+
534
+ Args:
535
+ empty_req_indices: empty batch indices, sorted descending.
536
+ """
537
+ num_reqs = self.num_reqs
538
+ if num_reqs == 0:
539
+ # The batched states are empty.
540
+ self._req_ids.clear()
541
+ self.req_output_token_ids.clear()
542
+ return
543
+
544
+ # NOTE(woosuk): This function assumes that the empty_req_indices
545
+ # is sorted in descending order.
546
+ last_req_index = num_reqs + len(empty_req_indices) - 1
547
+ while empty_req_indices:
548
+ # Find the largest non-empty index.
549
+ while last_req_index in empty_req_indices:
550
+ last_req_index -= 1
551
+
552
+ # Find the smallest empty index.
553
+ empty_index = empty_req_indices.pop()
554
+ if empty_index >= last_req_index:
555
+ break
556
+
557
+ # Swap the states.
558
+ req_id = self._req_ids[last_req_index]
559
+ output_token_ids = self.req_output_token_ids[last_req_index]
560
+ assert req_id is not None
561
+ self._req_ids[empty_index] = req_id
562
+ self._req_ids[last_req_index] = None
563
+ self.req_output_token_ids[empty_index] = output_token_ids
564
+ self.req_output_token_ids[last_req_index] = None
565
+ self.req_id_to_index[req_id] = empty_index
566
+
567
+ num_tokens = self.num_tokens[last_req_index]
568
+ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
569
+ last_req_index, :num_tokens]
570
+ self.num_tokens[empty_index] = num_tokens
571
+ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
572
+ last_req_index]
573
+ self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
574
+ last_req_index]
575
+ self.num_computed_tokens_cpu[
576
+ empty_index] = self.num_computed_tokens_cpu[last_req_index]
577
+ self.block_table.move_row(last_req_index, empty_index)
578
+ self.temperature_cpu[empty_index] = self.temperature_cpu[
579
+ last_req_index]
580
+ self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
581
+ self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
582
+ self.frequency_penalties_cpu[
583
+ empty_index] = self.frequency_penalties_cpu[last_req_index]
584
+ self.presence_penalties_cpu[
585
+ empty_index] = self.presence_penalties_cpu[last_req_index]
586
+ self.repetition_penalties_cpu[
587
+ empty_index] = self.repetition_penalties_cpu[last_req_index]
588
+ self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
589
+ self.top_n_sigma_cpu[
590
+ empty_index] = self.top_n_sigma_cpu[last_req_index]
591
+ generator = self.generators.pop(last_req_index, None)
592
+ if generator is not None:
593
+ self.generators[empty_index] = generator
594
+
595
+ min_token = self.min_tokens.pop(last_req_index, None)
596
+ if min_token is not None:
597
+ self.min_tokens[empty_index] = min_token
598
+
599
+ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
600
+ last_req_index]
601
+
602
+ self.logit_bias[empty_index] = self.logit_bias[last_req_index]
603
+
604
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
605
+ self.allowed_token_ids_mask_cpu_tensor[
606
+ empty_index] = self.allowed_token_ids_mask_cpu_tensor[
607
+ last_req_index]
608
+
609
+ bad_words_token_ids = self.bad_words_token_ids.pop(
610
+ last_req_index, None)
611
+ if bad_words_token_ids is not None:
612
+ self.bad_words_token_ids[empty_index] = bad_words_token_ids
613
+ # Decrement last_req_index since it is now empty.
614
+ last_req_index -= 1
615
+
616
+ # Trim lists to the batch size.
617
+ del self._req_ids[self.num_reqs:]
618
+ del self.req_output_token_ids[self.num_reqs:]
619
+
620
+ def refresh_sampling_metadata(self):
621
+ self.sampling_metadata = self._make_sampling_metadata()
622
+
623
+ def _make_sampling_metadata(self) -> Union[SamplingMetadata, SamplingMetadataTopNSigma]:
624
+ num_reqs = self.num_reqs
625
+ if not self.all_greedy:
626
+ temperature = copy_slice(self.temperature_cpu_tensor,
627
+ self.temperature, num_reqs)
628
+ else:
629
+ temperature = None
630
+ if not self.no_top_p:
631
+ copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
632
+ if not self.no_top_k:
633
+ copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
634
+ if not self.no_min_p:
635
+ copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
636
+
637
+ if not self.no_penalties:
638
+ # Since syncing these tensors is expensive only copy them
639
+ # if necessary i.e. if there are requests which require
640
+ # penalties to be applied during sampling.
641
+ copy_slice(self.frequency_penalties_cpu_tensor,
642
+ self.frequency_penalties, num_reqs)
643
+ copy_slice(self.presence_penalties_cpu_tensor,
644
+ self.presence_penalties, num_reqs)
645
+ copy_slice(self.repetition_penalties_cpu_tensor,
646
+ self.repetition_penalties, num_reqs)
647
+
648
+ if not self.no_top_n_sigma:
649
+ copy_slice(self.top_n_sigma_cpu_tensor,
650
+ self.top_n_sigma, num_reqs)
651
+
652
+
653
+ needs_prompt_token_ids = (not self.no_penalties or
654
+ (self.num_reqs > 0
655
+ and self.logits_processing_needs_token_ids))
656
+ if needs_prompt_token_ids:
657
+ # The prompt tokens are used only for applying penalties or
658
+ # step pooling during the sampling/pooling process.
659
+ # Hence copy these tensors only when there are requests which
660
+ # need penalties/step_pooler to be applied.
661
+ prompt_token_ids = self._make_prompt_token_ids_tensor()
662
+ else:
663
+ prompt_token_ids = None
664
+
665
+ allowed_token_ids_mask: Optional[torch.Tensor] = None
666
+ if not self.no_allowed_token_ids:
667
+ assert self.allowed_token_ids_mask is not None
668
+ copy_slice(self.allowed_token_ids_mask_cpu_tensor,
669
+ self.allowed_token_ids_mask, num_reqs)
670
+ allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
671
+
672
+ return SamplingMetadataTopNSigma(
673
+ temperature=temperature,
674
+ all_greedy=self.all_greedy,
675
+ all_random=self.all_random,
676
+ top_p=None if self.no_top_p else self.top_p[:num_reqs],
677
+ top_k=None if self.no_top_k else self.top_k[:num_reqs],
678
+ generators=self.generators,
679
+ max_num_logprobs=self.max_num_logprobs,
680
+ prompt_token_ids=prompt_token_ids,
681
+ frequency_penalties=self.frequency_penalties[:num_reqs],
682
+ presence_penalties=self.presence_penalties[:num_reqs],
683
+ repetition_penalties=self.repetition_penalties[:num_reqs],
684
+ top_n_sigma=self.top_n_sigma[:num_reqs],
685
+ output_token_ids=cast(list[list[int]], self.req_output_token_ids),
686
+ no_penalties=self.no_penalties,
687
+ no_top_n_sigma=self.no_top_n_sigma,
688
+ allowed_token_ids_mask=allowed_token_ids_mask,
689
+ bad_words_token_ids=self.bad_words_token_ids,
690
+ logitsprocs=self.logitsprocs,
691
+ )
692
+
693
+ @property
694
+ def pooling_metadata(self) -> PoolingMetadata:
695
+ if len(self.pooling_params) == 0:
696
+ pooling_params = []
697
+ else:
698
+ # Note, for now this assumes that all request in the batch
699
+ # are either sampling or pooling requests
700
+ assert len(self.req_ids) == len(self.pooling_params)
701
+ pooling_params = [
702
+ self.pooling_params[req_id] for req_id in self.req_ids
703
+ ]
704
+
705
+ return PoolingMetadata(
706
+ prompt_lens=torch.from_numpy(
707
+ self.num_prompt_tokens[:self.num_reqs]).to(self.device),
708
+ prompt_token_ids=self.sampling_metadata.prompt_token_ids,
709
+ pooling_params=pooling_params,
710
+ )
711
+
712
+ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
713
+ max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
714
+ prompt_token_ids_cpu_tensor = torch.empty(
715
+ (self.num_reqs, max_prompt_len),
716
+ device="cpu",
717
+ dtype=torch.int64,
718
+ pin_memory=self.pin_memory,
719
+ )
720
+ prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
721
+ prompt_token_ids[:] = self.token_ids_cpu[:self.
722
+ num_reqs, :max_prompt_len]
723
+ # Use the value of vocab_size as a pad since we don't have a
724
+ # token_id of this value.
725
+ for i in range(self.num_reqs):
726
+ prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
727
+ return prompt_token_ids_cpu_tensor.to(device=self.device,
728
+ non_blocking=True)
729
+
730
+ def make_lora_inputs(
731
+ self, num_scheduled_tokens: np.ndarray
732
+ ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
733
+ """
734
+ Given the num_scheduled_tokens for each request in the batch, return
735
+ datastructures used to activate the current LoRAs.
736
+ Returns:
737
+ 1. prompt_lora_mapping: A tuple of size self.num_reqs where,
738
+ prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
739
+ 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
740
+ where, token_lora_mapping[i] is the LoRA id to use for ith token.
741
+ 3. lora_requests: Set of relevant LoRA requests.
742
+ """
743
+
744
+ req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
745
+ prompt_lora_mapping = tuple(req_lora_mapping)
746
+ token_lora_mapping = tuple(
747
+ req_lora_mapping.repeat(num_scheduled_tokens))
748
+ active_lora_requests: set[LoRARequest] = set(
749
+ self.lora_id_to_lora_request.values())
750
+
751
+ return prompt_lora_mapping, token_lora_mapping, active_lora_requests
752
+
753
+ @property
754
+ def num_reqs(self) -> int:
755
+ return len(self.req_id_to_index)
756
+
757
+ @property
758
+ def all_greedy(self) -> bool:
759
+ return len(self.random_reqs) == 0
760
+
761
+ @property
762
+ def all_random(self) -> bool:
763
+ return len(self.greedy_reqs) == 0
764
+
765
+ @property
766
+ def no_top_p(self) -> bool:
767
+ return len(self.top_p_reqs) == 0
768
+
769
+ @property
770
+ def no_top_k(self) -> bool:
771
+ return len(self.top_k_reqs) == 0
772
+
773
+ @property
774
+ def no_min_p(self) -> bool:
775
+ return len(self.min_p_reqs) == 0
776
+
777
+ @property
778
+ def no_penalties(self) -> bool:
779
+ return (len(self.presence_penalties_reqs) == 0
780
+ and len(self.frequency_penalties_reqs) == 0
781
+ and len(self.repetition_penalties_reqs) == 0)
782
+ @property
783
+ def no_top_n_sigma(self) -> bool:
784
+ return len(self.top_n_sigma_reqs) == 0
785
+
786
+ @property
787
+ def max_num_logprobs(self) -> Optional[int]:
788
+ return max(self.num_logprobs.values()) if self.num_logprobs else None
789
+
790
+ @property
791
+ def no_prompt_logprob(self) -> bool:
792
+ return not self.num_prompt_logprobs
793
+
794
+ @property
795
+ def no_allowed_token_ids(self) -> bool:
796
+ return len(self.has_allowed_token_ids) == 0
inference/vllm_ascend_for_openpangu_embedded_7b.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Deployment Guide of openPangu Embedded 7B Based on [vllm-ascend](https://github.com/vllm-project/vllm-ascend)
2
+
3
+ ### Deployment Environment Description
4
+
5
+ The Atlas 800T A2 (64 GB) supports the deployment of openPangu Embedded 7B (bf16) with 4 cards. The vllm-ascend community image v0.9.1-dev is used and needs to be pulled on multiple nodes.
6
+ ```bash
7
+ docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
8
+ ```
9
+
10
+ ### Docker Boot and Inference Code
11
+
12
+ Perform the following operations on all nodes.
13
+
14
+ Run the following command to start the docker:
15
+ ```bash
16
+ # Update the vllm-ascend image
17
+ export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
18
+ export NAME=vllm-ascend # Custom docker name
19
+
20
+ # Run the container using the defined variables
21
+ # Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
22
+ # To prevent device interference from other docker containers, add the argument "--privileged"
23
+ docker run --rm \
24
+ --name $NAME \
25
+ --network host \
26
+ --device /dev/davinci0 \
27
+ --device /dev/davinci1 \
28
+ --device /dev/davinci2 \
29
+ --device /dev/davinci3 \
30
+ --device /dev/davinci4 \
31
+ --device /dev/davinci5 \
32
+ --device /dev/davinci6 \
33
+ --device /dev/davinci7 \
34
+ --device /dev/davinci_manager \
35
+ --device /dev/devmm_svm \
36
+ --device /dev/hisi_hdc \
37
+ -v /usr/local/dcmi:/usr/local/dcmi \
38
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
39
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
40
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
41
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
42
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
43
+ -v /mnt/sfs_turbo/.cache:/root/.cache \
44
+ -it $IMAGE bash
45
+ ```
46
+ If not inside the container, enter the container as the root user:
47
+ ```
48
+ docker exec -itu root $NAME /bin/bash
49
+ ```
50
+
51
+ Download vllm (v0.9.2) to replace the built-in vllm code of the image.
52
+ ```bash
53
+ pip install --no-deps vllm==0.9.2 pybase64==1.4.1
54
+ ```
55
+
56
+ Download [vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1) and replace the built-in vllm-ascend code in the image (/vllm-workspace/vllm-ascend/). For example, download [Source code (tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz) from Assets to get v0.9.2rc1.tar.gz, then extract and replace:
57
+
58
+ ```bash
59
+ tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
60
+ export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
61
+ ```
62
+
63
+ Use the Pangu model-adapted vllm-ascend code from the current repository to replace parts of the code in `/vllm-workspace/vllm-ascend/vllm_ascend/`:
64
+
65
+ ```bash
66
+ yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
67
+ ```
68
+
69
+ Replace the `tokenizer_config.json` file with the one that has added `special token`, [old file](../tokenizer_config.json) -> [new file](./vllm_ascend/tokenizer_config.json)
70
+ ```bash
71
+ cp ./vllm_ascend/tokenizer_config.json ../tokenizer_config.json
72
+ ```
73
+
74
+ ### openPangu Embedded 7B Inference
75
+
76
+ Perform the following operations on all nodes.
77
+
78
+ Configuration:
79
+ ```bash
80
+ export VLLM_USE_V1=1
81
+ # Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
82
+ # Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
83
+ HOST=xxx.xxx.xxx.xxx
84
+ PORT=8080
85
+ ```
86
+
87
+ openPangu Embedded 7B running command:
88
+ ```bash
89
+ export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
90
+ LOCAL_CKPT_DIR=/root/.cache/pangu_embedded_7b # The pangu_embedded_7b bf16 weight
91
+ SERVED_MODEL_NAME=pangu_embedded_7b
92
+
93
+ vllm serve $LOCAL_CKPT_DIR \
94
+ --served-model-name $SERVED_MODEL_NAME \
95
+ --tensor-parallel-size 4 \
96
+ --trust-remote-code \
97
+ --host $HOST \
98
+ --port $PORT \
99
+ --max-num-seqs 32 \
100
+ --max-model-len 32768 \
101
+ --max-num-batched-tokens 4096 \
102
+ --tokenizer-mode "slow" \
103
+ --dtype bfloat16 \
104
+ --distributed-executor-backend mp \
105
+ --gpu-memory-utilization 0.93 \
106
+ --no-enable-prefix-caching \
107
+ --no-enable-chunked-prefill \
108
+ ```
109
+
110
+ ### Test Request
111
+
112
+ After server launched, send test request from master node or other nodes:
113
+
114
+ ```bash
115
+ MASTER_NODE_IP=xxx.xxx.xxx.xxx # server node ip
116
+ curl http://${MASTER_NODE_IP}:${PORT}/v1/chat/completions \
117
+ -H "Content-Type: application/json" \
118
+ -d '{
119
+ "model": "'$SERVED_MODEL_NAME'",
120
+ "messages": [
121
+ {
122
+ "role": "user",
123
+ "content": "Who are you?"
124
+ }
125
+ ],
126
+ "max_tokens": 512,
127
+ "temperature": 0
128
+ }'
129
+ ```
inference/vllm_ascend_for_openpangu_embedded_7b.zh.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## openPangu Embedded 7B 在[vllm-ascend](https://github.com/vllm-project/vllm-ascend)部署指导文档
2
+
3
+ ### 部署环境说明
4
+
5
+ Atlas 800T A2(64GB) 4卡可部署openPangu Embedded 7B (bf16),选用vllm-ascend社区镜像v0.9.1-dev。
6
+ ```bash
7
+ docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
8
+ ```
9
+
10
+ ### 镜像启动和推理代码适配
11
+
12
+ 以下操作需在每个节点都执行。
13
+
14
+ 启动镜像。
15
+ ```bash
16
+ # Update the vllm-ascend image
17
+ export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
18
+ export NAME=vllm-ascend # Custom docker name
19
+
20
+ # Run the container using the defined variables
21
+ # Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
22
+ # To prevent device interference from other docker containers, add the argument "--privileged"
23
+ docker run --rm \
24
+ --name $NAME \
25
+ --network host \
26
+ --device /dev/davinci0 \
27
+ --device /dev/davinci1 \
28
+ --device /dev/davinci2 \
29
+ --device /dev/davinci3 \
30
+ --device /dev/davinci4 \
31
+ --device /dev/davinci5 \
32
+ --device /dev/davinci6 \
33
+ --device /dev/davinci7 \
34
+ --device /dev/davinci_manager \
35
+ --device /dev/devmm_svm \
36
+ --device /dev/hisi_hdc \
37
+ -v /usr/local/dcmi:/usr/local/dcmi \
38
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
39
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
40
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
41
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
42
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
43
+ -v /mnt/sfs_turbo/.cache:/root/.cache \
44
+ -it $IMAGE bash
45
+ ```
46
+
47
+ 如果未进入容器,需以root用户进入容器。
48
+ ```
49
+ docker exec -itu root $NAME /bin/bash
50
+ ```
51
+
52
+ 下载vllm (v0.9.2),替换镜像内置的vllm代码。
53
+ ```bash
54
+ pip install --no-deps vllm==0.9.2 pybase64==1.4.1
55
+ ```
56
+
57
+ 下载[vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1),替换镜像内置的vllm-ascend代码(`/vllm-workspace/vllm-ascend/`)。例如下载Assets中的[Source code
58
+ (tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz)得到v0.9.2rc1.tar.gz,然后解压并替换:
59
+ ```bash
60
+ tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
61
+ export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
62
+ ```
63
+
64
+ 使用当前代码仓中适配盘古模型的vllm-ascend代码替换`/vllm-workspace/vllm-ascend/vllm_ascend/`中的部分代码。
65
+ ```bash
66
+ yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
67
+ ```
68
+
69
+ 替换增加`special token`后的`tokenizer_config.json`文件,[旧文件](../tokenizer_config.json) -> [新文件](./vllm_ascend/tokenizer_config.json)
70
+ ```bash
71
+ cp ./vllm_ascend/tokenizer_config.json ../tokenizer_config.json
72
+ ```
73
+
74
+ ### openPangu Embedded 7B推理
75
+
76
+ 以下操作需在每个节点都执行。
77
+
78
+ 配置:
79
+ ```bash
80
+ export VLLM_USE_V1=1
81
+ # Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
82
+ # Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
83
+ HOST=xxx.xxx.xxx.xxx
84
+ PORT=8080
85
+ ```
86
+
87
+ openPangu Embedded 7B 运行命令:
88
+ ```bash
89
+ export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
90
+ LOCAL_CKPT_DIR=/root/.cache/pangu_embedded_7b # The pangu_embedded_7b bf16 weight
91
+ SERVED_MODEL_NAME=pangu_embedded_7b
92
+
93
+ vllm serve $LOCAL_CKPT_DIR \
94
+ --served-model-name $SERVED_MODEL_NAME \
95
+ --tensor-parallel-size 4 \
96
+ --trust-remote-code \
97
+ --host $HOST \
98
+ --port $PORT \
99
+ --max-num-seqs 32 \
100
+ --max-model-len 32768 \
101
+ --max-num-batched-tokens 4096 \
102
+ --tokenizer-mode "slow" \
103
+ --dtype bfloat16 \
104
+ --distributed-executor-backend mp \
105
+ --gpu-memory-utilization 0.93 \
106
+ --no-enable-prefix-caching \
107
+ --no-enable-chunked-prefill \
108
+ ```
109
+
110
+ ### 发请求测试
111
+
112
+ 服务启动后,在主节点或者其他节点向主节点发送测试请求:
113
+
114
+ ```bash
115
+ MASTER_NODE_IP=xxx.xxx.xxx.xxx # server node ip
116
+ curl http://${MASTER_NODE_IP}:${PORT}/v1/chat/completions \
117
+ -H "Content-Type: application/json" \
118
+ -d '{
119
+ "model": "'$SERVED_MODEL_NAME'",
120
+ "messages": [
121
+ {
122
+ "role": "user",
123
+ "content": "Who are you?"
124
+ }
125
+ ],
126
+ "max_tokens": 512,
127
+ "temperature": 0
128
+ }'
129
+ ```
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8af395f1ed0b7b4ca70828c4d87db0cc72e3e5720c3d462265dcc87a977fd1dd
3
+ size 4926842416
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb8081a9578a22fb500f851dcfd5af07029af73516663b589724d09215dc7edc
3
+ size 4991686576
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea5b8e7ac544c08954a47e1a98ea7b5c6fc26f3c12c39469159896158fa43e9c
3
+ size 4886853760
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b65417ef013c072961910a91db017836bbfbf300422fd3985652c7807f45adc9
3
+ size 1256456320
model.safetensors.index.json ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16061784576
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
20
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
21
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
32
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
33
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
34
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
35
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
36
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
42
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.10.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
44
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
46
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
48
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
49
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
56
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.11.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
58
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
60
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
62
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
63
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
70
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.12.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
72
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
74
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
76
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
77
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
84
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.13.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
86
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
88
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
90
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
91
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
98
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.14.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
100
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
101
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
102
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
103
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
104
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
105
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
112
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
113
+ "model.layers.15.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
114
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
116
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
118
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
119
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
122
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
125
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
126
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
127
+ "model.layers.16.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
128
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
129
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
130
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
132
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
133
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
134
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
137
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
139
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
140
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
141
+ "model.layers.17.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
142
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
143
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
144
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
145
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
146
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
147
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
148
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
149
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
150
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
151
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
152
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
153
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
154
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
155
+ "model.layers.18.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
156
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
157
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
158
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
159
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
160
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
161
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
162
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
163
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
164
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
165
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
166
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
167
+ "model.layers.19.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
168
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
169
+ "model.layers.19.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
170
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
171
+ "model.layers.19.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
172
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
173
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
174
+ "model.layers.19.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
175
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
176
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
177
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
178
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
179
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
180
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
181
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
182
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
183
+ "model.layers.2.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
184
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
185
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
186
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
187
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
188
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
189
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
190
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
191
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
192
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
193
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
194
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
195
+ "model.layers.20.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
196
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
197
+ "model.layers.20.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
198
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
199
+ "model.layers.20.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
200
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
201
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
202
+ "model.layers.20.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
203
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
204
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
206
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
207
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
209
+ "model.layers.21.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
210
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
211
+ "model.layers.21.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
212
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
213
+ "model.layers.21.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
214
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
215
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
216
+ "model.layers.21.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
217
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
218
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
220
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
221
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
223
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
224
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.22.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
226
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
228
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
230
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
231
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
233
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
235
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
238
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.23.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
240
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
241
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
242
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
244
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
245
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
247
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
252
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.24.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
254
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
256
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
257
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
258
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
259
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
261
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
262
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
263
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
264
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
265
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
266
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
267
+ "model.layers.25.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
268
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
269
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
270
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
271
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
272
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
273
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
274
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
275
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
276
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
277
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
278
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
279
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
280
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
281
+ "model.layers.26.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
282
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
283
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
284
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
285
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
286
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
287
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
288
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
289
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
290
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
291
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
292
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
293
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
294
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
295
+ "model.layers.27.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
296
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
297
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
298
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
299
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
300
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
301
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
302
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
303
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
304
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
305
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
306
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
307
+ "model.layers.28.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
308
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
309
+ "model.layers.28.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
310
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
311
+ "model.layers.28.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
312
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
313
+ "model.layers.28.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
314
+ "model.layers.28.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
315
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
316
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
317
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
318
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
319
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
320
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
321
+ "model.layers.29.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
322
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
323
+ "model.layers.29.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
324
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
325
+ "model.layers.29.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
326
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
327
+ "model.layers.29.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
328
+ "model.layers.29.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
329
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
330
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
331
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
332
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
333
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
334
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
335
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
336
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
337
+ "model.layers.3.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
338
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
339
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
340
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
341
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
342
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
343
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
344
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
345
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
346
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
347
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
348
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
349
+ "model.layers.30.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
350
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
351
+ "model.layers.30.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
352
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
353
+ "model.layers.30.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
354
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
355
+ "model.layers.30.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
356
+ "model.layers.30.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
357
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
358
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
359
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
360
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
361
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
362
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
363
+ "model.layers.31.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
364
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
365
+ "model.layers.31.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
366
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
367
+ "model.layers.31.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
368
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
369
+ "model.layers.31.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
370
+ "model.layers.31.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
371
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
372
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
373
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
374
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
375
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
376
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
377
+ "model.layers.32.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
378
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
379
+ "model.layers.32.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
380
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
381
+ "model.layers.32.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
382
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
383
+ "model.layers.32.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
384
+ "model.layers.32.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
385
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
386
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
387
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
388
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
389
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
390
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
391
+ "model.layers.33.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
392
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
393
+ "model.layers.33.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
394
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
395
+ "model.layers.33.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
396
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
397
+ "model.layers.33.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
398
+ "model.layers.33.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
399
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
400
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
401
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
402
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
403
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
404
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
405
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
406
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
407
+ "model.layers.4.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
408
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
409
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
410
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
411
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
412
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
413
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
414
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
415
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
416
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
417
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
418
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
419
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
420
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
421
+ "model.layers.5.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
422
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
423
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
424
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
425
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
426
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
427
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
428
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
429
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
430
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
431
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
432
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
433
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
434
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
435
+ "model.layers.6.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
436
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
437
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
438
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
439
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
440
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
441
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
442
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
443
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
444
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
445
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
446
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
447
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
448
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
449
+ "model.layers.7.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
450
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
451
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
452
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
453
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
454
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
455
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
456
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
457
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
458
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
459
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
460
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
461
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
462
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
463
+ "model.layers.8.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
464
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
465
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
466
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
467
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
468
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
469
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
470
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
471
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
472
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
473
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
474
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
475
+ "model.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
476
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
477
+ "model.layers.9.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
478
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
479
+ "model.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
480
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
481
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
482
+ "model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
483
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
484
+ "model.norm.weight": "model-00003-of-00004.safetensors"
485
+ }
486
+ }
modeling_openpangu_dense.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from modular_openpangu_dense.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_openpangu_dense.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+
8
+ # coding=utf-8
9
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
10
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
11
+ #
12
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
13
+ # and OPT implementations in this library. It has been modified from its
14
+ # original forms to accommodate minor architectural differences compared
15
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+ from typing import Callable, Optional, Union
30
+
31
+ import torch
32
+ from torch import nn
33
+
34
+ import torch_npu
35
+ from torch_npu.contrib import transfer_to_npu
36
+ if "910" in torch.npu.get_device_name():
37
+ NPU_ATTN_INFR = True
38
+ print("[INFO] torch_npu detected. Using NPU fused infer attention.")
39
+ else:
40
+ NPU_ATTN_INFR = False
41
+
42
+ from transformers.activations import ACT2FN
43
+ from transformers.cache_utils import Cache, DynamicCache
44
+ from transformers.generation import GenerationMixin
45
+ from transformers.masking_utils import create_causal_mask
46
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
47
+ from transformers.modeling_layers import GradientCheckpointingLayer
48
+ from transformers.modeling_outputs import (
49
+ BaseModelOutputWithPast,
50
+ CausalLMOutputWithPast,
51
+ SequenceClassifierOutputWithPast,
52
+ )
53
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
54
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
55
+ from transformers.processing_utils import Unpack
56
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
57
+ from .configuration_openpangu_dense import PanguEmbeddedConfig
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+
63
+ class PanguEmbeddedRMSNorm(nn.Module):
64
+ def __init__(self, hidden_size, eps=1e-6):
65
+ """
66
+ PanguEmbeddedRMSNorm is equivalent to T5LayerNorm
67
+ """
68
+ super().__init__()
69
+ self.weight = nn.Parameter(torch.ones(hidden_size))
70
+ self.variance_epsilon = eps
71
+
72
+ def forward(self, hidden_states):
73
+ input_dtype = hidden_states.dtype
74
+ hidden_states = hidden_states.to(torch.float32)
75
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
76
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
77
+ return self.weight * hidden_states.to(input_dtype)
78
+
79
+ def extra_repr(self):
80
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
81
+
82
+
83
+ class PanguEmbeddedRotaryEmbedding(nn.Module):
84
+ def __init__(self, config: PanguEmbeddedConfig, device=None):
85
+ super().__init__()
86
+ # BC: "rope_type" was originally "type"
87
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
88
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
89
+ else:
90
+ self.rope_type = "default"
91
+ self.max_seq_len_cached = config.max_position_embeddings
92
+ self.original_max_seq_len = config.max_position_embeddings
93
+
94
+ self.config = config
95
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
96
+
97
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
98
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
99
+ self.original_inv_freq = self.inv_freq
100
+
101
+ @torch.no_grad()
102
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
103
+ def forward(self, x, position_ids):
104
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
105
+ position_ids_expanded = position_ids[:, None, :].float()
106
+
107
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
108
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
109
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
110
+ emb = torch.cat((freqs, freqs), dim=-1)
111
+ cos = emb.cos() * self.attention_scaling
112
+ sin = emb.sin() * self.attention_scaling
113
+
114
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
115
+
116
+
117
+ def rotate_half(x):
118
+ """Rotates half the hidden dims of the input."""
119
+ x1 = x[..., : x.shape[-1] // 2]
120
+ x2 = x[..., x.shape[-1] // 2 :]
121
+ return torch.cat((-x2, x1), dim=-1)
122
+
123
+
124
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
125
+ """Applies Rotary Position Embedding to the query and key tensors.
126
+
127
+ Args:
128
+ q (`torch.Tensor`): The query tensor.
129
+ k (`torch.Tensor`): The key tensor.
130
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
131
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
132
+ position_ids (`torch.Tensor`, *optional*):
133
+ Deprecated and unused.
134
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
135
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
136
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
137
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
138
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
139
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
140
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
141
+ Returns:
142
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
143
+ """
144
+ cos = cos.unsqueeze(unsqueeze_dim)
145
+ sin = sin.unsqueeze(unsqueeze_dim)
146
+ q_embed = (q * cos) + (rotate_half(q) * sin)
147
+ k_embed = (k * cos) + (rotate_half(k) * sin)
148
+ return q_embed, k_embed
149
+
150
+
151
+ class PanguEmbeddedMLP(nn.Module):
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.config = config
155
+ self.hidden_size = config.hidden_size
156
+ self.intermediate_size = config.intermediate_size
157
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
158
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
159
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
160
+ self.act_fn = ACT2FN[config.hidden_act]
161
+
162
+ def forward(self, x):
163
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
164
+ return down_proj
165
+
166
+
167
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
168
+ """
169
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
170
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
171
+ """
172
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
173
+ if n_rep == 1:
174
+ return hidden_states
175
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
176
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
177
+
178
+
179
+ def eager_attention_forward(
180
+ module: nn.Module,
181
+ query: torch.Tensor,
182
+ key: torch.Tensor,
183
+ value: torch.Tensor,
184
+ attention_mask: Optional[torch.Tensor],
185
+ scaling: float,
186
+ dropout: float = 0.0,
187
+ **kwargs,
188
+ ):
189
+ key_states = repeat_kv(key, module.num_key_value_groups)
190
+ value_states = repeat_kv(value, module.num_key_value_groups)
191
+
192
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
193
+ if attention_mask is not None:
194
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
195
+ attn_weights = attn_weights + causal_mask
196
+
197
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
198
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
199
+ attn_output = torch.matmul(attn_weights, value_states)
200
+ attn_output = attn_output.transpose(1, 2).contiguous()
201
+
202
+ return attn_output, attn_weights
203
+
204
+
205
+ class PanguEmbeddedAttention(nn.Module):
206
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
207
+
208
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
209
+ super().__init__()
210
+ self.config = config
211
+ self.layer_idx = layer_idx
212
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
213
+ self.num_heads = config.num_attention_heads
214
+ self.num_key_value_heads = config.num_key_value_heads
215
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
216
+ self.scaling = self.head_dim**-0.5
217
+ self.attention_dropout = config.attention_dropout
218
+ self.is_causal = True
219
+
220
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
221
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
222
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
223
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
224
+
225
+ def forward(
226
+ self,
227
+ hidden_states: torch.Tensor,
228
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
229
+ attention_mask: Optional[torch.Tensor],
230
+ past_key_value: Optional[Cache] = None,
231
+ cache_position: Optional[torch.LongTensor] = None,
232
+ **kwargs: Unpack[FlashAttentionKwargs],
233
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
234
+ input_shape = hidden_states.shape[:-1]
235
+ hidden_shape = (*input_shape, -1, self.head_dim)
236
+
237
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
238
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
239
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
240
+
241
+ cos, sin = position_embeddings
242
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
243
+
244
+ if past_key_value is not None:
245
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
246
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
247
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
248
+
249
+ attention_interface: Callable = eager_attention_forward
250
+ if self.config._attn_implementation != "eager":
251
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
252
+
253
+ if not self.training and NPU_ATTN_INFR:
254
+ q_len = input_shape[1]
255
+ if attention_mask is not None:
256
+ attention_mask = ~attention_mask.bool()
257
+ elif q_len > 1:
258
+ attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
259
+
260
+ attn_output, _ = torch_npu.npu_fused_infer_attention_score(
261
+ query_states, key_states, value_states,
262
+ num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
263
+ input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
264
+ attn_output = attn_output.transpose(1, 2)
265
+ attn_weights = None
266
+ else:
267
+ attn_output, attn_weights = attention_interface(
268
+ self,
269
+ query_states,
270
+ key_states,
271
+ value_states,
272
+ attention_mask,
273
+ dropout=0.0 if not self.training else self.attention_dropout,
274
+ scaling=self.scaling,
275
+ **kwargs,
276
+ )
277
+
278
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
279
+ attn_output = self.o_proj(attn_output)
280
+ return attn_output, attn_weights
281
+
282
+
283
+ class PanguEmbeddedDecoderLayer(GradientCheckpointingLayer):
284
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
285
+ super().__init__()
286
+ self.hidden_size = config.hidden_size
287
+ self.self_attn = PanguEmbeddedAttention(config=config, layer_idx=layer_idx)
288
+ self.mlp = PanguEmbeddedMLP(config)
289
+ self.input_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
+ self.post_attention_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states: torch.Tensor,
295
+ attention_mask: Optional[torch.Tensor] = None,
296
+ position_ids: Optional[torch.LongTensor] = None,
297
+ past_key_value: Optional[Cache] = None,
298
+ output_attentions: Optional[bool] = False,
299
+ use_cache: Optional[bool] = False,
300
+ cache_position: Optional[torch.LongTensor] = None,
301
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
302
+ **kwargs: Unpack[FlashAttentionKwargs],
303
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
304
+ residual = hidden_states
305
+ hidden_states = self.input_layernorm(hidden_states)
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights = self.self_attn(
309
+ hidden_states=hidden_states,
310
+ attention_mask=attention_mask,
311
+ position_ids=position_ids,
312
+ past_key_value=past_key_value,
313
+ output_attentions=output_attentions,
314
+ use_cache=use_cache,
315
+ cache_position=cache_position,
316
+ position_embeddings=position_embeddings,
317
+ **kwargs,
318
+ )
319
+ hidden_states = residual + hidden_states
320
+
321
+ # Fully Connected
322
+ residual = hidden_states
323
+ hidden_states = self.post_attention_layernorm(hidden_states)
324
+ hidden_states = self.mlp(hidden_states)
325
+ hidden_states = residual + hidden_states
326
+
327
+ outputs = (hidden_states,)
328
+ if output_attentions:
329
+ outputs += (self_attn_weights,)
330
+
331
+ return outputs
332
+
333
+
334
+ @auto_docstring
335
+ class PanguEmbeddedPreTrainedModel(PreTrainedModel):
336
+ config_class = PanguEmbeddedConfig
337
+ base_model_prefix = "model"
338
+ supports_gradient_checkpointing = True
339
+ _no_split_modules = ["PanguEmbeddedDecoderLayer"]
340
+ _skip_keys_device_placement = ["past_key_values"]
341
+ _supports_flash_attn_3 = True
342
+ _supports_flash_attn_2 = True
343
+ _supports_sdpa = True
344
+ _supports_flex_attn = True
345
+ _supports_cache_class = True
346
+ _supports_quantized_cache = True
347
+ _supports_static_cache = True
348
+ _supports_attention_backend = True
349
+
350
+ def _init_weights(self, module):
351
+ std = self.config.initializer_range
352
+ if isinstance(module, nn.Linear):
353
+ module.weight.data.normal_(mean=0.0, std=std)
354
+ if module.bias is not None:
355
+ module.bias.data.zero_()
356
+ elif isinstance(module, nn.Embedding):
357
+ module.weight.data.normal_(mean=0.0, std=std)
358
+ if module.padding_idx is not None:
359
+ module.weight.data[module.padding_idx].zero_()
360
+ elif isinstance(module, PanguEmbeddedRMSNorm):
361
+ module.weight.data.fill_(1.0)
362
+
363
+
364
+ @auto_docstring
365
+ class PanguEmbeddedModel(PanguEmbeddedPreTrainedModel):
366
+ def __init__(self, config: PanguEmbeddedConfig):
367
+ super().__init__(config)
368
+ self.padding_idx = config.pad_token_id
369
+ self.vocab_size = config.vocab_size
370
+
371
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
372
+ self.layers = nn.ModuleList(
373
+ [PanguEmbeddedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
374
+ )
375
+ self.norm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
+ self.rotary_emb = PanguEmbeddedRotaryEmbedding(config=config)
377
+ self.gradient_checkpointing = False
378
+
379
+ # Initialize weights and apply final processing
380
+ self.post_init()
381
+
382
+ def get_input_embeddings(self):
383
+ return self.embed_tokens
384
+
385
+ def set_input_embeddings(self, value):
386
+ self.embed_tokens = value
387
+
388
+ @can_return_tuple
389
+ @auto_docstring
390
+ def forward(
391
+ self,
392
+ input_ids: Optional[torch.LongTensor] = None,
393
+ attention_mask: Optional[torch.Tensor] = None,
394
+ position_ids: Optional[torch.LongTensor] = None,
395
+ past_key_values: Optional[Cache] = None,
396
+ inputs_embeds: Optional[torch.FloatTensor] = None,
397
+ use_cache: Optional[bool] = None,
398
+ output_attentions: Optional[bool] = None,
399
+ output_hidden_states: Optional[bool] = None,
400
+ cache_position: Optional[torch.LongTensor] = None,
401
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
402
+ ) -> BaseModelOutputWithPast:
403
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
404
+ output_hidden_states = (
405
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
406
+ )
407
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
408
+
409
+ if (input_ids is None) ^ (inputs_embeds is not None):
410
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
411
+
412
+ if self.gradient_checkpointing and self.training and use_cache:
413
+ logger.warning_once(
414
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
415
+ )
416
+ use_cache = False
417
+
418
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
419
+ if not isinstance(past_key_values, (type(None), Cache)):
420
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
421
+
422
+ if inputs_embeds is None:
423
+ inputs_embeds = self.embed_tokens(input_ids)
424
+
425
+ if use_cache and past_key_values is None:
426
+ past_key_values = DynamicCache()
427
+
428
+ if cache_position is None:
429
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
430
+ cache_position = torch.arange(
431
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
432
+ )
433
+
434
+ if position_ids is None:
435
+ position_ids = cache_position.unsqueeze(0)
436
+
437
+ causal_mask = create_causal_mask(
438
+ config=self.config,
439
+ input_embeds=inputs_embeds,
440
+ attention_mask=attention_mask,
441
+ cache_position=cache_position,
442
+ past_key_values=past_key_values,
443
+ position_ids=position_ids,
444
+ )
445
+
446
+ hidden_states = inputs_embeds
447
+
448
+ # create position embeddings to be shared across the decoder layers
449
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
450
+
451
+ # decoder layers
452
+ all_hidden_states = () if output_hidden_states else None
453
+ all_self_attns = () if output_attentions else None
454
+
455
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
456
+ if output_hidden_states:
457
+ all_hidden_states += (hidden_states,)
458
+
459
+ layer_outputs = decoder_layer(
460
+ hidden_states,
461
+ attention_mask=causal_mask,
462
+ position_ids=position_ids,
463
+ past_key_value=past_key_values,
464
+ output_attentions=output_attentions,
465
+ use_cache=use_cache,
466
+ cache_position=cache_position,
467
+ position_embeddings=position_embeddings,
468
+ **flash_attn_kwargs,
469
+ )
470
+
471
+ hidden_states = layer_outputs[0]
472
+
473
+ if output_attentions:
474
+ all_self_attns += (layer_outputs[1],)
475
+
476
+ hidden_states = self.norm(hidden_states)
477
+
478
+ # add hidden states from the last decoder layer
479
+ if output_hidden_states:
480
+ all_hidden_states += (hidden_states,)
481
+
482
+ return BaseModelOutputWithPast(
483
+ last_hidden_state=hidden_states,
484
+ past_key_values=past_key_values if use_cache else None,
485
+ hidden_states=all_hidden_states,
486
+ attentions=all_self_attns,
487
+ )
488
+
489
+
490
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
491
+
492
+
493
+ @auto_docstring
494
+ class PanguEmbeddedForCausalLM(PanguEmbeddedPreTrainedModel, GenerationMixin):
495
+ _tied_weights_keys = ["lm_head.weight"]
496
+ _tp_plan = {"lm_head": "colwise_rep"}
497
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
498
+
499
+ def __init__(self, config):
500
+ super().__init__(config)
501
+ self.model = PanguEmbeddedModel(config)
502
+ self.vocab_size = config.vocab_size
503
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
504
+
505
+ # Initialize weights and apply final processing
506
+ self.post_init()
507
+
508
+ def get_input_embeddings(self):
509
+ return self.model.embed_tokens
510
+
511
+ def set_input_embeddings(self, value):
512
+ self.model.embed_tokens = value
513
+
514
+ def get_output_embeddings(self):
515
+ return self.lm_head
516
+
517
+ def set_output_embeddings(self, new_embeddings):
518
+ self.lm_head = new_embeddings
519
+
520
+ def set_decoder(self, decoder):
521
+ self.model = decoder
522
+
523
+ def get_decoder(self):
524
+ return self.model
525
+
526
+ @can_return_tuple
527
+ @auto_docstring
528
+ def forward(
529
+ self,
530
+ input_ids: Optional[torch.LongTensor] = None,
531
+ attention_mask: Optional[torch.Tensor] = None,
532
+ position_ids: Optional[torch.LongTensor] = None,
533
+ past_key_values: Optional[Cache] = None,
534
+ inputs_embeds: Optional[torch.FloatTensor] = None,
535
+ labels: Optional[torch.LongTensor] = None,
536
+ use_cache: Optional[bool] = None,
537
+ output_attentions: Optional[bool] = None,
538
+ output_hidden_states: Optional[bool] = None,
539
+ cache_position: Optional[torch.LongTensor] = None,
540
+ logits_to_keep: Union[int, torch.Tensor] = 0,
541
+ **kwargs: Unpack[KwargsForCausalLM],
542
+ ) -> CausalLMOutputWithPast:
543
+
544
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
545
+ output_hidden_states = (
546
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
547
+ )
548
+
549
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
550
+ outputs: BaseModelOutputWithPast = self.model(
551
+ input_ids=input_ids,
552
+ attention_mask=attention_mask,
553
+ position_ids=position_ids,
554
+ past_key_values=past_key_values,
555
+ inputs_embeds=inputs_embeds,
556
+ use_cache=use_cache,
557
+ output_attentions=output_attentions,
558
+ output_hidden_states=output_hidden_states,
559
+ cache_position=cache_position,
560
+ **kwargs,
561
+ )
562
+
563
+ hidden_states = outputs.last_hidden_state
564
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
565
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
566
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
567
+
568
+ loss = None
569
+ if labels is not None:
570
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
571
+
572
+ return CausalLMOutputWithPast(
573
+ loss=loss,
574
+ logits=logits,
575
+ past_key_values=outputs.past_key_values,
576
+ hidden_states=outputs.hidden_states,
577
+ attentions=outputs.attentions,
578
+ )
579
+
580
+
581
+ __all__ = [
582
+ "PanguEmbeddedForCausalLM",
583
+ "PanguEmbeddedModel",
584
+ "PanguEmbeddedPreTrainedModel",
585
+ ]
modular_openpangu_dense.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ import torch_npu
28
+ from torch_npu.contrib import transfer_to_npu
29
+ if "910" in torch.npu.get_device_name():
30
+ NPU_ATTN_INFR = True
31
+ print("[INFO] torch_npu detected. Using NPU fused infer attention.")
32
+ else:
33
+ NPU_ATTN_INFR = False
34
+
35
+ from transformers.cache_utils import Cache
36
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
37
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import logging
40
+ from transformers.models.llama.modeling_llama import (
41
+ LlamaAttention,
42
+ LlamaDecoderLayer,
43
+ LlamaForCausalLM,
44
+ LlamaForSequenceClassification,
45
+ LlamaMLP,
46
+ LlamaModel,
47
+ apply_rotary_pos_emb,
48
+ eager_attention_forward,
49
+ )
50
+ from .configuration_openpangu_dense import PanguEmbeddedConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ class PanguEmbeddedMLP(LlamaMLP):
57
+ def __init__(self, config):
58
+ super().__init__(config)
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+
64
+ class PanguEmbeddedAttention(LlamaAttention):
65
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
66
+ super().__init__()
67
+ self.config = config
68
+ self.layer_idx = layer_idx
69
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
70
+ self.num_heads = config.num_attention_heads
71
+ self.num_key_value_heads = config.num_key_value_heads
72
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
73
+ self.scaling = self.head_dim**-0.5
74
+ self.attention_dropout = config.attention_dropout
75
+ self.is_causal = True
76
+
77
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
78
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
79
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
80
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
86
+ attention_mask: Optional[torch.Tensor],
87
+ past_key_value: Optional[Cache] = None,
88
+ cache_position: Optional[torch.LongTensor] = None,
89
+ **kwargs: Unpack[FlashAttentionKwargs],
90
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
91
+ input_shape = hidden_states.shape[:-1]
92
+ hidden_shape = (*input_shape, -1, self.head_dim)
93
+
94
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
95
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
96
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
97
+
98
+ cos, sin = position_embeddings
99
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
100
+
101
+ if past_key_value is not None:
102
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
103
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
104
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
105
+
106
+ attention_interface: Callable = eager_attention_forward
107
+ if self.config._attn_implementation != "eager":
108
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
109
+
110
+ if not self.training and NPU_ATTN_INFR:
111
+ q_len = input_shape[1]
112
+ if attention_mask is not None:
113
+ attention_mask = ~attention_mask.bool()
114
+ elif q_len > 1:
115
+ attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
116
+
117
+ attn_output, _ = torch_npu.npu_fused_infer_attention_score(
118
+ query_states, key_states, value_states,
119
+ num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
120
+ input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
121
+ attn_output = attn_output.transpose(1, 2)
122
+ attn_weights = None
123
+ else:
124
+ attn_output, attn_weights = attention_interface(
125
+ self,
126
+ query_states,
127
+ key_states,
128
+ value_states,
129
+ attention_mask,
130
+ dropout=0.0 if not self.training else self.attention_dropout,
131
+ scaling=self.scaling,
132
+ **kwargs,
133
+ )
134
+
135
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
136
+ attn_output = self.o_proj(attn_output)
137
+ return attn_output, attn_weights
138
+
139
+
140
+ class PanguEmbeddedDecoderLayer(LlamaDecoderLayer):
141
+ pass
142
+
143
+
144
+ class PanguEmbeddedModel(LlamaModel):
145
+ pass
146
+
147
+
148
+ class PanguEmbeddedForCausalLM(LlamaForCausalLM):
149
+ pass
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "[unused10]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_openpangu.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import os
23
+ from shutil import copyfile
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+
26
+ import sentencepiece as spm
27
+
28
+ from transformers.tokenization_utils import PreTrainedTokenizer
29
+ from transformers.utils import logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {}
37
+
38
+
39
+ def convert_bool(string):
40
+ if isinstance(string, str):
41
+ if string.lower() == "true":
42
+ return True
43
+ elif string.lower() == "false":
44
+ return False
45
+ else:
46
+ return string
47
+ else:
48
+ return string
49
+
50
+
51
+ class PanguTokenizer(PreTrainedTokenizer):
52
+ """
53
+ Construct a tokenizer. Based on byte-level Byte-Pair-Encoding.
54
+
55
+ Args:
56
+ vocab_file (`str`):
57
+ Path to the vocabulary file.
58
+ """
59
+
60
+ vocab_files_names = VOCAB_FILES_NAMES
61
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
62
+ model_input_names = ["input_ids", "attention_mask"]
63
+ _auto_class = "AutoTokenizer"
64
+
65
+ def __init__(
66
+ self,
67
+ vocab_file,
68
+ unk_token="<unk>",
69
+ bos_token="<s>",
70
+ eos_token="</s>",
71
+ pad_token="</s>",
72
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
73
+ add_bos_token=True,
74
+ add_eos_token=False,
75
+ decode_with_prefix_space=False,
76
+ clean_up_tokenization_spaces=False,
77
+ **kwargs,
78
+ ):
79
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
80
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
81
+ self.sp_model.Load(vocab_file)
82
+ super().__init__(
83
+ bos_token=bos_token,
84
+ eos_token=eos_token,
85
+ unk_token=unk_token,
86
+ pad_token=pad_token,
87
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
88
+ **kwargs,
89
+ )
90
+ self.vocab_file = vocab_file
91
+ self.add_bos_token = convert_bool(add_bos_token)
92
+ self.add_eos_token = add_eos_token
93
+ self.decode_with_prefix_space = decode_with_prefix_space
94
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
95
+ self.sp_model.Load(vocab_file)
96
+ self._no_prefix_space_tokens = None
97
+
98
+ """ Initialisation"""
99
+
100
+ @property
101
+ def no_prefix_space_tokens(self):
102
+ if self._no_prefix_space_tokens is None:
103
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
104
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
105
+ return self._no_prefix_space_tokens
106
+
107
+ @property
108
+ def vocab_size(self):
109
+ """Returns vocab size"""
110
+ return self.sp_model.get_piece_size()
111
+
112
+ @property
113
+ def bos_token_id(self) -> Optional[int]:
114
+ return self.sp_model.bos_id()
115
+
116
+ @property
117
+ def eos_token_id(self) -> Optional[int]:
118
+ return super().eos_token_id
119
+
120
+ def get_vocab(self):
121
+ """Returns vocab as a dict"""
122
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
123
+ vocab.update(self.added_tokens_encoder)
124
+ return vocab
125
+
126
+ def _tokenize(self, text):
127
+ """Returns a tokenized string."""
128
+ return self.sp_model.encode(text, out_type=str)
129
+
130
+ def _convert_token_to_id(self, token):
131
+ """Converts a token (str) in an id using the vocab."""
132
+ return self.sp_model.piece_to_id(token)
133
+
134
+ def _convert_id_to_token(self, index):
135
+ """Converts an index (integer) in a token (str) using the vocab."""
136
+ token = self.sp_model.IdToPiece(index)
137
+ return token
138
+
139
+ def _maybe_add_prefix_space(self, tokens, decoded):
140
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
141
+ return " " + decoded
142
+ else:
143
+ return decoded
144
+
145
+ def convert_tokens_to_string(self, tokens):
146
+ """Converts a sequence of tokens (string) in a single string."""
147
+ current_sub_tokens = []
148
+ out_string = ""
149
+ prev_is_special = False
150
+ for token in tokens:
151
+ # make sure that special tokens are not decoded using sentencepiece model
152
+ if token in self.all_special_tokens:
153
+ # Decode the current sub-tokens first
154
+ if current_sub_tokens:
155
+ out_string += self.sp_model.decode(current_sub_tokens)
156
+ current_sub_tokens = []
157
+ # Append the special token without adding extra spaces
158
+ out_string += token
159
+ prev_is_special = True
160
+ else:
161
+ current_sub_tokens.append(token)
162
+ prev_is_special = False
163
+ # Decode any remaining sub-tokens
164
+ if current_sub_tokens:
165
+ out_string += self.sp_model.decode(current_sub_tokens)
166
+ # Clean up leading and trailing spaces
167
+ if self.clean_up_tokenization_spaces:
168
+ out_string = self.clean_up_tokenization(out_string)
169
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
170
+ return out_string[1:]
171
+
172
+ # Override decode to set spaces_between_special_tokens to True as default
173
+ def decode(self,
174
+ token_ids,
175
+ spaces_between_special_tokens: bool = False,
176
+ **kwargs):
177
+ return super().decode(
178
+ token_ids=token_ids,
179
+ spaces_between_special_tokens=spaces_between_special_tokens,
180
+ **kwargs,
181
+ )
182
+
183
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
184
+ """
185
+ Save the vocabulary and special tokens file to a directory.
186
+
187
+ Args:
188
+ save_directory (`str`):
189
+ The directory in which to save the vocabulary.
190
+
191
+ Returns:
192
+ `Tuple(str)`: Paths to the files saved.
193
+ """
194
+ if not os.path.isdir(save_directory):
195
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
196
+ return ("",)
197
+ out_vocab_file = os.path.join(
198
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
199
+ )
200
+
201
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
202
+ copyfile(self.vocab_file, out_vocab_file)
203
+ elif not os.path.isfile(self.vocab_file):
204
+ with open(out_vocab_file, "wb") as fi:
205
+ content_spiece_model = self.sp_model.serialized_model_proto()
206
+ fi.write(content_spiece_model)
207
+
208
+ return (out_vocab_file,)
209
+
210
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
211
+ if self.add_bos_token:
212
+ bos_token_ids = [self.bos_token_id]
213
+ else:
214
+ bos_token_ids = []
215
+
216
+ output = bos_token_ids + token_ids_0
217
+
218
+ if token_ids_1 is not None:
219
+ output = output + token_ids_1
220
+
221
+ if self.add_eos_token:
222
+ output = output + [self.eos_token_id]
223
+
224
+ return output
225
+
226
+ def get_special_tokens_mask(
227
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
228
+ ) -> List[int]:
229
+ """
230
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
231
+ special tokens using the tokenizer `prepare_for_model` method.
232
+
233
+ Args:
234
+ token_ids_0 (`List[int]`):
235
+ List of IDs.
236
+ token_ids_1 (`List[int]`, *optional*):
237
+ Optional second list of IDs for sequence pairs.
238
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
239
+ Whether or not the token list is already formatted with special tokens for the model.
240
+
241
+ Returns:
242
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
243
+ """
244
+ if already_has_special_tokens:
245
+ return super().get_special_tokens_mask(
246
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
247
+ )
248
+
249
+ if token_ids_1 is None:
250
+ return [1] + ([0] * len(token_ids_0)) + [1]
251
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
252
+
253
+ def create_token_type_ids_from_sequences(
254
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
255
+ ) -> List[int]:
256
+ """
257
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
258
+ use of token type ids, therefore a list of zeros is returned.
259
+
260
+ Args:
261
+ token_ids_0 (`List[int]`):
262
+ List of IDs.
263
+ token_ids_1 (`List[int]`, *optional*):
264
+ Optional second list of IDs for sequence pairs.
265
+
266
+ Returns:
267
+ `List[int]`: List of zeros.
268
+ """
269
+ eos = [self.eos_token_id]
270
+
271
+ if token_ids_1 is None:
272
+ return len(token_ids_0 + eos) * [0]
273
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec
3
+ size 2477809
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"add_bos_token": true, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguTokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguTokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}