Upload folder using huggingface_hub
Browse files- LICENSE +34 -0
- Open Source Software Notice +218 -0
- README.md +127 -6
- README_EN.md +124 -0
- checklist.chk +36 -0
- config.json +31 -0
- configuration_openpangu_dense.py +56 -0
- generation_config.json +11 -0
- inference/generate.py +57 -0
- inference/vllm_ascend/_build_info.py +3 -0
- inference/vllm_ascend/attention/attention.py +1220 -0
- inference/vllm_ascend/attention/mla_v1.py +1224 -0
- inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py +6 -0
- inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py +171 -0
- inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py +6 -0
- inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py +300 -0
- inference/vllm_ascend/envs.py +153 -0
- inference/vllm_ascend/models/__init__.py +68 -0
- inference/vllm_ascend/models/open_pangu.py +1127 -0
- inference/vllm_ascend/ops/fused_moe.py +1530 -0
- inference/vllm_ascend/patch/worker/patch_common/__init__.py +27 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_config.py +97 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py +26 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py +159 -0
- inference/vllm_ascend/quantization/w8a8.py +757 -0
- inference/vllm_ascend/quantization/w8a8_dynamic.py +831 -0
- inference/vllm_ascend/tokenizer_config.json +1 -0
- inference/vllm_ascend/utils.py +563 -0
- inference/vllm_ascend/worker/model_runner_v1.py +0 -0
- inference/vllm_ascend/worker/npu_input_batch.py +796 -0
- inference/vllm_ascend_for_openpangu_embedded_7b.md +129 -0
- inference/vllm_ascend_for_openpangu_embedded_7b.zh.md +129 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +486 -0
- modeling_openpangu_dense.py +585 -0
- modular_openpangu_dense.py +149 -0
- special_tokens_map.json +30 -0
- tokenization_openpangu.py +273 -0
- tokenizer.model +3 -0
- tokenizer_config.json +1 -0
LICENSE
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0
|
| 2 |
+
|
| 3 |
+
This OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 (the "Agreement") is a legal agreement between You and Huawei Technologies Co., Ltd. ("Huawei", "We" or "Us"), and it governs Your reproducing, use, modification, and distribution of openPangu as made available by Huawei under this Agreement.
|
| 4 |
+
|
| 5 |
+
By using, reproducing, modifying, distributing, performing or displaying any portion or element of openPangu, or otherwise accepting the terms of this Agreement, You agree to be bound by this Agreement.
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
1.1. “openPangu” or “Model” means openPangu large language models and software, including trained model weights, parameters (including optimizer states), accompanying source code and scripts released under this Agreement.
|
| 9 |
+
1.2. “Derivative Model” means all (1) modifications to the Model, (2) works based on the Model, and (3) any other derivative works of the Model. For clarity, information or content results from operating or otherwise using the Model is not a Derivative Model.
|
| 10 |
+
1.3. “You” or “Your” means an individual or Legal Entity exercising permissions granted by this Agreement and/or using the Model for any purpose.
|
| 11 |
+
1.4. “Third Party” or “Third Parties” means individuals or legal entities that are not under common control with Us or You.
|
| 12 |
+
|
| 13 |
+
2. License Grant. Subject to Your full compliance with the terms and conditions of this Agreement, We hereby grant to You a perpetual, worldwide, non-exclusive, non-transferable, no-charge, royalty-free license (except as stated in Section 3) to use, reproduce, modify, and distribute the Model.
|
| 14 |
+
|
| 15 |
+
3. Conditions for License Grant. You represent and warrant that You will not, access, download, install, run, deploy, integrate, modify, or otherwise use the Model, directly or indirectly, within the European Union.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
4. Redistribution.
|
| 19 |
+
4.1. If You distribute the Model or Derivative Model, You shall retain in Your distribution (1) a copy of this agreement, and (2) all copyright notices and other notices of origin included in the Model that are applicable to Your distribution.
|
| 20 |
+
4.2. Further, if You distribute or make available to Third Parties a product or service (including another AI model) based on the Model, You are required to (1) display the acknowledgement “Powered by openPangu” and (2) include a trademark notice “openPangu is a trademark of Huawei Technologies Co., Ltd.” on related webpages, user manuals, product documentations or other advertising materials mentioning features of the Model.
|
| 21 |
+
4.3. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for Derivative Model made by You as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the terms and conditions of this Agreement.
|
| 22 |
+
|
| 23 |
+
5. Ownership. We do not claim ownership to any information or content generated using the Model or Derivative Model that are made by You. You are solely responsible for evaluating the accuracy and appropriateness of such information or content for Your use case.
|
| 24 |
+
|
| 25 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of Huawei, except as required for complying with Section 4.2.
|
| 26 |
+
|
| 27 |
+
7. Indemnity. You will indemnify and hold harmless Huawei from and against any claim by any third party arising out of or related to Your use or distribution of the Model or Derivative Model made by You (e.g. a violation against Section 3). For avoidance of doubt, “third party” in this clause include supervisory authorities.
|
| 28 |
+
|
| 29 |
+
8. THE MODEL IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NONINFRINGEMENT, ACCURACY, OR THE ABSENCE OF LATENT OR OTHER DEFECTS OR ERRORS, WHETHER OR NOT DISCOVERABLE, ALL TO THE GREATEST EXTENT PERMISSIBLE UNDER APPLICABLE LAW.
|
| 30 |
+
|
| 31 |
+
9. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MODEL, IN WHOLE OR IN PART, NO MATTER HOW IT’S CAUSED OR THE LEGAL THEORY IT IS BASED ON, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
END OF THE TERMS AND CONDITIONS
|
Open Source Software Notice
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPEN SOURCE SOFTWARE NOTICE
|
| 2 |
+
|
| 3 |
+
Please note we provide an open source software notice along with this product and/or this product firmware (in the following just “this product”). The open source software licenses are granted by the respective right holders. And the open source licenses prevail all other license information with regard to the respective open source software contained in the product, including but not limited to End User Software Licensing Agreement. This notice is provided on behalf of Huawei Technologies Co. Ltd. and any of its local subsidiaries which may have provided this product to you in your local country.
|
| 4 |
+
|
| 5 |
+
Warranty Disclaimer
|
| 6 |
+
THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
|
| 7 |
+
|
| 8 |
+
Copyright Notice and License Texts
|
| 9 |
+
|
| 10 |
+
Software: transformers 4.53.2
|
| 11 |
+
Copyright notice:
|
| 12 |
+
Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 13 |
+
|
| 14 |
+
License Text:
|
| 15 |
+
----------------------------------------
|
| 16 |
+
|
| 17 |
+
Apache License
|
| 18 |
+
Version 2.0, January 2004
|
| 19 |
+
http://www.apache.org/licenses/
|
| 20 |
+
|
| 21 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 22 |
+
|
| 23 |
+
1. Definitions.
|
| 24 |
+
|
| 25 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 26 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 27 |
+
|
| 28 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 29 |
+
the copyright owner that is granting the License.
|
| 30 |
+
|
| 31 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 32 |
+
other entities that control, are controlled by, or are under common
|
| 33 |
+
control with that entity. For the purposes of this definition,
|
| 34 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 35 |
+
direction or management of such entity, whether by contract or
|
| 36 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 37 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 38 |
+
|
| 39 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 40 |
+
exercising permissions granted by this License.
|
| 41 |
+
|
| 42 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 43 |
+
including but not limited to software source code, documentation
|
| 44 |
+
source, and configuration files.
|
| 45 |
+
|
| 46 |
+
"Object" form shall mean any form resulting from mechanical
|
| 47 |
+
transformation or translation of a Source form, including but
|
| 48 |
+
not limited to compiled object code, generated documentation,
|
| 49 |
+
and conversions to other media types.
|
| 50 |
+
|
| 51 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 52 |
+
Object form, made available under the License, as indicated by a
|
| 53 |
+
copyright notice that is included in or attached to the work
|
| 54 |
+
(an example is provided in the Appendix below).
|
| 55 |
+
|
| 56 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 57 |
+
form, that is based on (or derived from) the Work and for which the
|
| 58 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 59 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 60 |
+
of this License, Derivative Works shall not include works that remain
|
| 61 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 62 |
+
the Work and Derivative Works thereof.
|
| 63 |
+
|
| 64 |
+
"Contribution" shall mean any work of authorship, including
|
| 65 |
+
the original version of the Work and any modifications or additions
|
| 66 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 67 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 68 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 69 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 70 |
+
means any form of electronic, verbal, or written communication sent
|
| 71 |
+
to the Licensor or its representatives, including but not limited to
|
| 72 |
+
communication on electronic mailing lists, source code control systems,
|
| 73 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 74 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 75 |
+
excluding communication that is conspicuously marked or otherwise
|
| 76 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 77 |
+
|
| 78 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 79 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 80 |
+
subsequently incorporated within the Work.
|
| 81 |
+
|
| 82 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 83 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 84 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 85 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 86 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 87 |
+
Work and such Derivative Works in Source or Object form.
|
| 88 |
+
|
| 89 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 90 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 91 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 92 |
+
(except as stated in this section) patent license to make, have made,
|
| 93 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 94 |
+
where such license applies only to those patent claims licensable
|
| 95 |
+
by such Contributor that are necessarily infringed by their
|
| 96 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 97 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 98 |
+
institute patent litigation against any entity (including a
|
| 99 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 100 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 101 |
+
or contributory patent infringement, then any patent licenses
|
| 102 |
+
granted to You under this License for that Work shall terminate
|
| 103 |
+
as of the date such litigation is filed.
|
| 104 |
+
|
| 105 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 106 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 107 |
+
modifications, and in Source or Object form, provided that You
|
| 108 |
+
meet the following conditions:
|
| 109 |
+
|
| 110 |
+
(a) You must give any other recipients of the Work or
|
| 111 |
+
Derivative Works a copy of this License; and
|
| 112 |
+
|
| 113 |
+
(b) You must cause any modified files to carry prominent notices
|
| 114 |
+
stating that You changed the files; and
|
| 115 |
+
|
| 116 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 117 |
+
that You distribute, all copyright, patent, trademark, and
|
| 118 |
+
attribution notices from the Source form of the Work,
|
| 119 |
+
excluding those notices that do not pertain to any part of
|
| 120 |
+
the Derivative Works; and
|
| 121 |
+
|
| 122 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 123 |
+
distribution, then any Derivative Works that You distribute must
|
| 124 |
+
include a readable copy of the attribution notices contained
|
| 125 |
+
within such NOTICE file, excluding those notices that do not
|
| 126 |
+
pertain to any part of the Derivative Works, in at least one
|
| 127 |
+
of the following places: within a NOTICE text file distributed
|
| 128 |
+
as part of the Derivative Works; within the Source form or
|
| 129 |
+
documentation, if provided along with the Derivative Works; or,
|
| 130 |
+
within a display generated by the Derivative Works, if and
|
| 131 |
+
wherever such third-party notices normally appear. The contents
|
| 132 |
+
of the NOTICE file are for informational purposes only and
|
| 133 |
+
do not modify the License. You may add Your own attribution
|
| 134 |
+
notices within Derivative Works that You distribute, alongside
|
| 135 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 136 |
+
that such additional attribution notices cannot be construed
|
| 137 |
+
as modifying the License.
|
| 138 |
+
|
| 139 |
+
You may add Your own copyright statement to Your modifications and
|
| 140 |
+
may provide additional or different license terms and conditions
|
| 141 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 142 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 143 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 144 |
+
the conditions stated in this License.
|
| 145 |
+
|
| 146 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 147 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 148 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 149 |
+
this License, without any additional terms or conditions.
|
| 150 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 151 |
+
the terms of any separate license agreement you may have executed
|
| 152 |
+
with Licensor regarding such Contributions.
|
| 153 |
+
|
| 154 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 155 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 156 |
+
except as required for reasonable and customary use in describing the
|
| 157 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 158 |
+
|
| 159 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 160 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 161 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 162 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 163 |
+
implied, including, without limitation, any warranties or conditions
|
| 164 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 165 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 166 |
+
appropriateness of using or redistributing the Work and assume any
|
| 167 |
+
risks associated with Your exercise of permissions under this License.
|
| 168 |
+
|
| 169 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 170 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 171 |
+
unless required by applicable law (such as deliberate and grossly
|
| 172 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 173 |
+
liable to You for damages, including any direct, indirect, special,
|
| 174 |
+
incidental, or consequential damages of any character arising as a
|
| 175 |
+
result of this License or out of the use or inability to use the
|
| 176 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 177 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 178 |
+
other commercial damages or losses), even if such Contributor
|
| 179 |
+
has been advised of the possibility of such damages.
|
| 180 |
+
|
| 181 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 182 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 183 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 184 |
+
or other liability obligations and/or rights consistent with this
|
| 185 |
+
License. However, in accepting such obligations, You may act only
|
| 186 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 187 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 188 |
+
defend, and hold each Contributor harmless for any liability
|
| 189 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 190 |
+
of your accepting any such warranty or additional liability.
|
| 191 |
+
|
| 192 |
+
END OF TERMS AND CONDITIONS
|
| 193 |
+
|
| 194 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 195 |
+
|
| 196 |
+
To apply the Apache License to your work, attach the following
|
| 197 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 198 |
+
replaced with your own identifying information. (Don't include
|
| 199 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 200 |
+
comment syntax for the file format. We also recommend that a
|
| 201 |
+
file or class name and description of purpose be included on the
|
| 202 |
+
same "printed page" as the copyright notice for easier
|
| 203 |
+
identification within third-party archives.
|
| 204 |
+
|
| 205 |
+
Copyright [yyyy] [name of copyright owner]
|
| 206 |
+
|
| 207 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 208 |
+
you may not use this file except in compliance with the License.
|
| 209 |
+
You may obtain a copy of the License at
|
| 210 |
+
|
| 211 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 212 |
+
|
| 213 |
+
Unless required by applicable law or agreed to in writing, software
|
| 214 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 215 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 216 |
+
See the License for the specific language governing permissions and
|
| 217 |
+
limitations under the License.
|
| 218 |
+
|
README.md
CHANGED
|
@@ -1,6 +1,127 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 开源盘古 Embedded-7B-V1.1
|
| 2 |
+
|
| 3 |
+
中文 | [English](README_EN.md)
|
| 4 |
+
|
| 5 |
+
## 1. 简介
|
| 6 |
+
|
| 7 |
+
openPangu-Embedded-7B-V1.1 是基于昇腾 NPU 从零训练的高效大语言模型,参数量为 7B(不含词表Embedding)。openPangu-Embedded-7B-V1.1 训练了约 25T tokens,具备快慢思考融合与自适应切换能力。
|
| 8 |
+
|
| 9 |
+
## 2. 模型架构
|
| 10 |
+
|
| 11 |
+
| | openPangu-Embedded-7B-V1.1 |
|
| 12 |
+
| :---------------------------: | :----------------: |
|
| 13 |
+
| **Architecture** | Dense |
|
| 14 |
+
| **Parameters (Non-Embedding)** | 7B |
|
| 15 |
+
| **Number of Layers** | 34 |
|
| 16 |
+
| **Hidden Dimension** | 12800 |
|
| 17 |
+
| **Attention Mechanism** | GQA |
|
| 18 |
+
| **Number of Attention Heads** | 32 for Q,8 for KV |
|
| 19 |
+
| **Vocabulary Size** | 153k |
|
| 20 |
+
| **Context Length (Natively)** | 32k |
|
| 21 |
+
| **Pretraining Tokens** | 25T |
|
| 22 |
+
|
| 23 |
+
## 3. 测评结果
|
| 24 |
+
|
| 25 |
+
| 测评集 | 测评指标 | 慢思考v1.0 | 慢思考v1.1 | 自适应v1.1 |
|
| 26 |
+
| :------------: | :-----------------: | :--------: | :--------: | :--------: |
|
| 27 |
+
| **通用能力** | | | | |
|
| 28 |
+
| MMLU-Pro | Exact Match | 76.32 | 75.54 | 72.81 |
|
| 29 |
+
| CMMLU | Acc | 75.59 | 72.94 | 72.18 |
|
| 30 |
+
| ArenaHard_v0.1 | w/o style control | 85.80 | 88.00 | 84.60 |
|
| 31 |
+
| C-Eval | Acc | 83.05 | 84.92 | 83.33 |
|
| 32 |
+
| GPQA-Diamond | Avg@4 | 70.54 | 73.23 | 73.74 |
|
| 33 |
+
| **数学能力** | | | | |
|
| 34 |
+
| MATH-500 | Avg@1 | 95.00 | 97.00 | 96.00 |
|
| 35 |
+
| AIME24 | Avg@16 | 71.57 | 79.38 | 79.02 |
|
| 36 |
+
| AIME25 | Avg@16 | 58.24 | 70.00 | 70.21 |
|
| 37 |
+
| **代码能力** | | | | |
|
| 38 |
+
| LiveCodeBench | Avg@2 (08/24~01/25) | 54.04 | 58.27 | 58.27 |
|
| 39 |
+
| MBPP+ | Avg@2 | 76.06 | 76.46 | 75.66 |
|
| 40 |
+
|
| 41 |
+
**注:** 评测过程中system prompt 为空,且不添加任何额外的思维链(CoT)提示。评测采用 128k 的序列长度进行。
|
| 42 |
+
|
| 43 |
+
除精度外,我们还在部分数据集上统计了模型的输出长度,通过数据质量驱动的学习策略,自适应快慢思考可以在基本不影响精度地前提下,有效地在简单任务上自动切换部分输出为快思考,大幅缩短平均输出思维链长度(Length);在难任务通过保持慢思考能力,精度持平纯慢思考模型。
|
| 44 |
+
|
| 45 |
+
| 测评集 | 测评指标 | 慢思考v1.1 | 自适应v1.1 |
|
| 46 |
+
| :------------: | :-----------------: | :--------: | :--------: |
|
| 47 |
+
| **通用能力** | | | |
|
| 48 |
+
| CMMLU | Acc | 72.94 | 72.18 |
|
| 49 |
+
| | Length | 2574 | 1338 |
|
| 50 |
+
| C-Eval | Acc | 84.92 | 83.33 |
|
| 51 |
+
| | Length | 2484 | 1723 |
|
| 52 |
+
| **数学能力** | | | |
|
| 53 |
+
| AIME24 | Avg@16 | 79.38 | 79.02 |
|
| 54 |
+
| | Length | 48229 | 49656 |
|
| 55 |
+
| **代码能力** | | | |
|
| 56 |
+
| LiveCodeBench | Avg@2 (08/24~01/25) | 58.27 | 58.27 |
|
| 57 |
+
| | Length | 58140 | 59307 |
|
| 58 |
+
|
| 59 |
+
## 4. 部署和使用
|
| 60 |
+
|
| 61 |
+
### 4.1 环境准备
|
| 62 |
+
|
| 63 |
+
##### 硬件规格
|
| 64 |
+
|
| 65 |
+
Atlas 800T A2 (64GB),驱动与固件安装包获取请参照 [[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)]。
|
| 66 |
+
|
| 67 |
+
##### 软件环境
|
| 68 |
+
|
| 69 |
+
- 操作系统:Linux(推荐 openEuler>=24.03)
|
| 70 |
+
- CANN==8.1.RC1,安装准备及流程请参照 [[CANN Install]](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)
|
| 71 |
+
- python==3.10
|
| 72 |
+
- torch==2.1.0
|
| 73 |
+
- torch-npu==2.1.0.post12
|
| 74 |
+
- transformers==4.53.2
|
| 75 |
+
|
| 76 |
+
以上软件配套经过验证,理论可以支持更高版本,如有疑问,可以提交 issue。
|
| 77 |
+
|
| 78 |
+
### 4.2 权重完整性校验
|
| 79 |
+
|
| 80 |
+
请参考以下方法对下载内容进行完整性校验,hash 值存储在 checklist.chk 文件中。
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
#!/usr/bin/env bash
|
| 84 |
+
ARCH=$(uname -m)
|
| 85 |
+
MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
|
| 86 |
+
cd "$MODEL_PATH" || exit 1
|
| 87 |
+
if [ "$ARCH" = "arm64" ]; then
|
| 88 |
+
sha256sum checklist.chk
|
| 89 |
+
else
|
| 90 |
+
sha256sum -c checklist.chk
|
| 91 |
+
fi
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### 4.3 推理样例
|
| 95 |
+
|
| 96 |
+
下述内容提供 openPangu-Embedded-7B-V1.1 在 `transformers` 框架上进行推理的一个简单示例:
|
| 97 |
+
|
| 98 |
+
> 运行前请修改 generate.py,添加模型路径。
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
cd inference
|
| 102 |
+
python generate.py
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
openPangu-Embedded-7B-V1.1 模型默认为慢思考模式,可以通过以下手段切换至快慢自适应切换/快思考模式:
|
| 106 |
+
|
| 107 |
+
- 在代码实例`generate.py`中,`auto_thinking_prompt`与`no_thinking_prompt`变量的定义展示了切换至快慢自适应或快思考模式的具体实现:通过在用户输入末尾添加`/auto_think`或`/no_think`标记,可将当前轮次切换至快慢自适应切换/快思考模式。
|
| 108 |
+
|
| 109 |
+
### 4.4 使用推理框架
|
| 110 |
+
|
| 111 |
+
vllm_ascend:参考[[vllm_ascend_for_openpangu_embedded_7b.zh]](inference/vllm_ascend_for_openpangu_embedded_7b.zh.md)
|
| 112 |
+
|
| 113 |
+
## 5. 模型许可证
|
| 114 |
+
|
| 115 |
+
除文件中对开源许可证另有约定外,openPangu-Embedded-7B-V1.1 模型根据 OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 授权,旨在允许使用并促进人工智能技术的进一步发展。有关详细信息,请参阅模型存储库根目录中的 [LICENSE](LICENSE) 文件。
|
| 116 |
+
|
| 117 |
+
## 6. 免责声明
|
| 118 |
+
|
| 119 |
+
由于 openPangu-Embedded-7B-V1.1(“模型”)所依赖的技术固有的技术限制,以及人工智能生成的内容是由盘古自动生成的,华为无法对以下事项做出任何保证:
|
| 120 |
+
|
| 121 |
+
- 尽管该模型的输出由 AI 算法生成,但不能排除某些信息可能存在缺陷、不合理或引起不适的可能性,生成的内容不代表华为的态度或立场;
|
| 122 |
+
- 无法保证该模型 100% 准确、可靠、功能齐全、及时、安全、无错误、不间断、持续稳定或无任何故障;
|
| 123 |
+
- 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任。
|
| 124 |
+
|
| 125 |
+
## 7. 反馈
|
| 126 |
+
|
| 127 |
+
如果有任何意见和建议,请提交issue或联系 [email protected]。
|
README_EN.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# openPangu-Embedded-7B-V1.1
|
| 2 |
+
[中文](README.md) | English
|
| 3 |
+
|
| 4 |
+
## 1. Introduction
|
| 5 |
+
The openPangu-Embedded-7B-V1.1 is an efficient large language model trained from scratch based on the Ascend NPU. It contains 7 billion parameters (excluding the vocabulary embedding layer). The openPangu-Embedded-7B-V1.1 has been trained on approximately 25T tokens. The model is capable of integrating both fast and slow thinking, and can adaptively switch between two thinking mode based on assessed query complexity.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## 2. Model Architecture
|
| 9 |
+
|
| 10 |
+
| | openPangu-Embedded-7B-V1.1 |
|
| 11 |
+
| :---------------------------: | :----------------: |
|
| 12 |
+
| **Architecture** | Dense |
|
| 13 |
+
| **Parameters (Non-Embedding)** | 7B |
|
| 14 |
+
| **Number of Layers** | 34 |
|
| 15 |
+
| **Hidden Dimension** | 12800 |
|
| 16 |
+
| **Attention Mechanism** | GQA |
|
| 17 |
+
| **Number of Attention Heads** | 32 for Q,8 for KV |
|
| 18 |
+
| **Vocabulary Size** | 153k |
|
| 19 |
+
| **Context Length (Natively)** | 32k |
|
| 20 |
+
| **Pretraining Tokens** | 25T |
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
## 3. Results
|
| 24 |
+
|
| 25 |
+
| Benchmark | Metric | Slow-thinking v1.0 | Slow-thinking v1.1 | Adaptive-switching v1.1 |
|
| 26 |
+
| :---: | :---: | :---: | :---: | :---: |
|
| 27 |
+
| **General** | | | | |
|
| 28 |
+
| MMLU-Pro | Exact Match | 76.32 | 75.54 | 72.81 |
|
| 29 |
+
| CMMLU | Acc | 75.59 | 72.94 | 72.18 |
|
| 30 |
+
| ArenaHard_v0.1 | w/o style control | 85.80 | 88.00 | 84.60 |
|
| 31 |
+
| C-Eval | Acc | 83.05 | 84.92 | 83.33 |
|
| 32 |
+
| GPQA-Diamond | Avg@4 | 70.54 | 73.23 | 73.74 |
|
| 33 |
+
| **Math** | | | | |
|
| 34 |
+
| MATH-500 | Avg@1 | 95.00 | 97.00 | 96.00 |
|
| 35 |
+
| AIME24 | Avg@16 | 71.57 | 79.38 | 79.02 |
|
| 36 |
+
| AIME25 | Avg@16 | 58.24 | 70.00 | 70.21 |
|
| 37 |
+
| **Coding** | | | | |
|
| 38 |
+
| LiveCodeBench | Avg@2 (08/24~01/25) | 54.04 | 58.27 | 58.27 |
|
| 39 |
+
| MBPP+ | Avg@2 | 76.06 | 76.46 | 75.66 |
|
| 40 |
+
|
| 41 |
+
**Note:** The system prompt is left empty, and no additional Chain-of-Thought (CoT) prompts are introduced during the evaluation. All evaluations are performed using a sequence length of 128k tokens.
|
| 42 |
+
|
| 43 |
+
In addition to accuracy, we also analyzed the model's output length on some datasets. Through data quality-driven learning strategy, adaptive-switching mode can effectively automatically switch some outputs to fast thinking on simple tasks without significantly affecting accuracy, significantly shortening the average Chain-of-Thought length. On difficult tasks, by maintaining slow thinking capabilities, the accuracy is comparable to that of a pure slow-thinking model.
|
| 44 |
+
|
| 45 |
+
| Benchmark | Metric | Slow-thinking v1.1 | Adaptive-switching v1.1 |
|
| 46 |
+
| :------------: | :-----------------: | :--------: | :--------: |
|
| 47 |
+
| **General** | | | |
|
| 48 |
+
| CMMLU | Acc | 72.94 | 72.18 |
|
| 49 |
+
| | Length | 2574 | 1338 |
|
| 50 |
+
| C-Eval | Acc | 84.92 | 83.33 |
|
| 51 |
+
| | Length | 2484 | 1723 |
|
| 52 |
+
| **Math** | | | |
|
| 53 |
+
| AIME24 | Avg@16 | 79.38 | 79.02 |
|
| 54 |
+
| | Length | 48229 | 49656 |
|
| 55 |
+
| **Coding** | | | |
|
| 56 |
+
| LiveCodeBench | Avg@2 (08/24~01/25) | 58.27 | 58.27 |
|
| 57 |
+
| | Length | 58140 | 59307 |
|
| 58 |
+
|
| 59 |
+
## 4. Deployment
|
| 60 |
+
|
| 61 |
+
### 4.1 Environment
|
| 62 |
+
|
| 63 |
+
##### Hardware Requirements
|
| 64 |
+
|
| 65 |
+
Atlas 800T A2 (64GB), please refer to [[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)] for obtaining the driver and firmware installation packages.
|
| 66 |
+
|
| 67 |
+
#### System Requirements & Dependencies
|
| 68 |
+
|
| 69 |
+
- System: Linux (OpenEuler ≥ 24.03 recommended)
|
| 70 |
+
- CANN==8.1.RC1: [[CANN Install]](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)
|
| 71 |
+
- python==3.10
|
| 72 |
+
- torch==2.1.0
|
| 73 |
+
- torch-npu==2.1.0.post12
|
| 74 |
+
- transformers==4.53.2
|
| 75 |
+
|
| 76 |
+
The above software environment has been verified, and theoretically supports newer versions. For any questions, please submit an issue.
|
| 77 |
+
|
| 78 |
+
### 4.2 Integrity Check
|
| 79 |
+
|
| 80 |
+
Please refer to the following methods to verify the integrity of the downloaded content. The hash values are stored in the `checklist.chk` file.
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
#!/usr/bin/env bash
|
| 84 |
+
ARCH=$(uname -m)
|
| 85 |
+
MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
|
| 86 |
+
cd "$MODEL_PATH" || exit 1
|
| 87 |
+
if [ "$ARCH" = "arm64" ]; then
|
| 88 |
+
sha256sum checklist.chk
|
| 89 |
+
else
|
| 90 |
+
sha256sum -c checklist.chk
|
| 91 |
+
fi
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### 4.3 Inference Examples
|
| 95 |
+
|
| 96 |
+
The following provides a simple inference example of openPangu-Embedded-7B-V1.1 based on the `transformers` framework:
|
| 97 |
+
>Please modify generate.py and add the model path before running.
|
| 98 |
+
```bash
|
| 99 |
+
cd inference
|
| 100 |
+
python generate.py
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
The openPangu-Embedded-7B-V1.1 model is in slow thinking mode by default, and can be switched to adaptive/fast thinking mode by the following means:
|
| 104 |
+
- In the code example `generate.py`, the definition of the `auto_thinking_prompt` and `no_thinking_prompt` variables demonstrates the specific implementation for switching to adaptive/fast thinking mode: by appending the `/auto_think` or `/no_think` tag at the end of user input, the current turn can be switched to fast thinking mode. In this mode, `thinking_content` will be an empty value.
|
| 105 |
+
|
| 106 |
+
### 4.4 Using Inference Framework
|
| 107 |
+
vllm_ascend:[[vllm_ascend_for_openpangu_embedded_7b]](inference/vllm_ascend_for_openpangu_embedded_7b.md)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
## 5. Model License
|
| 113 |
+
|
| 114 |
+
Unless otherwise noted, openPangu-Embedded-7B-V1.1 model is licensed under the terms and conditions of OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0, which is intended to be used permissively and enable the further development of artificial intelligence technologies. Please refer to the [LICENSE](LICENSE) file located in the root directory of the model repository for details.
|
| 115 |
+
|
| 116 |
+
## 6. Disclaimer
|
| 117 |
+
Due to the technical limitations inherent in the technology on which the openPangu-Embedded-7B-V1.1 (“Model”) relies and the fact that the artificial intelligence generated content is automatically produced by Model, Huawei cannot make any guarantees regarding the following matters:
|
| 118 |
+
- The output of this Model is automatically generated via AI algorithms, it does not rule out the possibility that some of the information may be flawed, unreasonable, or cause discomfort, and the generated content does not represent Huawei's attitude or standpoint;
|
| 119 |
+
- There is no guarantee that this Model is 100% accurate, reliable, functional, timely, secure and safety, error-free, uninterrupted, continuously stable, or free of any faults;
|
| 120 |
+
- The output of this Model does not constitute any advices or decisions for you, and it does not guarantee the authenticity, completeness, accuracy, timeliness, legality, functionality, or practicality of the generated content. The generated content cannot replace professionals in medical, legal, and other fields in answering your questions. The generated content is for your reference only and does not represent any attitude, standpoint, or position of Huawei. You need to make independent judgments based on your actual situation, and Huawei does not assume any responsibilities.
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
## 7. Contact Us
|
| 124 |
+
If you have any comments or suggestions, please submit an issue or contact [email protected].
|
checklist.chk
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
8e300541888d68a6148d18720fa6b2df48f71e2e1606026d55b30af6dcb2ba37 *./config.json
|
| 2 |
+
7694a0e7b59d7ec2eeebc2fd058f02fe4dc4464b27f82839fc9f425a88555a3a *./configuration_openpangu_dense.py
|
| 3 |
+
5cbfc09f10ae85f0e9bebc1281541dcc7107d86e34282839277782cbb146117d *./generation_config.json
|
| 4 |
+
e2f59bb91ad4e826fe087090393f438e0b6d41b6d5fe2991b390938030877504 *./inference/generate.py
|
| 5 |
+
ba6d7edcf1cf464d6fd787b12a9bda2a16fea0ac0d5d1e54136baec503d6e696 *./inference/vllm_ascend/attention/attention.py
|
| 6 |
+
2254aeca0be7b8922318e10c4a950f39afb30ba5fe3b46564a58671b237ac612 *./inference/vllm_ascend/attention/mla_v1.py
|
| 7 |
+
f9577c29bc4dc19a4cc41ccfcca17065402c9dd92221bef987c74808b23ed124 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
|
| 8 |
+
9070682b058a79d2b2874ba5e07ce72beff6efb870f75cdac30cdcf6ba8fadc7 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
|
| 9 |
+
91eab52cdc19603b7b705b302e25345d849e18fa66875261a1135d5382392123 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
|
| 10 |
+
d07256c9014f911f81269e65aad6c0d7dd61d4e82f5cb399e05285d5c1bc8fa8 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
|
| 11 |
+
52a968f10ebaebeb626248afd3e1d1b92f8fbfcaad19ebf05cafbc0bd03192cb *./inference/vllm_ascend/envs.py
|
| 12 |
+
b654e72ece161b3f04080e5c4d2476641c024939ac5308115fe1c65a6c5c7215 *./inference/vllm_ascend/models/open_pangu.py
|
| 13 |
+
e98aa2549f02017a35b07499216fe569e86400684087821820cf2d971c8fcbac *./inference/vllm_ascend/models/__init__.py
|
| 14 |
+
09273eb0e4696d2fb530881ba1ad9d331897dd81c0cd2f203ed3d0a475b4d39b *./inference/vllm_ascend/ops/fused_moe.py
|
| 15 |
+
8436ab93933989431160e55627b5dce5326f0fc5ec18263653902764ac8ace7b *./inference/vllm_ascend/patch/worker/patch_common/patch_config.py
|
| 16 |
+
8c59df8086bde0cd4df674403f83000921a34403651a8ff2b31de9b28768247a *./inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
|
| 17 |
+
e712ea36caf16c2a9dd21c5288f9d8e34c7fd2face444da44dca6db6c21f6c1b *./inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
|
| 18 |
+
63a6ba0d0b0158d4586219c979bf96d5fe87b74123af93f1c8d9ed842db96500 *./inference/vllm_ascend/patch/worker/patch_common/__init__.py
|
| 19 |
+
743bd96cfc109975a11fe5412c4b5de46f880501dcbbbdd10e11cbeb865fa4f2 *./inference/vllm_ascend/quantization/w8a8.py
|
| 20 |
+
6adfaa8a67ea9b561dec2e6a2392f6fc85ff376fb2030d8761c34c6c6d3f4cbf *./inference/vllm_ascend/quantization/w8a8_dynamic.py
|
| 21 |
+
4a95e3bf02d4c88bb614a48b37fe5c66c8c4c28e574cfc8a0e9791fc22704e7d *./inference/vllm_ascend/tokenizer_config.json
|
| 22 |
+
e2457c558f048876afe069d1226e7080ac214478f1a9ac28ae472928b81b5a06 *./inference/vllm_ascend/utils.py
|
| 23 |
+
62c6734d1283e3d649a6478d2004f46bfee2f7878af7f2849c979b124e355302 *./inference/vllm_ascend/worker/model_runner_v1.py
|
| 24 |
+
bc6505adabc0498ad07b49187858788c65c13dbf9446fd0bcf177a3e1b27220d *./inference/vllm_ascend/worker/npu_input_batch.py
|
| 25 |
+
4aaf57e6f6d2e139b3847b10ee59d738398ebfc4927a22325b27dad384874aec *./inference/vllm_ascend/_build_info.py
|
| 26 |
+
2aae256964b995a9416935d2aeb32cb9028f413a8e9bc2030dbe00e02b38ffec *./model-00001-of-00004.safetensors
|
| 27 |
+
6b0c4bcd085187fc8a1e1aeeec59afac21c876ba9b47e275e87b8ce56e058fec *./model-00002-of-00004.safetensors
|
| 28 |
+
511d39bdc8d619627aa701820c141afa1c3a1139450e2d2d6892004bc7b61721 *./model-00003-of-00004.safetensors
|
| 29 |
+
52ae62d754a63d2fcd8d1f40be1332dd161a2105ff7c43622328ef3800d4d891 *./model-00004-of-00004.safetensors
|
| 30 |
+
9bf645e8399be6d99000eae64bd172b5c457d6d2c44d2257b47eb97a3c41aeda *./model.safetensors.index.json
|
| 31 |
+
f15eaf322af8a0b0f16b26795eb68af836179413d3dbfa4dc44505db6c8b0d6f *./modeling_openpangu_dense.py
|
| 32 |
+
c1f2d87f855b994039c52b1e83c8a7f3d71a2d1eb52946c4a2e862e99f19d8b3 *./modular_openpangu_dense.py
|
| 33 |
+
b34cf5e7c7660889303b6e2d0a346c440356385c9db551d06f6615cf9fc600d1 *./special_tokens_map.json
|
| 34 |
+
c98602d6d1f61792a8bd3393972bbbe7409a205c0bb6299394c74287c26bd723 *./tokenization_openpangu.py
|
| 35 |
+
6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec *./tokenizer.model
|
| 36 |
+
acb88eac57f8765fedf34e9c10bc16d55c46f0902b0fea74fbf041daca2667ae *./tokenizer_config.json
|
config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"PanguEmbeddedForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_openpangu_dense.PanguEmbeddedConfig",
|
| 7 |
+
"AutoModel": "modeling_openpangu_dense.PanguEmbeddedModel",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_openpangu_dense.PanguEmbeddedForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"bias": true,
|
| 11 |
+
"attention_dropout": 0.0,
|
| 12 |
+
"bos_token_id": 1,
|
| 13 |
+
"pad_token_id": 0,
|
| 14 |
+
"eos_token_id": 45892,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 4096,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 12800,
|
| 19 |
+
"max_position_embeddings": 32768,
|
| 20 |
+
"model_type": "PanguEmbedded",
|
| 21 |
+
"num_attention_heads": 32,
|
| 22 |
+
"num_hidden_layers": 34,
|
| 23 |
+
"num_key_value_heads": 8,
|
| 24 |
+
"rms_norm_eps": 1e-05,
|
| 25 |
+
"rope_theta": 16000000.0,
|
| 26 |
+
"tie_word_embeddings": false,
|
| 27 |
+
"torch_dtype": "bfloat16",
|
| 28 |
+
"transformers_version": "4.53.2",
|
| 29 |
+
"use_cache": true,
|
| 30 |
+
"vocab_size": 153376
|
| 31 |
+
}
|
configuration_openpangu_dense.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
from transformers.utils import logging
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
logger = logging.get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PanguEmbeddedConfig(PretrainedConfig):
|
| 12 |
+
|
| 13 |
+
model_type = "PanguEmbedded"
|
| 14 |
+
_auto_class = "AutoConfig"
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
vocab_size=153376,
|
| 19 |
+
hidden_size=4096,
|
| 20 |
+
intermediate_size=12800,
|
| 21 |
+
num_hidden_layers=34,
|
| 22 |
+
num_attention_heads=32,
|
| 23 |
+
num_key_value_heads=8,
|
| 24 |
+
hidden_act="silu",
|
| 25 |
+
max_position_embeddings=32768,
|
| 26 |
+
initializer_range=0.02,
|
| 27 |
+
rms_norm_eps=1e-5,
|
| 28 |
+
use_cache=True,
|
| 29 |
+
pad_token_id=0,
|
| 30 |
+
bos_token_id=1,
|
| 31 |
+
eos_token_id=45892,
|
| 32 |
+
tie_word_embeddings=False,
|
| 33 |
+
rope_theta=16000000.0,
|
| 34 |
+
bias=True,
|
| 35 |
+
**kwargs,
|
| 36 |
+
):
|
| 37 |
+
self.vocab_size = vocab_size
|
| 38 |
+
self.max_position_embeddings = max_position_embeddings
|
| 39 |
+
self.hidden_size = hidden_size
|
| 40 |
+
self.intermediate_size = intermediate_size
|
| 41 |
+
self.num_hidden_layers = num_hidden_layers
|
| 42 |
+
self.num_attention_heads = num_attention_heads
|
| 43 |
+
self.num_key_value_heads = num_key_value_heads
|
| 44 |
+
self.hidden_act = hidden_act
|
| 45 |
+
self.initializer_range = initializer_range
|
| 46 |
+
self.rms_norm_eps = rms_norm_eps
|
| 47 |
+
self.use_cache = use_cache
|
| 48 |
+
self.rope_theta = rope_theta
|
| 49 |
+
self.bias = bias
|
| 50 |
+
super().__init__(
|
| 51 |
+
pad_token_id=pad_token_id,
|
| 52 |
+
bos_token_id=bos_token_id,
|
| 53 |
+
eos_token_id=eos_token_id,
|
| 54 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 55 |
+
**kwargs,
|
| 56 |
+
)
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"pad_token_id": 0,
|
| 6 |
+
"eos_token_id": 45892,
|
| 7 |
+
"temperature": 1.0,
|
| 8 |
+
"top_k": 0,
|
| 9 |
+
"top_p": 0.8,
|
| 10 |
+
"transformers_version": "4.53.2"
|
| 11 |
+
}
|
inference/generate.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
model_local_path = "path_to_openPangu-Embedded-7B"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# load the tokenizer and the model
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 11 |
+
model_local_path,
|
| 12 |
+
use_fast=False,
|
| 13 |
+
trust_remote_code=True,
|
| 14 |
+
local_files_only=True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 18 |
+
model_local_path,
|
| 19 |
+
trust_remote_code=True,
|
| 20 |
+
torch_dtype="auto",
|
| 21 |
+
device_map="npu",
|
| 22 |
+
local_files_only=True
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# prepare the model input
|
| 26 |
+
sys_prompt = "你必须严格遵守法律法规和社会道德规范。" \
|
| 27 |
+
"生成任何内容时,都应避免涉及暴力、色情、恐怖主义、种族歧视、性别歧视等不当内容。" \
|
| 28 |
+
"一旦检测到输入或输出有此类倾向,应拒绝回答并发出警告。例如,如果输入内容包含暴力威胁或色情描述," \
|
| 29 |
+
"应返回错误信息:“您的输入包含不当内容,无法处理。”"
|
| 30 |
+
|
| 31 |
+
prompt = "Give me a short introduction to large language model."
|
| 32 |
+
no_thinking_prompt = prompt+" /no_think"
|
| 33 |
+
auto_thinking_prompt = prompt+" /auto_think"
|
| 34 |
+
messages = [
|
| 35 |
+
{"role": "system", "content": sys_prompt}, # define your system prompt here
|
| 36 |
+
{"role": "user", "content": prompt}
|
| 37 |
+
]
|
| 38 |
+
text = tokenizer.apply_chat_template(
|
| 39 |
+
messages,
|
| 40 |
+
tokenize=False,
|
| 41 |
+
add_generation_prompt=True
|
| 42 |
+
)
|
| 43 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 44 |
+
|
| 45 |
+
# conduct text completion
|
| 46 |
+
outputs = model.generate(**model_inputs, max_new_tokens=32768, eos_token_id=45892, return_dict_in_generate=True)
|
| 47 |
+
|
| 48 |
+
input_length = model_inputs.input_ids.shape[1]
|
| 49 |
+
generated_tokens = outputs.sequences[:, input_length:]
|
| 50 |
+
output_sent = tokenizer.decode(generated_tokens[0])
|
| 51 |
+
|
| 52 |
+
# parsing thinking content
|
| 53 |
+
thinking_content = output_sent.split("[unused17]")[0].split("[unused16]")[-1].strip()
|
| 54 |
+
content = output_sent.split("[unused17]")[-1].split("[unused10]")[0].strip()
|
| 55 |
+
|
| 56 |
+
print("\nthinking content:", thinking_content)
|
| 57 |
+
print("\ncontent:", content)
|
inference/vllm_ascend/_build_info.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto-generated file
|
| 2 |
+
__soc_version__ = 'ASCEND910B1'
|
| 3 |
+
__sleep_mode_enabled__ = True
|
inference/vllm_ascend/attention/attention.py
ADDED
|
@@ -0,0 +1,1220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# This file is a part of the vllm-ascend project.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch_npu
|
| 24 |
+
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
|
| 25 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 26 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 27 |
+
AttentionLayer,
|
| 28 |
+
AttentionMetadata, AttentionType,
|
| 29 |
+
MLAAttentionImpl)
|
| 30 |
+
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
| 31 |
+
CommonMetadataBuilder,
|
| 32 |
+
compute_slot_mapping,
|
| 33 |
+
compute_slot_mapping_start_idx,
|
| 34 |
+
is_block_tables_empty)
|
| 35 |
+
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
| 36 |
+
|
| 37 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 38 |
+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
| 39 |
+
from vllm_ascend.ops.cache import concat_and_cache_mla
|
| 40 |
+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
|
| 41 |
+
enable_custom_op, is_310p, nd_to_nz_2d)
|
| 42 |
+
from vllm_ascend.worker.model_runner import (
|
| 43 |
+
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
| 44 |
+
|
| 45 |
+
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class AscendAttentionBackend(AttentionBackend):
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def get_name() -> str:
|
| 52 |
+
return "ASCEND"
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
| 56 |
+
return AscendAttentionBackendImpl
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def get_metadata_cls() -> Type["AscendMetadata"]:
|
| 60 |
+
return AscendMetadata
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 64 |
+
return CommonAttentionState
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_kv_cache_shape(
|
| 68 |
+
num_blocks: int,
|
| 69 |
+
block_size: int,
|
| 70 |
+
num_kv_heads: int,
|
| 71 |
+
head_size: int,
|
| 72 |
+
) -> Tuple[int, ...]:
|
| 73 |
+
if is_310p():
|
| 74 |
+
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
| 75 |
+
16)
|
| 76 |
+
else:
|
| 77 |
+
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def swap_blocks(
|
| 81 |
+
src_kv_cache: List[torch.Tensor],
|
| 82 |
+
dst_kv_cache: List[torch.Tensor],
|
| 83 |
+
src_to_dst: torch.Tensor,
|
| 84 |
+
) -> None:
|
| 85 |
+
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
| 86 |
+
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
| 87 |
+
src_indices = src_to_dst[:, 0]
|
| 88 |
+
dst_indices = src_to_dst[:, 1]
|
| 89 |
+
|
| 90 |
+
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
| 91 |
+
dst_key_cache.device)
|
| 92 |
+
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
| 93 |
+
dst_key_cache.device)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def copy_blocks(
|
| 97 |
+
kv_caches: List[torch.Tensor],
|
| 98 |
+
src_to_dists: torch.Tensor,
|
| 99 |
+
) -> None:
|
| 100 |
+
src_indices = src_to_dists[:, 0]
|
| 101 |
+
dst_indices = src_to_dists[:, 1]
|
| 102 |
+
|
| 103 |
+
for kv_cache in kv_caches:
|
| 104 |
+
key_caches = kv_cache[0]
|
| 105 |
+
value_caches = kv_cache[1]
|
| 106 |
+
key_caches[dst_indices] = key_caches[src_indices]
|
| 107 |
+
value_caches[dst_indices] = value_caches[src_indices]
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
|
| 111 |
+
return AscendMetadataBuilder
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
|
| 115 |
+
return cls.get_builder_cls()(*args, **kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class AscendMLAAttentionBackend(AscendAttentionBackend):
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]:
|
| 122 |
+
return AscendMLAAttentionBackendImpl
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def get_kv_cache_shape(
|
| 126 |
+
num_blocks: int,
|
| 127 |
+
block_size: int,
|
| 128 |
+
num_kv_heads: int,
|
| 129 |
+
head_size: int,
|
| 130 |
+
) -> Tuple[int, ...]:
|
| 131 |
+
return (num_blocks, block_size, num_kv_heads, head_size)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class AscendMetadata(AttentionMetadata):
|
| 136 |
+
"""Metadata for Ascendbackend.
|
| 137 |
+
* modified from XFormersbackend
|
| 138 |
+
NOTE: Any python object stored here is not updated when it is
|
| 139 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 140 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 141 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
# |---------- N-1 iteration --------|
|
| 145 |
+
# |---------------- N iteration ---------------------|
|
| 146 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 147 |
+
# |---------- context_len ----------|
|
| 148 |
+
# |-------------------- seq_len ----------------------|
|
| 149 |
+
# |-- query_len ---|
|
| 150 |
+
|
| 151 |
+
# FIXME: It is for flash attn.
|
| 152 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 153 |
+
# Avoid mypy error
|
| 154 |
+
# Total number of prefill requests.
|
| 155 |
+
num_prefills: int
|
| 156 |
+
# Number of prefill tokens.
|
| 157 |
+
num_prefill_tokens: int
|
| 158 |
+
# (num_tokens,). The indices of the token slots that input tokens will be
|
| 159 |
+
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
| 160 |
+
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
| 161 |
+
# in block 0, and 1st slot in block 1, respectively.
|
| 162 |
+
slot_mapping: torch.Tensor
|
| 163 |
+
|
| 164 |
+
# requests only.
|
| 165 |
+
max_prefill_seq_len: int
|
| 166 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 167 |
+
# requests only.
|
| 168 |
+
max_decode_seq_len: int
|
| 169 |
+
|
| 170 |
+
chunked_prefill_enabled: bool
|
| 171 |
+
|
| 172 |
+
# (batch_size, max_blocks_per_seq).
|
| 173 |
+
# Block addresses per sequence. (Seq id -> list of physical block)
|
| 174 |
+
block_tables: Optional[torch.Tensor]
|
| 175 |
+
|
| 176 |
+
# seq_lens stored as a tensor.
|
| 177 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 178 |
+
|
| 179 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 180 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 181 |
+
seq_lens: Optional[List[int]] = None
|
| 182 |
+
|
| 183 |
+
# The query lengths of the input sequences
|
| 184 |
+
query_lens: Optional[List[int]] = None
|
| 185 |
+
|
| 186 |
+
# Maximum query length in the batch. None for decoding.
|
| 187 |
+
max_query_len: Optional[int] = None
|
| 188 |
+
|
| 189 |
+
# Self-attention prefill/decode metadata cache
|
| 190 |
+
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
| 191 |
+
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
| 192 |
+
|
| 193 |
+
# Begin encoder attn & enc/dec cross-attn fields...
|
| 194 |
+
|
| 195 |
+
# Encoder sequence lengths representation
|
| 196 |
+
encoder_seq_lens: Optional[List[int]] = None
|
| 197 |
+
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
| 198 |
+
|
| 199 |
+
# Maximum sequence length among encoder sequences
|
| 200 |
+
max_encoder_seq_len: Optional[int] = None
|
| 201 |
+
|
| 202 |
+
# Number of tokens input to encoder
|
| 203 |
+
num_encoder_tokens: Optional[int] = None
|
| 204 |
+
|
| 205 |
+
# Mask for normal situation
|
| 206 |
+
attn_mask: Optional[torch.Tensor] = None
|
| 207 |
+
|
| 208 |
+
# Mask for prefix caching
|
| 209 |
+
compress_mask: Optional[torch.Tensor] = None
|
| 210 |
+
|
| 211 |
+
# Mask for chunked prefill
|
| 212 |
+
chunk_mask: Optional[torch.Tensor] = None
|
| 213 |
+
|
| 214 |
+
# Cross-attention memory-mapping data structures: slot mapping
|
| 215 |
+
# and block tables
|
| 216 |
+
cross_slot_mapping: Optional[torch.Tensor] = None
|
| 217 |
+
cross_block_tables: Optional[torch.Tensor] = None
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def prefill_metadata(self) -> Optional["AscendMetadata"]:
|
| 221 |
+
if self.num_prefills == 0:
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
if self._cached_prefill_metadata is not None:
|
| 225 |
+
# Recover cached prefill-phase attention
|
| 226 |
+
# metadata structure.
|
| 227 |
+
return self._cached_prefill_metadata
|
| 228 |
+
|
| 229 |
+
assert ((self.seq_lens is not None)
|
| 230 |
+
or (self.encoder_seq_lens is not None))
|
| 231 |
+
|
| 232 |
+
# Compute some attn_metadata fields which default to None.
|
| 233 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 234 |
+
self.slot_mapping[:self.num_prefill_tokens])
|
| 235 |
+
seq_lens = (None if self.seq_lens is None else
|
| 236 |
+
self.seq_lens[:self.num_prefills])
|
| 237 |
+
query_lens = (None if self.query_lens is None else
|
| 238 |
+
self.query_lens[:self.num_prefills])
|
| 239 |
+
block_tables = (None if self.block_tables is None else
|
| 240 |
+
self.block_tables[:self.num_prefills])
|
| 241 |
+
|
| 242 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 243 |
+
self.seq_lens_tensor[:self.num_prefills])
|
| 244 |
+
|
| 245 |
+
# Construct & cache prefill-phase attention metadata structure.
|
| 246 |
+
self._cached_prefill_metadata = AscendMetadata(
|
| 247 |
+
num_prefills=self.num_prefills,
|
| 248 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 249 |
+
num_decode_tokens=0,
|
| 250 |
+
slot_mapping=slot_mapping,
|
| 251 |
+
seq_lens=seq_lens,
|
| 252 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 253 |
+
query_lens=query_lens,
|
| 254 |
+
max_query_len=self.max_query_len,
|
| 255 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 256 |
+
max_decode_seq_len=0,
|
| 257 |
+
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
| 258 |
+
block_tables=block_tables,
|
| 259 |
+
# Begin encoder & cross attn fields below...
|
| 260 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 261 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 262 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 263 |
+
multi_modal_placeholder_index_maps=self.
|
| 264 |
+
multi_modal_placeholder_index_maps,
|
| 265 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 266 |
+
cross_block_tables=self.cross_block_tables,
|
| 267 |
+
enable_kv_scales_calculation=False)
|
| 268 |
+
return self._cached_prefill_metadata
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def decode_metadata(self) -> Optional["AscendMetadata"]:
|
| 272 |
+
if self.num_decode_tokens == 0:
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
if self._cached_decode_metadata is not None:
|
| 276 |
+
# Recover cached decode-phase attention
|
| 277 |
+
# metadata structure.
|
| 278 |
+
return self._cached_decode_metadata
|
| 279 |
+
|
| 280 |
+
# Compute some attn_metadata fields which default to None.
|
| 281 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 282 |
+
self.slot_mapping[self.num_prefill_tokens:])
|
| 283 |
+
seq_lens = (None if self.seq_lens is None else
|
| 284 |
+
self.seq_lens[self.num_prefills:])
|
| 285 |
+
query_lens = (None if self.query_lens is None else
|
| 286 |
+
self.query_lens[self.num_prefills:])
|
| 287 |
+
block_tables = (None if self.block_tables is None else
|
| 288 |
+
self.block_tables[self.num_prefills:])
|
| 289 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 290 |
+
self.seq_lens_tensor[self.num_prefills:])
|
| 291 |
+
# Construct & cache decode-phase attention metadata structure.
|
| 292 |
+
self._cached_decode_metadata = AscendMetadata(
|
| 293 |
+
num_prefills=0,
|
| 294 |
+
num_prefill_tokens=0,
|
| 295 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 296 |
+
slot_mapping=slot_mapping,
|
| 297 |
+
seq_lens=seq_lens,
|
| 298 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 299 |
+
query_lens=query_lens,
|
| 300 |
+
max_query_len=self.max_query_len,
|
| 301 |
+
max_prefill_seq_len=0,
|
| 302 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 303 |
+
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
| 304 |
+
block_tables=block_tables,
|
| 305 |
+
# Begin encoder & cross attn fields below...
|
| 306 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 307 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 308 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 309 |
+
multi_modal_placeholder_index_maps=self.
|
| 310 |
+
multi_modal_placeholder_index_maps,
|
| 311 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 312 |
+
cross_block_tables=self.cross_block_tables,
|
| 313 |
+
enable_kv_scales_calculation=False)
|
| 314 |
+
return self._cached_decode_metadata
|
| 315 |
+
|
| 316 |
+
def advance_step(self,
|
| 317 |
+
model_input: "ModelInputForNPUWithSamplingMetadata",
|
| 318 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 319 |
+
block_size: int,
|
| 320 |
+
num_seqs: int,
|
| 321 |
+
num_queries: int,
|
| 322 |
+
turn_prefills_into_decodes: bool = False):
|
| 323 |
+
"""
|
| 324 |
+
Update metadata in-place to advance one decode step.
|
| 325 |
+
"""
|
| 326 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 327 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 328 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 329 |
+
if num_seqs != num_queries:
|
| 330 |
+
assert num_seqs > num_queries
|
| 331 |
+
|
| 332 |
+
if turn_prefills_into_decodes:
|
| 333 |
+
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
| 334 |
+
# decodes are scheduled together. In the first step, all the
|
| 335 |
+
# prefills turn into decodes. This update reflects that
|
| 336 |
+
# conversion.
|
| 337 |
+
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
| 338 |
+
self.num_decode_tokens += self.num_prefills
|
| 339 |
+
self.num_prefills = 0
|
| 340 |
+
self.num_prefill_tokens = 0
|
| 341 |
+
self.max_prefill_seq_len = 0
|
| 342 |
+
self.max_query_len = 1
|
| 343 |
+
|
| 344 |
+
self.slot_mapping = self.slot_mapping[:num_seqs]
|
| 345 |
+
else:
|
| 346 |
+
assert self.seq_lens is not None
|
| 347 |
+
assert self.max_decode_seq_len == max(self.seq_lens)
|
| 348 |
+
|
| 349 |
+
assert self.num_prefills == 0
|
| 350 |
+
assert self.num_prefill_tokens == 0
|
| 351 |
+
assert self.num_decode_tokens == num_seqs
|
| 352 |
+
assert self.slot_mapping.shape == (num_seqs, )
|
| 353 |
+
|
| 354 |
+
assert self.seq_lens is not None
|
| 355 |
+
assert len(self.seq_lens) == num_seqs
|
| 356 |
+
assert self.seq_lens_tensor is not None
|
| 357 |
+
assert self.seq_lens_tensor.shape == (num_seqs, )
|
| 358 |
+
assert self.max_query_len == 1
|
| 359 |
+
assert self.max_prefill_seq_len == 0
|
| 360 |
+
|
| 361 |
+
assert self.block_tables is not None
|
| 362 |
+
assert self.block_tables.shape[0] == num_seqs
|
| 363 |
+
|
| 364 |
+
# Update query lengths. Note that we update only queries and not seqs,
|
| 365 |
+
# since tensors may be padded due to captured cuda graph batch size
|
| 366 |
+
for i in range(num_queries):
|
| 367 |
+
self.seq_lens[i] += 1
|
| 368 |
+
self.max_decode_seq_len = max(self.seq_lens)
|
| 369 |
+
if enable_custom_op():
|
| 370 |
+
#advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
|
| 371 |
+
torch.ops._C.advance_step_flashattn_ascendc(
|
| 372 |
+
num_seqs=num_seqs,
|
| 373 |
+
num_queries=num_queries,
|
| 374 |
+
block_size=block_size,
|
| 375 |
+
input_tokens=model_input.input_tokens,
|
| 376 |
+
sampled_token_ids=sampled_token_ids,
|
| 377 |
+
input_positions=model_input.input_positions,
|
| 378 |
+
seq_lens=self.seq_lens_tensor,
|
| 379 |
+
slot_mapping=self.slot_mapping,
|
| 380 |
+
block_tables=self.block_tables)
|
| 381 |
+
else:
|
| 382 |
+
# use traditional Pytorch method for updating these tensors.
|
| 383 |
+
# update input_tokens
|
| 384 |
+
sampled_token_ids_list = sampled_token_ids[:
|
| 385 |
+
num_queries].squeeze( # type: ignore
|
| 386 |
+
-1)
|
| 387 |
+
model_input.input_tokens[:
|
| 388 |
+
num_queries] = sampled_token_ids_list # type: ignore
|
| 389 |
+
|
| 390 |
+
# get seq_lens and input_positions
|
| 391 |
+
seq_lens = self.seq_lens_tensor[:num_queries]
|
| 392 |
+
next_seq_lens = seq_lens + 1
|
| 393 |
+
next_input_pos = next_seq_lens - 1
|
| 394 |
+
|
| 395 |
+
# update seq_lens and input_positions
|
| 396 |
+
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
| 397 |
+
model_input.input_positions[:
|
| 398 |
+
num_queries] = next_input_pos # type: ignore
|
| 399 |
+
|
| 400 |
+
# 计算 block index 和 offset
|
| 401 |
+
block_idx = next_input_pos // block_size
|
| 402 |
+
block_offset = next_input_pos % block_size
|
| 403 |
+
|
| 404 |
+
current_block_table = self.block_tables.gather(
|
| 405 |
+
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
| 406 |
+
slot_num = current_block_table * block_size + block_offset
|
| 407 |
+
|
| 408 |
+
# update slot_mapping
|
| 409 |
+
self.slot_mapping[:num_queries] = slot_num
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
| 413 |
+
|
| 414 |
+
_attn_mask_builder = None # noqa
|
| 415 |
+
|
| 416 |
+
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
| 417 |
+
self.input_builder = input_builder
|
| 418 |
+
self.runner = input_builder.runner
|
| 419 |
+
self.sliding_window = input_builder.sliding_window
|
| 420 |
+
self.block_size = input_builder.block_size
|
| 421 |
+
|
| 422 |
+
self.attn_mask = None
|
| 423 |
+
self.compress_mask = None
|
| 424 |
+
self.chunk_mask = None
|
| 425 |
+
if AscendMetadataBuilder._attn_mask_builder is None:
|
| 426 |
+
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
|
| 427 |
+
128, self.input_builder.runner.model_config.dtype)
|
| 428 |
+
|
| 429 |
+
def _add_seq_group(
|
| 430 |
+
self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup,
|
| 431 |
+
chunked_prefill_enabled: bool):
|
| 432 |
+
"""Add a sequence group to the metadata. Specifically update/append
|
| 433 |
+
1. context length.
|
| 434 |
+
2. block table.
|
| 435 |
+
3. slot mapping.
|
| 436 |
+
"""
|
| 437 |
+
is_prompt = inter_data.is_prompt
|
| 438 |
+
block_tables = inter_data.block_tables
|
| 439 |
+
|
| 440 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 441 |
+
curr_sliding_window_block) in zip(
|
| 442 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 443 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 444 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 445 |
+
inter_data.curr_sliding_window_blocks):
|
| 446 |
+
self.context_lens.append(context_len)
|
| 447 |
+
if is_prompt:
|
| 448 |
+
self.num_prefills += 1
|
| 449 |
+
self.num_prefill_tokens += token_len
|
| 450 |
+
self.prefill_seq_lens.append(seq_len)
|
| 451 |
+
else:
|
| 452 |
+
self.num_decode_tokens += query_len
|
| 453 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 454 |
+
|
| 455 |
+
# Compute block table.
|
| 456 |
+
# TODO(sang): Combine chunked prefill and prefix caching by
|
| 457 |
+
# only allowing multiple of block_size chunk size.
|
| 458 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 459 |
+
block_table: List[int] = []
|
| 460 |
+
prefix_cache_hit = any([
|
| 461 |
+
inter_data.prefix_cache_hit
|
| 462 |
+
for inter_data in self.input_builder.inter_data_list
|
| 463 |
+
])
|
| 464 |
+
if prefix_cache_hit:
|
| 465 |
+
# NOTE(woosuk): For flash-attn, the block table should
|
| 466 |
+
# include the entries for the incoming prefill tokens.
|
| 467 |
+
if block_tables is not None:
|
| 468 |
+
block_table = block_tables[seq_id]
|
| 469 |
+
elif ((chunked_prefill_enabled or not is_prompt)
|
| 470 |
+
and block_tables is not None):
|
| 471 |
+
if curr_sliding_window_block == 0:
|
| 472 |
+
block_table = block_tables[seq_id]
|
| 473 |
+
else:
|
| 474 |
+
block_table = block_tables[seq_id][
|
| 475 |
+
-curr_sliding_window_block:]
|
| 476 |
+
self.block_tables.append(block_table)
|
| 477 |
+
|
| 478 |
+
# Compute slot mapping.
|
| 479 |
+
is_profile_run = is_block_tables_empty(block_tables)
|
| 480 |
+
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
| 481 |
+
context_len,
|
| 482 |
+
self.sliding_window)
|
| 483 |
+
compute_slot_mapping(
|
| 484 |
+
is_profile_run,
|
| 485 |
+
self.slot_mapping,
|
| 486 |
+
seq_id,
|
| 487 |
+
seq_len,
|
| 488 |
+
context_len,
|
| 489 |
+
start_idx,
|
| 490 |
+
self.block_size,
|
| 491 |
+
inter_data.block_tables,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
def _get_graph_runner_block_tables(
|
| 495 |
+
self, num_seqs: int,
|
| 496 |
+
block_tables: List[List[int]]) -> torch.Tensor:
|
| 497 |
+
# The shape of graph_block_tables is
|
| 498 |
+
# [max batch size, max context len // block size].
|
| 499 |
+
|
| 500 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 501 |
+
assert max_batch_size >= num_seqs
|
| 502 |
+
|
| 503 |
+
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
|
| 504 |
+
for i, block_table in enumerate(block_tables):
|
| 505 |
+
if block_table:
|
| 506 |
+
num_blocks = len(block_table)
|
| 507 |
+
if num_blocks <= max_blocks:
|
| 508 |
+
graph_block_tables[i, :num_blocks] = block_table
|
| 509 |
+
else:
|
| 510 |
+
graph_block_tables[
|
| 511 |
+
i, :max_blocks] = block_table[:max_blocks]
|
| 512 |
+
|
| 513 |
+
return torch.from_numpy(graph_block_tables).to(
|
| 514 |
+
device=self.runner.device, non_blocking=True)
|
| 515 |
+
|
| 516 |
+
def build(
|
| 517 |
+
self,
|
| 518 |
+
seq_lens: List[int],
|
| 519 |
+
query_lens: List[int],
|
| 520 |
+
graph_pad_size: int,
|
| 521 |
+
):
|
| 522 |
+
"""Build attention metadata with on-device tensors.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 526 |
+
query_lens: The query lengths of the input sequences.
|
| 527 |
+
"""
|
| 528 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 529 |
+
self._add_seq_group(inter_data,
|
| 530 |
+
self.input_builder.chunked_prefill_enabled)
|
| 531 |
+
|
| 532 |
+
device = self.runner.device
|
| 533 |
+
dtype = self.runner.model_config.dtype
|
| 534 |
+
use_npu_graph = graph_pad_size != -1
|
| 535 |
+
|
| 536 |
+
max_query_len = max(query_lens)
|
| 537 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 538 |
+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
| 539 |
+
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
|
| 540 |
+
num_decode_tokens = self.num_decode_tokens
|
| 541 |
+
|
| 542 |
+
if self.num_prefills == 0 and use_npu_graph:
|
| 543 |
+
num_seqs = len(seq_lens)
|
| 544 |
+
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
| 545 |
+
self.block_tables.extend([[]] * graph_pad_size)
|
| 546 |
+
block_tables = self._get_graph_runner_block_tables(
|
| 547 |
+
num_seqs, self.block_tables)
|
| 548 |
+
else:
|
| 549 |
+
block_tables = make_tensor_with_pad(
|
| 550 |
+
self.block_tables,
|
| 551 |
+
pad=0,
|
| 552 |
+
dtype=torch.int32,
|
| 553 |
+
device=device,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
if self.num_prefills > 0:
|
| 557 |
+
if block_tables is None or block_tables.numel() == 0:
|
| 558 |
+
# normal mask
|
| 559 |
+
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 560 |
+
max_prefill_seq_len, dtype, device)
|
| 561 |
+
if is_310p():
|
| 562 |
+
mask_nz = nd_to_nz_2d(self.attn_mask)
|
| 563 |
+
mask_nz = torch_npu.npu_format_cast(
|
| 564 |
+
mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
| 565 |
+
self.attn_mask = mask_nz
|
| 566 |
+
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
|
| 567 |
+
# compress mask for prefix cache
|
| 568 |
+
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 569 |
+
128, dtype, device)
|
| 570 |
+
else:
|
| 571 |
+
# chunk_mask for chunk prefill
|
| 572 |
+
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 573 |
+
max_seq_len, dtype, device)
|
| 574 |
+
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
| 575 |
+
# Do not use in-place multiplication to avoid modifying `attn_mask_cache`!
|
| 576 |
+
attn_mask = attn_mask * -10000
|
| 577 |
+
chunk_mask_list = []
|
| 578 |
+
for i, seq_len in enumerate(seq_lens):
|
| 579 |
+
context_len = self.context_lens[i]
|
| 580 |
+
chunk_mask_list.append(attn_mask[context_len:seq_len])
|
| 581 |
+
self.chunk_mask = torch.cat(chunk_mask_list, 0)
|
| 582 |
+
else:
|
| 583 |
+
self.attn_mask = None
|
| 584 |
+
self.compress_mask = None
|
| 585 |
+
self.chunk_mask = None
|
| 586 |
+
|
| 587 |
+
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
| 588 |
+
|
| 589 |
+
assert device is not None
|
| 590 |
+
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
| 591 |
+
device, self.runner.pin_memory)
|
| 592 |
+
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
| 593 |
+
self.runner.pin_memory)
|
| 594 |
+
placeholder_index_maps = {
|
| 595 |
+
modality: placeholder_map.index_map()
|
| 596 |
+
for modality, placeholder_map in
|
| 597 |
+
self.multimodal_placeholder_maps.items()
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
return AscendMetadata(
|
| 601 |
+
num_prefills=self.num_prefills,
|
| 602 |
+
slot_mapping=slot_mapping_tensor,
|
| 603 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 604 |
+
num_decode_tokens=num_decode_tokens,
|
| 605 |
+
seq_lens=seq_lens,
|
| 606 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 607 |
+
enable_kv_scales_calculation=True,
|
| 608 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 609 |
+
query_lens=query_lens,
|
| 610 |
+
max_query_len=max_query_len,
|
| 611 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 612 |
+
max_decode_seq_len=max_decode_seq_len,
|
| 613 |
+
block_tables=block_tables,
|
| 614 |
+
attn_mask=self.attn_mask,
|
| 615 |
+
compress_mask=self.compress_mask,
|
| 616 |
+
chunk_mask=self.chunk_mask,
|
| 617 |
+
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class AscendAttentionBackendImpl(AttentionImpl):
|
| 622 |
+
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
num_heads: int,
|
| 626 |
+
head_size: int,
|
| 627 |
+
scale: float,
|
| 628 |
+
num_kv_heads: int,
|
| 629 |
+
alibi_slopes: Optional[List[float]],
|
| 630 |
+
sliding_window: Optional[int],
|
| 631 |
+
kv_cache_dtype: str,
|
| 632 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 633 |
+
logits_soft_cap: Optional[float] = None,
|
| 634 |
+
attn_type: str = AttentionType.DECODER,
|
| 635 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 636 |
+
use_irope: bool = False,
|
| 637 |
+
) -> None:
|
| 638 |
+
self.num_heads = num_heads
|
| 639 |
+
self.head_size = head_size
|
| 640 |
+
self.scale = float(scale)
|
| 641 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 642 |
+
self.hidden_size = self.num_heads * self.head_size
|
| 643 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 644 |
+
self.sliding_window = sliding_window
|
| 645 |
+
if alibi_slopes is not None:
|
| 646 |
+
alibi_slopes = torch.tensor(alibi_slopes,
|
| 647 |
+
dtype=torch.float32,
|
| 648 |
+
device="npu")
|
| 649 |
+
self.alibi_slopes = alibi_slopes
|
| 650 |
+
self.attn_type = attn_type
|
| 651 |
+
|
| 652 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 653 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 654 |
+
self.seq_len_cpu_tensor = None
|
| 655 |
+
self.query_len_cpu_tensor = None
|
| 656 |
+
self.key_cache = None
|
| 657 |
+
self.value_cache = None
|
| 658 |
+
|
| 659 |
+
def forward(
|
| 660 |
+
self,
|
| 661 |
+
layer: AttentionLayer,
|
| 662 |
+
query: torch.Tensor,
|
| 663 |
+
key: torch.Tensor,
|
| 664 |
+
value: torch.Tensor,
|
| 665 |
+
kv_cache: torch.Tensor,
|
| 666 |
+
attn_metadata: AscendMetadata,
|
| 667 |
+
attn_type: str = AttentionType.DECODER,
|
| 668 |
+
output: Optional[torch.Tensor] = None,
|
| 669 |
+
) -> torch.Tensor:
|
| 670 |
+
"""Forward pass with Ascend attention.
|
| 671 |
+
Args:
|
| 672 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 673 |
+
num_tokens = batch_size * seq_len
|
| 674 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 675 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 676 |
+
kv_cache: shape = [2, num_blocks, block_size,
|
| 677 |
+
num_kv_heads, head_size]
|
| 678 |
+
key_cache = [num_blocks, block_size,
|
| 679 |
+
num_kv_heads, head_size]
|
| 680 |
+
value_cache = [num_blocks, block_size,
|
| 681 |
+
num_kv_heads, head_size]
|
| 682 |
+
attn_metadata: Metadata for attention.
|
| 683 |
+
Returns:
|
| 684 |
+
shape = [batch_size, seq_len * num_heads * head_size]
|
| 685 |
+
"""
|
| 686 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 687 |
+
# View q k v to BSH.
|
| 688 |
+
num_tokens = query.shape[0]
|
| 689 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 690 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 691 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 692 |
+
# TODO: Remove this contiguous in the future.
|
| 693 |
+
value = value.contiguous()
|
| 694 |
+
attn_type = self.attn_type
|
| 695 |
+
|
| 696 |
+
output = torch.empty(num_tokens,
|
| 697 |
+
self.num_heads,
|
| 698 |
+
self.head_size,
|
| 699 |
+
dtype=query.dtype,
|
| 700 |
+
device=query.device)
|
| 701 |
+
|
| 702 |
+
if kv_cache.numel() > 0:
|
| 703 |
+
if self.key_cache is None:
|
| 704 |
+
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
| 705 |
+
slots = attn_metadata.slot_mapping
|
| 706 |
+
|
| 707 |
+
if hasattr(layer, 'quant_method'):
|
| 708 |
+
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
| 709 |
+
if isPrefill:
|
| 710 |
+
assert attn_metadata.prefill_metadata is not None
|
| 711 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 712 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
| 713 |
+
np.int32))
|
| 714 |
+
else:
|
| 715 |
+
assert attn_metadata.decode_metadata is not None
|
| 716 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 717 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 718 |
+
np.int32))
|
| 719 |
+
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
| 720 |
+
# Details of kv_cache arrangement in attention quantization
|
| 721 |
+
# are implemented by quant_method.
|
| 722 |
+
layer.quant_method.apply(
|
| 723 |
+
layer,
|
| 724 |
+
query,
|
| 725 |
+
key,
|
| 726 |
+
value,
|
| 727 |
+
self.key_cache,
|
| 728 |
+
self.value_cache,
|
| 729 |
+
self.scale,
|
| 730 |
+
block_tables,
|
| 731 |
+
isPrefill,
|
| 732 |
+
attn_metadata,
|
| 733 |
+
output,
|
| 734 |
+
seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
|
| 735 |
+
else:
|
| 736 |
+
if self.key_cache is not None:
|
| 737 |
+
torch_npu._npu_reshape_and_cache(key=key,
|
| 738 |
+
value=value,
|
| 739 |
+
key_cache=self.key_cache,
|
| 740 |
+
value_cache=self.value_cache,
|
| 741 |
+
slot_indices=slots)
|
| 742 |
+
|
| 743 |
+
if attn_metadata.num_prefills > 0:
|
| 744 |
+
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
|
| 745 |
+
if (attn_metadata.block_tables is None
|
| 746 |
+
or attn_metadata.block_tables.numel() == 0):
|
| 747 |
+
if attn_type == AttentionType.ENCODER_ONLY:
|
| 748 |
+
# TODO: change to use torch_npu encoder attention op, instead
|
| 749 |
+
# of torch sdpa
|
| 750 |
+
query = query.movedim(0, query.dim() - 2)
|
| 751 |
+
key = key.movedim(0, key.dim() - 2)
|
| 752 |
+
value = value.movedim(0, value.dim() - 2)
|
| 753 |
+
|
| 754 |
+
causal_attn = (attn_type == AttentionType.DECODER)
|
| 755 |
+
if attn_metadata.seq_lens is not None:
|
| 756 |
+
seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
|
| 757 |
+
attn_masks = [None] * len(seq_lens_q)
|
| 758 |
+
start_q, start_kv = 0, 0
|
| 759 |
+
for seq_len_q, seq_len_kv, mask in zip(
|
| 760 |
+
seq_lens_q, seq_lens_kv, attn_masks):
|
| 761 |
+
end_q = start_q + seq_len_q
|
| 762 |
+
end_kv = start_kv + seq_len_kv
|
| 763 |
+
sub_out = scaled_dot_product_attention(
|
| 764 |
+
query[None, :, start_q:end_q, :],
|
| 765 |
+
key[None, :, start_kv:end_kv, :],
|
| 766 |
+
value[None, :, start_kv:end_kv, :],
|
| 767 |
+
attn_mask=mask,
|
| 768 |
+
dropout_p=0.0,
|
| 769 |
+
is_causal=causal_attn and mask is None,
|
| 770 |
+
scale=self.scale).squeeze(0).movedim(
|
| 771 |
+
query.dim() - 2, 0)
|
| 772 |
+
output[start_q:end_q, :, :] = sub_out
|
| 773 |
+
start_q, start_kv = end_q, end_kv
|
| 774 |
+
else:
|
| 775 |
+
assert attn_metadata.attn_mask is not None
|
| 776 |
+
mask = attn_metadata.attn_mask
|
| 777 |
+
assert attn_metadata.prefill_metadata is not None
|
| 778 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 779 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).
|
| 780 |
+
astype(np.int32))
|
| 781 |
+
if is_310p():
|
| 782 |
+
# align q k v output tensors
|
| 783 |
+
query = aligned_16(query)
|
| 784 |
+
key = aligned_16(key)
|
| 785 |
+
value = aligned_16(value)
|
| 786 |
+
output = aligned_16(output)
|
| 787 |
+
|
| 788 |
+
# do reformat in case of broadcasted tensors
|
| 789 |
+
mask = mask.repeat(
|
| 790 |
+
self.seq_lens_tensor_cpu.size(0), 1, 1, 1)
|
| 791 |
+
mask = torch_npu.npu_format_cast(
|
| 792 |
+
mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
| 793 |
+
torch_npu._npu_flash_attention(
|
| 794 |
+
query=query,
|
| 795 |
+
key=key,
|
| 796 |
+
value=value,
|
| 797 |
+
mask=mask,
|
| 798 |
+
seq_len=self.seq_lens_tensor_cpu,
|
| 799 |
+
scale_value=self.scale,
|
| 800 |
+
num_heads=self.num_heads,
|
| 801 |
+
num_kv_heads=self.num_kv_heads,
|
| 802 |
+
out=output)
|
| 803 |
+
output = output[:num_tokens, :, :]
|
| 804 |
+
# Prefix cache only and cache hit
|
| 805 |
+
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
|
| 806 |
+
assert kv_cache is not None
|
| 807 |
+
assert attn_metadata.prefill_metadata is not None
|
| 808 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 809 |
+
np.array(
|
| 810 |
+
attn_metadata.prefill_metadata.seq_lens).astype(
|
| 811 |
+
np.int32))
|
| 812 |
+
self.query_lens_tensor_cpu = torch.from_numpy(
|
| 813 |
+
np.array(
|
| 814 |
+
attn_metadata.prefill_metadata.query_lens).astype(
|
| 815 |
+
np.int32))
|
| 816 |
+
block_tables = attn_metadata.prefill_metadata.block_tables
|
| 817 |
+
assert attn_metadata.compress_mask is not None
|
| 818 |
+
compress_mask = attn_metadata.compress_mask
|
| 819 |
+
torch_npu._npu_flash_attention_qlens(
|
| 820 |
+
query=query,
|
| 821 |
+
key_cache=self.key_cache,
|
| 822 |
+
value_cache=self.value_cache,
|
| 823 |
+
block_table=block_tables,
|
| 824 |
+
mask=compress_mask,
|
| 825 |
+
seq_len=self.query_lens_tensor_cpu,
|
| 826 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 827 |
+
num_kv_heads=self.num_kv_heads,
|
| 828 |
+
num_heads=self.num_heads,
|
| 829 |
+
scale_value=self.scale,
|
| 830 |
+
out=output)
|
| 831 |
+
# Splitfuse
|
| 832 |
+
else:
|
| 833 |
+
assert kv_cache is not None
|
| 834 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 835 |
+
np.array(attn_metadata.seq_lens).astype(np.int32))
|
| 836 |
+
self.query_lens_tensor_cpu = torch.from_numpy(
|
| 837 |
+
np.array(attn_metadata.query_lens).astype(np.int32))
|
| 838 |
+
block_tables = attn_metadata.block_tables
|
| 839 |
+
assert attn_metadata.chunk_mask is not None
|
| 840 |
+
chunk_mask = attn_metadata.chunk_mask
|
| 841 |
+
torch_npu._npu_paged_attention_splitfuse(
|
| 842 |
+
query=query,
|
| 843 |
+
key_cache=self.key_cache,
|
| 844 |
+
value_cache=self.value_cache,
|
| 845 |
+
block_table=block_tables,
|
| 846 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 847 |
+
mask=chunk_mask,
|
| 848 |
+
seq_len=self.query_lens_tensor_cpu,
|
| 849 |
+
num_kv_heads=self.num_kv_heads,
|
| 850 |
+
num_heads=self.num_heads,
|
| 851 |
+
scale_value=self.scale,
|
| 852 |
+
out=output)
|
| 853 |
+
# Decode only
|
| 854 |
+
else:
|
| 855 |
+
assert self.key_cache is not None
|
| 856 |
+
assert self.value_cache is not None
|
| 857 |
+
assert attn_metadata.decode_metadata is not None
|
| 858 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 859 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 860 |
+
np.int32))
|
| 861 |
+
if is_310p():
|
| 862 |
+
# # seq_lens_tensor needs to be transferred to the device for 310P
|
| 863 |
+
self.seq_lens_tensor_cpu = self.seq_lens_tensor_cpu.to(
|
| 864 |
+
device=self.key_cache.device)
|
| 865 |
+
block_tables = attn_metadata.decode_metadata.block_tables
|
| 866 |
+
torch_npu._npu_paged_attention(
|
| 867 |
+
query=query,
|
| 868 |
+
key_cache=self.key_cache,
|
| 869 |
+
value_cache=self.value_cache,
|
| 870 |
+
num_kv_heads=self.num_kv_heads,
|
| 871 |
+
num_heads=self.num_heads,
|
| 872 |
+
scale_value=self.scale,
|
| 873 |
+
block_table=block_tables,
|
| 874 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 875 |
+
out=output)
|
| 876 |
+
|
| 877 |
+
return output.view(num_tokens, self.hidden_size)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
| 881 |
+
|
| 882 |
+
def __init__(
|
| 883 |
+
self,
|
| 884 |
+
num_heads: int,
|
| 885 |
+
head_size: int,
|
| 886 |
+
scale: float,
|
| 887 |
+
num_kv_heads: int,
|
| 888 |
+
alibi_slopes: Optional[List[float]],
|
| 889 |
+
sliding_window: Optional[int],
|
| 890 |
+
kv_cache_dtype: str,
|
| 891 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 892 |
+
logits_soft_cap: Optional[float] = None,
|
| 893 |
+
attn_type: str = AttentionType.DECODER,
|
| 894 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 895 |
+
**extra_impl_args,
|
| 896 |
+
) -> None:
|
| 897 |
+
self.num_heads = num_heads
|
| 898 |
+
self.head_size = head_size
|
| 899 |
+
self.scale = float(scale)
|
| 900 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 901 |
+
self.hidden_size = self.num_heads * self.head_size
|
| 902 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 903 |
+
self.sliding_window = sliding_window
|
| 904 |
+
if alibi_slopes is not None:
|
| 905 |
+
alibi_slopes = torch.tensor(alibi_slopes,
|
| 906 |
+
dtype=torch.float32,
|
| 907 |
+
device="npu")
|
| 908 |
+
self.alibi_slopes = alibi_slopes
|
| 909 |
+
self.attn_type = attn_type
|
| 910 |
+
|
| 911 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 912 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 913 |
+
self.seq_len_cpu_tensor = None
|
| 914 |
+
|
| 915 |
+
# MLA Args
|
| 916 |
+
self.q_lora_rank = extra_impl_args['q_lora_rank']
|
| 917 |
+
self.kv_lora_rank = extra_impl_args['kv_lora_rank']
|
| 918 |
+
self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim']
|
| 919 |
+
self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim']
|
| 920 |
+
self.qk_head_dim = extra_impl_args['qk_head_dim']
|
| 921 |
+
self.v_head_dim = extra_impl_args['v_head_dim']
|
| 922 |
+
self.rotary_emb = extra_impl_args['rotary_emb']
|
| 923 |
+
self.q_proj = extra_impl_args['q_proj']
|
| 924 |
+
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
| 925 |
+
self.o_proj = extra_impl_args['o_proj']
|
| 926 |
+
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
|
| 927 |
+
None)
|
| 928 |
+
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
|
| 929 |
+
self.k_pe_cache = None
|
| 930 |
+
self.k_nope_cache = None
|
| 931 |
+
self.w_kc = None
|
| 932 |
+
self.w_vc = None
|
| 933 |
+
|
| 934 |
+
ascend_config = get_ascend_config()
|
| 935 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def exec_kv(
|
| 939 |
+
self,
|
| 940 |
+
hidden_states: torch.Tensor,
|
| 941 |
+
cos: torch.Tensor,
|
| 942 |
+
sin: torch.Tensor,
|
| 943 |
+
kv_cache: Tuple,
|
| 944 |
+
slots: torch.Tensor,
|
| 945 |
+
):
|
| 946 |
+
B = hidden_states.shape[0]
|
| 947 |
+
N = self.num_kv_heads
|
| 948 |
+
S = 1
|
| 949 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 950 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 951 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 952 |
+
|
| 953 |
+
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
| 954 |
+
kv,
|
| 955 |
+
self.kv_a_layernorm.weight,
|
| 956 |
+
cos,
|
| 957 |
+
sin,
|
| 958 |
+
slots.to(torch.int64),
|
| 959 |
+
kv_cache[1],
|
| 960 |
+
kv_cache[0],
|
| 961 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 962 |
+
cache_mode="PA",
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
return k_pe, k_nope
|
| 966 |
+
|
| 967 |
+
def apply_rotary_emb(
|
| 968 |
+
self,
|
| 969 |
+
x: torch.Tensor,
|
| 970 |
+
cos: torch.Tensor,
|
| 971 |
+
sin: torch.Tensor,
|
| 972 |
+
is_neox_style: bool,
|
| 973 |
+
) -> torch.Tensor:
|
| 974 |
+
"""
|
| 975 |
+
Args:
|
| 976 |
+
x: [num_tokens, num_heads, head_size]
|
| 977 |
+
cos: [num_tokens, head_size // 2]
|
| 978 |
+
sin: [num_tokens, head_size // 2]
|
| 979 |
+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
| 980 |
+
positional embeddings.
|
| 981 |
+
"""
|
| 982 |
+
cos = cos.unsqueeze(-2).to(x.dtype)
|
| 983 |
+
sin = sin.unsqueeze(-2).to(x.dtype)
|
| 984 |
+
if is_neox_style:
|
| 985 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
| 986 |
+
else:
|
| 987 |
+
x1 = x[..., ::2]
|
| 988 |
+
x2 = x[..., 1::2]
|
| 989 |
+
o1 = x1 * cos - x2 * sin
|
| 990 |
+
o2 = x2 * cos + x1 * sin
|
| 991 |
+
if is_neox_style:
|
| 992 |
+
return torch.cat((o1, o2), dim=-1)
|
| 993 |
+
else:
|
| 994 |
+
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
| 995 |
+
|
| 996 |
+
def rope_single(
|
| 997 |
+
self,
|
| 998 |
+
x: torch.Tensor,
|
| 999 |
+
cos: torch.Tensor,
|
| 1000 |
+
sin: torch.Tensor,
|
| 1001 |
+
) -> torch.Tensor:
|
| 1002 |
+
B, N, D = x.shape
|
| 1003 |
+
S = 1
|
| 1004 |
+
x = x.view(B, N, S, D)
|
| 1005 |
+
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
| 1006 |
+
return x.view(B, N, D)
|
| 1007 |
+
|
| 1008 |
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
| 1009 |
+
if self.w_kc is None or self.w_vc is None:
|
| 1010 |
+
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
| 1011 |
+
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
| 1012 |
+
self.kv_lora_rank)
|
| 1013 |
+
self.w_kc = kv_b_proj_weight[:, :self.
|
| 1014 |
+
qk_nope_head_dim, :].contiguous()
|
| 1015 |
+
self.w_vc = kv_b_proj_weight[:,
|
| 1016 |
+
self.qk_nope_head_dim:, :].transpose(
|
| 1017 |
+
1, 2).contiguous()
|
| 1018 |
+
|
| 1019 |
+
def forward(
|
| 1020 |
+
self,
|
| 1021 |
+
layer: AttentionLayer,
|
| 1022 |
+
hidden_states_or_q_c: torch.Tensor,
|
| 1023 |
+
hidden_states_or_kv_c_normed: torch.Tensor,
|
| 1024 |
+
k_pe: torch.Tensor,
|
| 1025 |
+
kv_cache: torch.Tensor,
|
| 1026 |
+
attn_metadata: AscendMetadata,
|
| 1027 |
+
attn_type: str = AttentionType.DECODER,
|
| 1028 |
+
output: Optional[torch.Tensor] = None,
|
| 1029 |
+
) -> torch.Tensor:
|
| 1030 |
+
"""Forward pass with Ascend attention.
|
| 1031 |
+
Args:
|
| 1032 |
+
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
| 1033 |
+
num_tokens = batch_size * seq_len
|
| 1034 |
+
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
| 1035 |
+
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
| 1036 |
+
kv_cache: shape = [1, num_blocks, block_size,
|
| 1037 |
+
num_kv_heads * head_size]
|
| 1038 |
+
attn_metadata: Metadata for attention.
|
| 1039 |
+
Returns:
|
| 1040 |
+
shape = [batch_size, seq_len * num_heads * head_size]
|
| 1041 |
+
"""
|
| 1042 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 1043 |
+
attn_type = self.attn_type
|
| 1044 |
+
if attn_type != AttentionType.DECODER:
|
| 1045 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 1046 |
+
"encoder/decoder cross-attention "
|
| 1047 |
+
"are not implemented for "
|
| 1048 |
+
"PallasAttentionBackendImpl")
|
| 1049 |
+
|
| 1050 |
+
if attn_metadata is None:
|
| 1051 |
+
# for profile run
|
| 1052 |
+
return hidden_states_or_q_c
|
| 1053 |
+
|
| 1054 |
+
num_tokens = hidden_states_or_q_c.shape[0]
|
| 1055 |
+
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
| 1056 |
+
self.qk_head_dim)
|
| 1057 |
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
| 1058 |
+
dim=-1)
|
| 1059 |
+
if k_pe is None and attn_metadata.decode_metadata:
|
| 1060 |
+
seq_len = self.rotary_emb.max_position_embeddings
|
| 1061 |
+
|
| 1062 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
|
| 1063 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
|
| 1064 |
+
cos = cos[attn_metadata.input_positions]
|
| 1065 |
+
sin = sin[attn_metadata.input_positions]
|
| 1066 |
+
cos = cos[:, None, None, :]
|
| 1067 |
+
sin = sin[:, None, None, :]
|
| 1068 |
+
|
| 1069 |
+
q_pe = self.rope_single(q_pe, cos, sin)
|
| 1070 |
+
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
|
| 1071 |
+
kv_cache, attn_metadata.slot_mapping)
|
| 1072 |
+
else:
|
| 1073 |
+
if k_pe is None:
|
| 1074 |
+
# NOTE: k_pe is None when graph mode enabled
|
| 1075 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
| 1076 |
+
hidden_states_or_kv_c_normed)[0].split(
|
| 1077 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 1078 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 1079 |
+
else:
|
| 1080 |
+
kv_c_normed = hidden_states_or_kv_c_normed
|
| 1081 |
+
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
| 1082 |
+
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
| 1083 |
+
# NOTE: When scaling not specified
|
| 1084 |
+
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
| 1085 |
+
q_pe = q_pe.reshape(num_tokens, -1)
|
| 1086 |
+
k_pe = k_pe.reshape(num_tokens, -1)
|
| 1087 |
+
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
| 1088 |
+
q_pe, k_pe)
|
| 1089 |
+
q_pe = q_pe.view(ori_q_pe_shape)
|
| 1090 |
+
k_pe = k_pe.view(ori_k_pe_shape)
|
| 1091 |
+
else:
|
| 1092 |
+
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
| 1093 |
+
q_pe, k_pe)
|
| 1094 |
+
|
| 1095 |
+
if attn_metadata.num_prefills > 0:
|
| 1096 |
+
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
|
| 1097 |
+
self.num_heads, -1)
|
| 1098 |
+
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
| 1099 |
+
dim=-1)
|
| 1100 |
+
else:
|
| 1101 |
+
q_nope_t = torch.transpose(q_nope, 0, 1)
|
| 1102 |
+
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
| 1103 |
+
q_nope = torch.transpose(q_nope_out, 0, 1)
|
| 1104 |
+
|
| 1105 |
+
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
| 1106 |
+
self.num_heads, -1)
|
| 1107 |
+
|
| 1108 |
+
# TODO: Replace the env with more flexible expressions
|
| 1109 |
+
if self.torchair_graph_enabled:
|
| 1110 |
+
if len(kv_cache) > 0 and kv_cache[0].numel(
|
| 1111 |
+
) > 0 and attn_metadata.num_prefills > 0:
|
| 1112 |
+
slots = attn_metadata.slot_mapping
|
| 1113 |
+
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
| 1114 |
+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
| 1115 |
+
num_tokens, self.num_kv_heads, -1),
|
| 1116 |
+
value=k_pe,
|
| 1117 |
+
key_cache=kv_cache[0],
|
| 1118 |
+
value_cache=kv_cache[1],
|
| 1119 |
+
slot_indices=slots)
|
| 1120 |
+
elif kv_cache.numel() > 0:
|
| 1121 |
+
# TODO replace this naive implement with fusion kernel
|
| 1122 |
+
concat_and_cache_mla(kv_c_normed, k_pe, kv_cache,
|
| 1123 |
+
attn_metadata.slot_mapping)
|
| 1124 |
+
|
| 1125 |
+
if attn_metadata.num_prefills > 0:
|
| 1126 |
+
attn_output = torch.empty(num_tokens,
|
| 1127 |
+
self.num_heads,
|
| 1128 |
+
self.v_head_dim,
|
| 1129 |
+
dtype=query.dtype,
|
| 1130 |
+
device=query.device)
|
| 1131 |
+
if (attn_metadata.block_tables is None
|
| 1132 |
+
or attn_metadata.block_tables.numel() == 0):
|
| 1133 |
+
assert attn_metadata.attn_mask is not None
|
| 1134 |
+
assert attn_metadata.prefill_metadata is not None
|
| 1135 |
+
assert attn_metadata.prefill_metadata.seq_lens is not None
|
| 1136 |
+
mask = attn_metadata.attn_mask
|
| 1137 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 1138 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
| 1139 |
+
np.int32))
|
| 1140 |
+
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
| 1141 |
+
key = torch.cat(
|
| 1142 |
+
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
|
| 1143 |
+
torch_npu._npu_flash_attention(
|
| 1144 |
+
query=query,
|
| 1145 |
+
key=key,
|
| 1146 |
+
value=value,
|
| 1147 |
+
mask=mask,
|
| 1148 |
+
seq_len=self.seq_lens_tensor_cpu,
|
| 1149 |
+
scale_value=self.scale,
|
| 1150 |
+
num_heads=self.num_heads,
|
| 1151 |
+
num_kv_heads=self.num_heads,
|
| 1152 |
+
out=attn_output)
|
| 1153 |
+
else:
|
| 1154 |
+
# TODO: Will support prefix cache and chunked prefill soon.
|
| 1155 |
+
raise RuntimeError(
|
| 1156 |
+
"Prefix cache and chunked prefill are currently not supported."
|
| 1157 |
+
)
|
| 1158 |
+
elif attn_metadata.decode_metadata:
|
| 1159 |
+
assert kv_cache is not None
|
| 1160 |
+
if self.torchair_graph_enabled:
|
| 1161 |
+
# shape of query for npu graph mode should be:
|
| 1162 |
+
# [bs, num_heads_per_rank, seq_len, dim]
|
| 1163 |
+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
| 1164 |
+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
| 1165 |
+
# shape of knope/k_pe for npu graph mode should be:
|
| 1166 |
+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
| 1167 |
+
block_size = kv_cache[0].shape[1]
|
| 1168 |
+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
| 1169 |
+
self.kv_lora_rank)
|
| 1170 |
+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
| 1171 |
+
self.qk_rope_head_dim)
|
| 1172 |
+
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
| 1173 |
+
q_nope,
|
| 1174 |
+
k_nope,
|
| 1175 |
+
k_nope,
|
| 1176 |
+
query_rope=q_pe,
|
| 1177 |
+
key_rope=k_pe,
|
| 1178 |
+
num_heads=self.num_heads,
|
| 1179 |
+
num_key_value_heads=self.num_kv_heads,
|
| 1180 |
+
input_layout="BNSD",
|
| 1181 |
+
atten_mask=attn_metadata.attn_mask,
|
| 1182 |
+
scale=self.scale,
|
| 1183 |
+
antiquant_mode=0,
|
| 1184 |
+
antiquant_scale=None,
|
| 1185 |
+
block_table=attn_metadata.block_tables,
|
| 1186 |
+
block_size=block_size,
|
| 1187 |
+
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
| 1188 |
+
)
|
| 1189 |
+
attn_output = attn_output.view(num_tokens, -1,
|
| 1190 |
+
self.kv_lora_rank).transpose(
|
| 1191 |
+
0, 1)
|
| 1192 |
+
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
|
| 1193 |
+
else:
|
| 1194 |
+
# if torch.empty is used here, the preemptive scheduling case of
|
| 1195 |
+
# test_mtp_correctness.py will fail to run.
|
| 1196 |
+
attn_output = torch.randn(
|
| 1197 |
+
[num_tokens, self.num_heads, self.kv_lora_rank],
|
| 1198 |
+
dtype=query.dtype,
|
| 1199 |
+
device=query.device)
|
| 1200 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 1201 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 1202 |
+
np.int32))
|
| 1203 |
+
block_tables = attn_metadata.decode_metadata.block_tables
|
| 1204 |
+
torch_npu._npu_paged_attention_mla(
|
| 1205 |
+
query=query,
|
| 1206 |
+
key_cache=kv_cache,
|
| 1207 |
+
num_kv_heads=self.num_kv_heads,
|
| 1208 |
+
num_heads=self.num_heads,
|
| 1209 |
+
scale_value=self.scale,
|
| 1210 |
+
block_table=block_tables,
|
| 1211 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 1212 |
+
mla_vheadsize=self.kv_lora_rank,
|
| 1213 |
+
out=attn_output)
|
| 1214 |
+
attn_output_t = torch.transpose(attn_output, 0, 1)
|
| 1215 |
+
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
| 1216 |
+
attn_output = torch.transpose(attn_output_t, 0, 1)
|
| 1217 |
+
|
| 1218 |
+
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
| 1219 |
+
|
| 1220 |
+
return output
|
inference/vllm_ascend/attention/mla_v1.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch_npu
|
| 7 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
| 8 |
+
AttentionMetadata,
|
| 9 |
+
MLAAttentionImpl)
|
| 10 |
+
from vllm.attention.backends.utils import PAD_SLOT_ID
|
| 11 |
+
from vllm.config import get_current_vllm_config
|
| 12 |
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 13 |
+
from vllm.model_executor.layers.linear import (LinearBase,
|
| 14 |
+
UnquantizedLinearMethod)
|
| 15 |
+
from vllm.utils import cdiv, round_down
|
| 16 |
+
|
| 17 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 18 |
+
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
| 19 |
+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
| 20 |
+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
| 21 |
+
from vllm_ascend.multistream.context import get_multistream_comm_context
|
| 22 |
+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
| 23 |
+
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
| 24 |
+
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
|
| 25 |
+
from vllm_ascend.worker.npu_input_batch import InputBatch
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from vllm.v1.core.sched.output import SchedulerOutput
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class CommonAttentionMetadata:
|
| 33 |
+
"""
|
| 34 |
+
Attention metadata attributes that can be shared by layers in different KV
|
| 35 |
+
cache groups and thus having different block table.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
query_start_loc: torch.Tensor
|
| 39 |
+
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
| 40 |
+
seq_lens: torch.Tensor
|
| 41 |
+
"""(batch_size,), the length of each request including both computed tokens
|
| 42 |
+
and newly scheduled tokens"""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AscendMLABackend(AttentionBackend):
|
| 46 |
+
|
| 47 |
+
accept_output_buffer: bool = True
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def get_name() -> str:
|
| 51 |
+
return "VLLM_ASCEND_MLA"
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def get_metadata_cls() -> type["AttentionMetadata"]:
|
| 55 |
+
return AscendMLAMetadata
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def get_builder_cls():
|
| 59 |
+
return AscendMLAMetadataBuilder
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
| 63 |
+
head_size: int) -> tuple[int, ...]:
|
| 64 |
+
return (num_blocks, block_size, num_kv_heads, head_size)
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
| 68 |
+
return AscendMLAImpl
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class AscendMLAPrefillMetadata:
|
| 73 |
+
""" Prefill Specific Metadata for Ascend"""
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ChunkedContextMetadata:
|
| 77 |
+
# New for MLA (compared to FlashAttention)
|
| 78 |
+
# For handling chunked prefill
|
| 79 |
+
cu_seq_lens: torch.Tensor
|
| 80 |
+
starts: torch.Tensor
|
| 81 |
+
seq_tot: list[int]
|
| 82 |
+
max_seq_lens: list[int]
|
| 83 |
+
workspace: torch.Tensor
|
| 84 |
+
chunk_seq_lens: torch.Tensor
|
| 85 |
+
|
| 86 |
+
attn_mask: torch.Tensor
|
| 87 |
+
query_lens: list[int]
|
| 88 |
+
seq_lens: list[int]
|
| 89 |
+
context_lens: torch.Tensor
|
| 90 |
+
input_positions: torch.Tensor
|
| 91 |
+
query_start_loc: torch.Tensor
|
| 92 |
+
block_table: torch.Tensor
|
| 93 |
+
max_query_len: int
|
| 94 |
+
max_seq_lens: int
|
| 95 |
+
chunked_context: Optional[ChunkedContextMetadata] = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class AscendMLADecodeMetadata:
|
| 100 |
+
# Input positions for rotrary embeddings since for MLA the rotary
|
| 101 |
+
# position embeddings are applied inside the attention backend
|
| 102 |
+
input_positions: torch.Tensor
|
| 103 |
+
block_table: torch.Tensor
|
| 104 |
+
seq_lens: torch.Tensor
|
| 105 |
+
max_seq_lens: int
|
| 106 |
+
seq_lens_list: list[int]
|
| 107 |
+
attn_mask: Optional[torch.Tensor] = None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclass
|
| 111 |
+
class AscendMLAMetadata:
|
| 112 |
+
"""Metadata for MLACommon.
|
| 113 |
+
|
| 114 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 115 |
+
understand this class
|
| 116 |
+
"""
|
| 117 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 118 |
+
# |---------- N-1 iteration --------|
|
| 119 |
+
# |---------------- N iteration ---------------------|
|
| 120 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 121 |
+
# |---------- context_len ----------|
|
| 122 |
+
# |-------------------- seq_len ---------------------|
|
| 123 |
+
# |-- query_len ---|
|
| 124 |
+
|
| 125 |
+
num_actual_tokens: int # Number of tokens excluding padding.
|
| 126 |
+
slot_mapping: torch.Tensor
|
| 127 |
+
query_start_loc: torch.Tensor
|
| 128 |
+
seq_lens: torch.Tensor
|
| 129 |
+
block_tables: torch.Tensor
|
| 130 |
+
|
| 131 |
+
# New for MLA (compared to FlashAttention)
|
| 132 |
+
# For handling prefill decode split
|
| 133 |
+
num_decodes: int
|
| 134 |
+
num_decode_tokens: int
|
| 135 |
+
num_prefills: int
|
| 136 |
+
|
| 137 |
+
# For logging.
|
| 138 |
+
num_input_tokens: int = 0 # Number of tokens including padding.
|
| 139 |
+
|
| 140 |
+
max_num_tokens_across_dp: int = 0
|
| 141 |
+
with_prefill_across_dp: bool = False
|
| 142 |
+
|
| 143 |
+
query_lens: Optional[list[int]] = None
|
| 144 |
+
# The dimension of the attention heads
|
| 145 |
+
head_dim: Optional[int] = None
|
| 146 |
+
attn_mask: torch.Tensor = None
|
| 147 |
+
# chunked prefill by default if no attn_states passed
|
| 148 |
+
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
| 149 |
+
|
| 150 |
+
decode: Optional[AscendMLADecodeMetadata] = None
|
| 151 |
+
prefill: Optional[AscendMLAPrefillMetadata] = None
|
| 152 |
+
|
| 153 |
+
def __post_init__(self):
|
| 154 |
+
pass
|
| 155 |
+
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
| 156 |
+
# if self.head_dim is not None and self.head_dim \
|
| 157 |
+
# not in supported_head_sizes:
|
| 158 |
+
# raise ValueError(
|
| 159 |
+
# f"Only {supported_head_sizes} are supported for head_dim,",
|
| 160 |
+
# f"received {self.head_dim}.")
|
| 161 |
+
|
| 162 |
+
def split_metadata_for_multistream(
|
| 163 |
+
self,
|
| 164 |
+
ms_split_config: MSAttentionMetadataSplitConfig,
|
| 165 |
+
) -> list["AscendMLAMetadata"]:
|
| 166 |
+
"""Split metadata for multi-stream with AscendMLAMetadata"""
|
| 167 |
+
return model_input_split_v1_mla_attn(
|
| 168 |
+
ms_split_config=ms_split_config,
|
| 169 |
+
attn_metadata=self,
|
| 170 |
+
_metadata_cls=AscendMLAMetadata,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
M = TypeVar("M", bound=AscendMLAMetadata)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class AscendMLAMetadataBuilder:
|
| 178 |
+
"""
|
| 179 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 180 |
+
understand this class
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
# _attn_mask_builder = None
|
| 184 |
+
def __init__(self,
|
| 185 |
+
runner,
|
| 186 |
+
metadata_cls: Optional[AscendMLAMetadata] = None):
|
| 187 |
+
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
| 188 |
+
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
| 189 |
+
self.runner = runner
|
| 190 |
+
scheduler_config = runner.scheduler_config
|
| 191 |
+
model_config = runner.model_config
|
| 192 |
+
self.block_size = runner.block_size
|
| 193 |
+
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
|
| 194 |
+
if self.chunked_prefill_enabled:
|
| 195 |
+
self.chunked_prefill_workspace_size = min(
|
| 196 |
+
# Max sure there is enough for 8 full length request or at least
|
| 197 |
+
# 4 pages of cache per request
|
| 198 |
+
max(8 * model_config.max_model_len,
|
| 199 |
+
4 * scheduler_config.max_num_seqs * self.block_size),
|
| 200 |
+
# For long-context models try not to over-allocate limiting
|
| 201 |
+
# kv-cache space, limiting it to 64k tokens,
|
| 202 |
+
# which would result in the workspace being:
|
| 203 |
+
# 2*(576)*(64*1024) = 144mb
|
| 204 |
+
# (assuming 576 MLA head dim, and fp16)
|
| 205 |
+
# which would result in up-projected context being
|
| 206 |
+
# 2*(192*128)*(64*1024) = 3gb
|
| 207 |
+
# (assuming 192 QK head dim, 128 heads, and fp16)
|
| 208 |
+
128 * 1024)
|
| 209 |
+
assert self.chunked_prefill_workspace_size >= \
|
| 210 |
+
scheduler_config.max_num_seqs * self.block_size
|
| 211 |
+
self.chunked_prefill_workspace = torch.empty(
|
| 212 |
+
(self.chunked_prefill_workspace_size,
|
| 213 |
+
model_config.get_head_size()),
|
| 214 |
+
dtype=model_config.dtype,
|
| 215 |
+
device=runner.device,
|
| 216 |
+
)
|
| 217 |
+
ascend_config = get_ascend_config()
|
| 218 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 219 |
+
|
| 220 |
+
def reorder_batch(self, input_batch: "InputBatch",
|
| 221 |
+
scheduler_output: "SchedulerOutput") -> bool:
|
| 222 |
+
# We now want to reorder the batch so that the "decode" requests are at
|
| 223 |
+
# the front and the "prefill" requests are at the using the least amount
|
| 224 |
+
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
| 225 |
+
# where attention is likely memory-bound and "prefill" to mean requests
|
| 226 |
+
# where attention is likely compute-bound, TODO(lucas): figure out a
|
| 227 |
+
# better naming here)
|
| 228 |
+
decodes = []
|
| 229 |
+
prefills = []
|
| 230 |
+
num_decode_tokens = 0
|
| 231 |
+
num_prefill_tokens = 0
|
| 232 |
+
|
| 233 |
+
for i, req_id in enumerate(input_batch.req_ids):
|
| 234 |
+
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
| 235 |
+
num_spec_tokens = len(
|
| 236 |
+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
| 237 |
+
# For torch air graph mode we treat spec decoding as decode.
|
| 238 |
+
if self.torchair_graph_enabled:
|
| 239 |
+
if num_tokens - num_spec_tokens == 1:
|
| 240 |
+
decodes.append(i)
|
| 241 |
+
num_decode_tokens += num_tokens
|
| 242 |
+
else:
|
| 243 |
+
prefills.append(i)
|
| 244 |
+
num_prefill_tokens += num_tokens
|
| 245 |
+
# For eager mode we treat spec decoding as chunked prefill.
|
| 246 |
+
else:
|
| 247 |
+
if num_tokens == 1:
|
| 248 |
+
decodes.append(i)
|
| 249 |
+
num_decode_tokens += num_tokens
|
| 250 |
+
else:
|
| 251 |
+
prefills.append(i)
|
| 252 |
+
num_prefill_tokens += num_tokens
|
| 253 |
+
|
| 254 |
+
# We hope that this is fairly minimal since decodes
|
| 255 |
+
# should be around for a number of iterations so hopefully they are
|
| 256 |
+
# relatively stationary (and new request are generally appended to the
|
| 257 |
+
# persistent batch so already should be at the back)
|
| 258 |
+
# To achieve this we loop over the decodes in descending order and
|
| 259 |
+
# the prefills in ascending order. We swap decodes from the "back"
|
| 260 |
+
# i.e. past where the last decode should be in the reodorered with
|
| 261 |
+
# prefills from the front of the batch.
|
| 262 |
+
# `decodes` and `prefills` are already in ascending order just based on
|
| 263 |
+
# the above loop
|
| 264 |
+
num_decodes = len(decodes)
|
| 265 |
+
num_prefills = len(prefills)
|
| 266 |
+
first_prefill = 0
|
| 267 |
+
modified_batch = False
|
| 268 |
+
|
| 269 |
+
for i in range(1, min(num_decodes, num_prefills) + 1):
|
| 270 |
+
# If the decode is at the "back" of the batch, i, we can swap it
|
| 271 |
+
# with the prefill closest to the front of the batch
|
| 272 |
+
if decodes[num_decodes - i] >= num_decodes:
|
| 273 |
+
input_batch.swap_states(prefills[first_prefill],
|
| 274 |
+
decodes[num_decodes - i])
|
| 275 |
+
first_prefill += 1
|
| 276 |
+
modified_batch = True
|
| 277 |
+
else:
|
| 278 |
+
break
|
| 279 |
+
|
| 280 |
+
# Save for next `build` call
|
| 281 |
+
# TODO(lucas): this is a bit of a hack, we should probably have a
|
| 282 |
+
# better way of doing this
|
| 283 |
+
self._num_decodes = num_decodes
|
| 284 |
+
self._num_prefills = num_prefills
|
| 285 |
+
self._num_decode_tokens = num_decode_tokens
|
| 286 |
+
self._num_prefill_tokens = num_prefill_tokens
|
| 287 |
+
|
| 288 |
+
return modified_batch
|
| 289 |
+
|
| 290 |
+
def _get_graph_runner_block_tables(
|
| 291 |
+
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
| 292 |
+
|
| 293 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 294 |
+
assert max_batch_size >= num_seqs
|
| 295 |
+
|
| 296 |
+
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
| 297 |
+
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
| 298 |
+
dtype=block_tables.dtype,
|
| 299 |
+
device=block_tables.device)
|
| 300 |
+
else:
|
| 301 |
+
graph_block_tables = self.runner.graph_block_tables.to(
|
| 302 |
+
device=block_tables.device, dtype=block_tables.dtype)
|
| 303 |
+
|
| 304 |
+
num_blocks = block_tables.size(1)
|
| 305 |
+
if num_blocks <= max_blocks:
|
| 306 |
+
graph_block_tables[:num_seqs, :
|
| 307 |
+
num_blocks] = block_tables[:num_seqs, :
|
| 308 |
+
num_blocks]
|
| 309 |
+
else:
|
| 310 |
+
graph_block_tables[:num_seqs, :
|
| 311 |
+
max_blocks] = block_tables[:num_seqs, :
|
| 312 |
+
max_blocks]
|
| 313 |
+
|
| 314 |
+
return graph_block_tables[:num_seqs, :max_blocks]
|
| 315 |
+
|
| 316 |
+
def build_dummy(self, num_reqs: int,
|
| 317 |
+
num_actual_tokens: int) -> AscendMLAMetadata:
|
| 318 |
+
device = self.runner.device
|
| 319 |
+
_, max_blocks = self.runner.graph_block_tables.shape
|
| 320 |
+
block_table = torch.zeros((num_reqs, max_blocks),
|
| 321 |
+
dtype=torch.int32,
|
| 322 |
+
device=device)
|
| 323 |
+
block_table = self._get_graph_runner_block_tables(
|
| 324 |
+
num_reqs, block_table)
|
| 325 |
+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
| 326 |
+
input_positions = torch.zeros(num_reqs,
|
| 327 |
+
dtype=torch.int32,
|
| 328 |
+
device=device).long()
|
| 329 |
+
slot_mapping = torch.full((num_reqs, ),
|
| 330 |
+
PAD_SLOT_ID,
|
| 331 |
+
dtype=torch.int32,
|
| 332 |
+
device=device)
|
| 333 |
+
query_start_loc = torch.full((num_reqs, ),
|
| 334 |
+
-1,
|
| 335 |
+
dtype=torch.int32,
|
| 336 |
+
device=device)
|
| 337 |
+
decode_metadata = AscendMLADecodeMetadata(
|
| 338 |
+
input_positions=input_positions,
|
| 339 |
+
block_table=block_table,
|
| 340 |
+
seq_lens=seq_lens,
|
| 341 |
+
seq_lens_list=seq_lens.tolist(),
|
| 342 |
+
max_seq_lens=1,
|
| 343 |
+
attn_mask=self.runner.spec_attn_mask)
|
| 344 |
+
return self.metadata_cls( # type: ignore
|
| 345 |
+
num_input_tokens=num_actual_tokens,
|
| 346 |
+
num_actual_tokens=num_actual_tokens,
|
| 347 |
+
slot_mapping=slot_mapping,
|
| 348 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 349 |
+
num_decodes=1,
|
| 350 |
+
num_decode_tokens=1,
|
| 351 |
+
num_prefills=0,
|
| 352 |
+
attn_mask=self.runner.attn_mask,
|
| 353 |
+
attn_state=AscendAttentionState.DecodeOnly,
|
| 354 |
+
prefill=None,
|
| 355 |
+
decode=decode_metadata,
|
| 356 |
+
query_start_loc=query_start_loc,
|
| 357 |
+
seq_lens=seq_lens,
|
| 358 |
+
block_tables=block_table,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def build(
|
| 362 |
+
self,
|
| 363 |
+
num_reqs: int,
|
| 364 |
+
num_actual_tokens: int,
|
| 365 |
+
max_query_len: int,
|
| 366 |
+
common_attn_metadata: CommonAttentionMetadata,
|
| 367 |
+
common_prefix_len: Optional[int] = None,
|
| 368 |
+
graph_pad_size: int = -1,
|
| 369 |
+
max_num_tokens_across_dp: int = 0,
|
| 370 |
+
with_prefill_across_dp: bool = False,
|
| 371 |
+
) -> AscendMLAMetadata:
|
| 372 |
+
assert self._num_decodes + self._num_prefills == num_reqs
|
| 373 |
+
|
| 374 |
+
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
| 375 |
+
# function. We should avoid GPU -> CPU sync as much as possible because
|
| 376 |
+
# it blocks on all previous kernels.
|
| 377 |
+
device = self.runner.device
|
| 378 |
+
|
| 379 |
+
block_table = (self.runner.input_batch.block_table[0].
|
| 380 |
+
get_device_tensor()[:num_reqs])
|
| 381 |
+
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
| 382 |
+
device, non_blocking=True)
|
| 383 |
+
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
| 384 |
+
device, non_blocking=True).long()
|
| 385 |
+
|
| 386 |
+
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
| 387 |
+
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
|
| 388 |
+
num_reqs]
|
| 389 |
+
seq_lens = seq_lens_cpu
|
| 390 |
+
max_query_len = query_lens.max().item()
|
| 391 |
+
max_seq_lens = seq_lens.max().item()
|
| 392 |
+
query_start_loc = common_attn_metadata.query_start_loc
|
| 393 |
+
|
| 394 |
+
prefill_metadata = None
|
| 395 |
+
chunked_context_metadata = None
|
| 396 |
+
if self._num_prefills > 0:
|
| 397 |
+
reqs_start = self._num_decodes # prefill_start
|
| 398 |
+
tokens_start = self._num_decode_tokens
|
| 399 |
+
max_query_len = query_lens[tokens_start:].max().item()
|
| 400 |
+
max_seq_lens = seq_lens[tokens_start:].max().item()
|
| 401 |
+
query_start_loc = common_attn_metadata.query_start_loc
|
| 402 |
+
prefill_query_start_loc = query_start_loc[
|
| 403 |
+
reqs_start:] - query_start_loc[reqs_start]
|
| 404 |
+
|
| 405 |
+
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
|
| 406 |
+
reqs_start:num_reqs]
|
| 407 |
+
max_context_len_cpu = context_lens_cpu.max().item()
|
| 408 |
+
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
| 409 |
+
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
| 410 |
+
max_context_chunk = (self.chunked_prefill_workspace_size //
|
| 411 |
+
num_prefills_with_context_cpu)
|
| 412 |
+
max_context_chunk = round_down(max_context_chunk,
|
| 413 |
+
self.block_size)
|
| 414 |
+
|
| 415 |
+
assert max_context_chunk > 0
|
| 416 |
+
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
| 417 |
+
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
| 418 |
+
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
|
| 419 |
+
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
| 420 |
+
chunk_starts + max_context_chunk)
|
| 421 |
+
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
| 422 |
+
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
| 423 |
+
self._num_prefills + 1,
|
| 424 |
+
dtype=torch.int32,
|
| 425 |
+
pin_memory=True)
|
| 426 |
+
torch.cumsum(chunk_seq_lens,
|
| 427 |
+
dim=1,
|
| 428 |
+
out=cu_seq_lens_cpu[:, 1:],
|
| 429 |
+
dtype=torch.int32)
|
| 430 |
+
chunked_context_metadata = \
|
| 431 |
+
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
| 432 |
+
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
| 433 |
+
starts=chunk_starts.to(device, non_blocking=True),
|
| 434 |
+
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
| 435 |
+
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
| 436 |
+
chunk_seq_lens=chunk_seq_lens,
|
| 437 |
+
workspace=self.chunked_prefill_workspace,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
prefill_metadata = AscendMLAPrefillMetadata(
|
| 441 |
+
attn_mask=self.runner.attn_mask,
|
| 442 |
+
query_lens=query_lens[tokens_start:],
|
| 443 |
+
seq_lens=seq_lens,
|
| 444 |
+
context_lens=seq_lens[tokens_start:],
|
| 445 |
+
input_positions=input_positions[tokens_start:],
|
| 446 |
+
block_table=block_table[reqs_start:, ...],
|
| 447 |
+
max_query_len=max_query_len,
|
| 448 |
+
max_seq_lens=max_seq_lens,
|
| 449 |
+
query_start_loc=prefill_query_start_loc,
|
| 450 |
+
chunked_context=chunked_context_metadata,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
decode_metadata = None
|
| 454 |
+
use_torchair_graph = graph_pad_size != -1
|
| 455 |
+
if self._num_decodes > 0:
|
| 456 |
+
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
| 457 |
+
seq_lens = seq_lens[:self._num_decode_tokens]
|
| 458 |
+
input_positions = input_positions[:self._num_decode_tokens]
|
| 459 |
+
block_table = block_table[:self._num_decode_tokens, ...]
|
| 460 |
+
if use_torchair_graph and self.runner.attn_state in [
|
| 461 |
+
AscendAttentionState.DecodeOnly,
|
| 462 |
+
AscendAttentionState.SpecDecoding
|
| 463 |
+
]:
|
| 464 |
+
num_seqs = len(seq_lens)
|
| 465 |
+
if graph_pad_size != 0:
|
| 466 |
+
pad_value = 1
|
| 467 |
+
padded_seq_lens = seq_lens.tolist() + [pad_value
|
| 468 |
+
] * graph_pad_size
|
| 469 |
+
else:
|
| 470 |
+
padded_seq_lens = seq_lens.tolist()
|
| 471 |
+
|
| 472 |
+
seq_lens = torch.from_numpy(
|
| 473 |
+
np.array(padded_seq_lens).astype(np.int32))
|
| 474 |
+
padding = torch.full((graph_pad_size, ),
|
| 475 |
+
PAD_SLOT_ID,
|
| 476 |
+
dtype=slot_mapping.dtype,
|
| 477 |
+
device=slot_mapping.device)
|
| 478 |
+
slot_mapping = torch.cat([slot_mapping, padding])
|
| 479 |
+
block_table_padding = torch.zeros(
|
| 480 |
+
(graph_pad_size, ) + block_table.shape[1:],
|
| 481 |
+
dtype=block_table.dtype,
|
| 482 |
+
device=block_table.device)
|
| 483 |
+
block_table = torch.cat([block_table, block_table_padding],
|
| 484 |
+
dim=0)
|
| 485 |
+
block_table = self._get_graph_runner_block_tables(
|
| 486 |
+
num_seqs + graph_pad_size, block_table)
|
| 487 |
+
padding_0 = torch.zeros(graph_pad_size,
|
| 488 |
+
dtype=input_positions.dtype,
|
| 489 |
+
device=input_positions.device)
|
| 490 |
+
input_positions = torch.cat([input_positions, padding_0])
|
| 491 |
+
|
| 492 |
+
decode_metadata = AscendMLADecodeMetadata(
|
| 493 |
+
input_positions=input_positions,
|
| 494 |
+
block_table=block_table,
|
| 495 |
+
seq_lens=seq_lens,
|
| 496 |
+
seq_lens_list=seq_lens.tolist(),
|
| 497 |
+
max_seq_lens=max_seq_lens,
|
| 498 |
+
attn_mask=self.runner.spec_attn_mask)
|
| 499 |
+
|
| 500 |
+
return self.metadata_cls( # type: ignore
|
| 501 |
+
num_actual_tokens=num_actual_tokens,
|
| 502 |
+
query_lens=query_lens.tolist(),
|
| 503 |
+
slot_mapping=slot_mapping,
|
| 504 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 505 |
+
num_decodes=self._num_decodes,
|
| 506 |
+
num_decode_tokens=self._num_decode_tokens,
|
| 507 |
+
num_prefills=self._num_prefills,
|
| 508 |
+
attn_mask=self.runner.attn_mask,
|
| 509 |
+
attn_state=self.runner.attn_state,
|
| 510 |
+
prefill=prefill_metadata,
|
| 511 |
+
decode=decode_metadata,
|
| 512 |
+
query_start_loc=query_start_loc,
|
| 513 |
+
block_tables=block_table,
|
| 514 |
+
seq_lens=seq_lens,
|
| 515 |
+
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
| 516 |
+
with_prefill_across_dp=with_prefill_across_dp,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class AscendMLAImpl(MLAAttentionImpl):
|
| 521 |
+
"""
|
| 522 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 523 |
+
understand this class
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
def __init__(
|
| 527 |
+
self,
|
| 528 |
+
num_heads: int,
|
| 529 |
+
head_size: int,
|
| 530 |
+
scale: float,
|
| 531 |
+
num_kv_heads: int,
|
| 532 |
+
alibi_slopes: Optional[list[float]],
|
| 533 |
+
sliding_window: Optional[int],
|
| 534 |
+
kv_cache_dtype: str,
|
| 535 |
+
blocksparse_params: Optional[dict[str, Any]],
|
| 536 |
+
logits_soft_cap: Optional[float],
|
| 537 |
+
attn_type: str,
|
| 538 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 539 |
+
**kwargs,
|
| 540 |
+
) -> None:
|
| 541 |
+
self.num_heads = num_heads
|
| 542 |
+
self.head_size = head_size
|
| 543 |
+
self.scale = float(scale)
|
| 544 |
+
self.num_kv_heads = num_kv_heads
|
| 545 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 546 |
+
|
| 547 |
+
# MLA Args
|
| 548 |
+
self.q_lora_rank = kwargs['q_lora_rank']
|
| 549 |
+
self.kv_lora_rank = kwargs['kv_lora_rank']
|
| 550 |
+
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
| 551 |
+
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
| 552 |
+
self.qk_head_dim = kwargs['qk_head_dim']
|
| 553 |
+
self.v_head_dim = kwargs['v_head_dim']
|
| 554 |
+
self.rotary_emb = kwargs['rotary_emb']
|
| 555 |
+
self.q_proj = kwargs['q_proj']
|
| 556 |
+
self.kv_b_proj = kwargs['kv_b_proj']
|
| 557 |
+
self.o_proj = kwargs['o_proj']
|
| 558 |
+
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
| 559 |
+
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
| 560 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 561 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 562 |
+
|
| 563 |
+
ascend_config = get_ascend_config()
|
| 564 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 565 |
+
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
| 566 |
+
|
| 567 |
+
# Adapt torch air graph mode with spec decoding.
|
| 568 |
+
speculative_config = get_current_vllm_config().speculative_config
|
| 569 |
+
if speculative_config is not None:
|
| 570 |
+
self.spec_token_num = speculative_config.num_speculative_tokens
|
| 571 |
+
assert self.spec_token_num > 0
|
| 572 |
+
self.SHARE_MASK_TRIL_SPARSE = ~torch.tril(torch.ones((2048, 2048), dtype=torch.bool)).npu()
|
| 573 |
+
|
| 574 |
+
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
|
| 575 |
+
# Convert from (B, N, L) to (N, B, L)
|
| 576 |
+
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
| 577 |
+
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
| 578 |
+
x = torch.bmm(x, self.W_UV)
|
| 579 |
+
# Convert from (N, B, V) to (B, N * V)
|
| 580 |
+
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
| 581 |
+
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
| 582 |
+
npu_prefetch(self.o_proj.weight,
|
| 583 |
+
x,
|
| 584 |
+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
| 585 |
+
enabled=enable_multistream_mla)
|
| 586 |
+
return self.o_proj(x, is_prefill=False)[0]
|
| 587 |
+
|
| 588 |
+
# Return `ql_nope`, `q_pe`
|
| 589 |
+
def _q_proj_and_k_up_proj(self, x):
|
| 590 |
+
q_nope, q_pe = self.q_proj(x)[0]\
|
| 591 |
+
.view(-1, self.num_heads, self.qk_head_dim)\
|
| 592 |
+
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 593 |
+
|
| 594 |
+
# Convert from (B, N, P) to (N, B, P)
|
| 595 |
+
q_nope = q_nope.transpose(0, 1)
|
| 596 |
+
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
| 597 |
+
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
| 598 |
+
# Convert from (N, B, L) to (B, N, L)
|
| 599 |
+
return ql_nope.transpose(0, 1), q_pe
|
| 600 |
+
|
| 601 |
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
| 602 |
+
|
| 603 |
+
def get_layer_weight(layer):
|
| 604 |
+
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
| 605 |
+
for attr in WEIGHT_NAMES:
|
| 606 |
+
if hasattr(layer, attr):
|
| 607 |
+
return getattr(layer, attr)
|
| 608 |
+
raise AttributeError(
|
| 609 |
+
f"Layer '{layer}' has no recognized weight attribute:"
|
| 610 |
+
f" {WEIGHT_NAMES}.")
|
| 611 |
+
|
| 612 |
+
def get_and_maybe_dequant_weights(layer: LinearBase):
|
| 613 |
+
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
| 614 |
+
# NOTE: This should only be used offline, since it's O(N^3)
|
| 615 |
+
eye = torch.eye(layer.input_size_per_partition,
|
| 616 |
+
dtype=act_dtype,
|
| 617 |
+
device=get_layer_weight(layer).device)
|
| 618 |
+
dequant_weights = layer.quant_method.apply(layer,
|
| 619 |
+
eye,
|
| 620 |
+
bias=None)
|
| 621 |
+
del eye
|
| 622 |
+
# standardize to (output, input)
|
| 623 |
+
return dequant_weights.T
|
| 624 |
+
return layer.weight
|
| 625 |
+
|
| 626 |
+
# we currently do not have quantized bmm's which are needed for
|
| 627 |
+
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
| 628 |
+
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
| 629 |
+
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
| 630 |
+
assert kv_b_proj_weight.shape == (
|
| 631 |
+
self.kv_lora_rank,
|
| 632 |
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
| 633 |
+
f"{kv_b_proj_weight.shape=}, "
|
| 634 |
+
f"{self.kv_lora_rank=}, "
|
| 635 |
+
f"{self.num_heads=}, "
|
| 636 |
+
f"{self.qk_nope_head_dim=}, "
|
| 637 |
+
f"{self.v_head_dim=}")
|
| 638 |
+
kv_b_proj_weight = kv_b_proj_weight.view(
|
| 639 |
+
self.kv_lora_rank,
|
| 640 |
+
self.num_heads,
|
| 641 |
+
self.qk_nope_head_dim + self.v_head_dim,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
W_UK, W_UV = kv_b_proj_weight.split(
|
| 645 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 646 |
+
|
| 647 |
+
# Convert from (L, N, V) to (N, L, V)
|
| 648 |
+
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
| 649 |
+
# Convert from (L, N, P) to (N, P, L)
|
| 650 |
+
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
| 651 |
+
|
| 652 |
+
# Waiting for BMM NZ support
|
| 653 |
+
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
| 654 |
+
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
| 655 |
+
|
| 656 |
+
def _compute_prefill_context(
|
| 657 |
+
self,
|
| 658 |
+
query: torch.Tensor,
|
| 659 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 660 |
+
rope_dim: int,
|
| 661 |
+
attn_metadata: AscendMLAMetadata,
|
| 662 |
+
prefix_output: torch.Tensor,
|
| 663 |
+
prefix_lse: torch.Tensor,
|
| 664 |
+
):
|
| 665 |
+
prefill_metadata = attn_metadata.prefill
|
| 666 |
+
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
| 667 |
+
return prefix_output, prefix_lse
|
| 668 |
+
|
| 669 |
+
iters = len(prefill_metadata.chunked_context.seq_tot)
|
| 670 |
+
q_pe = query[..., self.qk_nope_head_dim:]
|
| 671 |
+
q_nope = query[..., :self.qk_nope_head_dim]
|
| 672 |
+
|
| 673 |
+
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
| 674 |
+
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
|
| 675 |
+
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
|
| 676 |
+
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
|
| 677 |
+
for i in range(iters):
|
| 678 |
+
toks = prefill_metadata.chunked_context.seq_tot[i]
|
| 679 |
+
|
| 680 |
+
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
| 681 |
+
seq_len = torch.stack([seq_len1, seq_len2])
|
| 682 |
+
kv_c_normed = torch.empty(toks,
|
| 683 |
+
kv_c_and_k_pe_cache.size(2),
|
| 684 |
+
latent_kv_dim,
|
| 685 |
+
dtype=query.dtype,
|
| 686 |
+
device=query.device)
|
| 687 |
+
k_pe = torch.empty(toks,
|
| 688 |
+
kv_c_and_k_pe_cache.size(2),
|
| 689 |
+
rope_dim,
|
| 690 |
+
dtype=query.dtype,
|
| 691 |
+
device=query.device)
|
| 692 |
+
|
| 693 |
+
torch_npu.atb.npu_paged_cache_load(
|
| 694 |
+
cache_kv_c,
|
| 695 |
+
cache_k_pe,
|
| 696 |
+
prefill_metadata.block_table,
|
| 697 |
+
seq_len2.to(query.device),
|
| 698 |
+
seq_starts=prefill_metadata.chunked_context.starts[i],
|
| 699 |
+
key=kv_c_normed,
|
| 700 |
+
value=k_pe,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
kv_c_normed = kv_c_normed.squeeze()
|
| 704 |
+
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
| 705 |
+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 706 |
+
k_nope, v = kv_nope\
|
| 707 |
+
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 708 |
+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
| 709 |
+
mask = torch.triu(
|
| 710 |
+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
| 711 |
+
1)
|
| 712 |
+
torch_npu.atb.npu_ring_mla(
|
| 713 |
+
q_nope=q_nope,
|
| 714 |
+
q_rope=q_pe,
|
| 715 |
+
k_nope=k_nope,
|
| 716 |
+
k_rope=k_pe,
|
| 717 |
+
value=v,
|
| 718 |
+
mask=mask,
|
| 719 |
+
seqlen=seq_len,
|
| 720 |
+
head_num=self.num_heads,
|
| 721 |
+
kv_head_num=self.num_heads,
|
| 722 |
+
pre_out=prefix_output,
|
| 723 |
+
prev_lse=prefix_lse,
|
| 724 |
+
qk_scale=self.scale,
|
| 725 |
+
kernel_type="kernel_type_high_precision",
|
| 726 |
+
mask_type="no_mask",
|
| 727 |
+
input_layout="type_bsnd",
|
| 728 |
+
calc_type="calc_type_default",
|
| 729 |
+
output=prefix_output,
|
| 730 |
+
softmax_lse=prefix_lse)
|
| 731 |
+
return prefix_output, prefix_lse
|
| 732 |
+
|
| 733 |
+
def _forward_prefill(
|
| 734 |
+
self,
|
| 735 |
+
query: torch.Tensor,
|
| 736 |
+
kv_c_normed: torch.Tensor,
|
| 737 |
+
k_pe: torch.Tensor,
|
| 738 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 739 |
+
attn_metadata: AscendMLAMetadata,
|
| 740 |
+
) -> torch.Tensor:
|
| 741 |
+
assert attn_metadata.prefill is not None
|
| 742 |
+
|
| 743 |
+
num_tokens = query.size(0)
|
| 744 |
+
attn_output = torch.empty(num_tokens,
|
| 745 |
+
self.num_heads,
|
| 746 |
+
self.v_head_dim,
|
| 747 |
+
dtype=query.dtype,
|
| 748 |
+
device=query.device)
|
| 749 |
+
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
| 750 |
+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
| 751 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 752 |
+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
| 753 |
+
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
| 754 |
+
ascend_config = get_ascend_config()
|
| 755 |
+
|
| 756 |
+
if attn_metadata.attn_state in [
|
| 757 |
+
AscendAttentionState.ChunkedPrefill,
|
| 758 |
+
AscendAttentionState.SpecDecoding,
|
| 759 |
+
AscendAttentionState.PrefillCacheHit
|
| 760 |
+
] and not ascend_config.chunked_prefill_for_mla:
|
| 761 |
+
attn_output_torch = torch.empty(num_tokens,
|
| 762 |
+
self.num_heads * self.v_head_dim,
|
| 763 |
+
dtype=query.dtype,
|
| 764 |
+
device=query.device)
|
| 765 |
+
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
| 766 |
+
vanilla_chunked_prefill_mla(
|
| 767 |
+
output=attn_output_torch,
|
| 768 |
+
query=query,
|
| 769 |
+
kv_cache=kv_c_and_k_pe_cache,
|
| 770 |
+
block_tables=attn_metadata.prefill.block_table,
|
| 771 |
+
query_lens=attn_metadata.prefill.query_lens,
|
| 772 |
+
context_lens=attn_metadata.prefill.context_lens,
|
| 773 |
+
kv_b_proj=self.kv_b_proj,
|
| 774 |
+
max_query_len=attn_metadata.prefill.max_query_len,
|
| 775 |
+
max_context_len=attn_metadata.prefill.max_seq_lens,
|
| 776 |
+
nope_dim=self.qk_nope_head_dim,
|
| 777 |
+
rope_dim=self.qk_rope_head_dim,
|
| 778 |
+
v_head_dim=self.v_head_dim,
|
| 779 |
+
scale=self.scale,
|
| 780 |
+
alibi_slopes=None,
|
| 781 |
+
causal=True)
|
| 782 |
+
elif attn_metadata.attn_state in [
|
| 783 |
+
AscendAttentionState.ChunkedPrefill,
|
| 784 |
+
AscendAttentionState.SpecDecoding,
|
| 785 |
+
AscendAttentionState.PrefillCacheHit
|
| 786 |
+
]:
|
| 787 |
+
attn_lse = torch.empty(self.num_heads,
|
| 788 |
+
num_tokens,
|
| 789 |
+
dtype=torch.float32,
|
| 790 |
+
device=query.device)
|
| 791 |
+
q_pe = query[..., self.qk_nope_head_dim:]
|
| 792 |
+
q_nope = query[..., :self.qk_nope_head_dim]
|
| 793 |
+
mask = torch.triu(
|
| 794 |
+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
| 795 |
+
1) # 512: mask only support 512
|
| 796 |
+
if attn_metadata.num_prefills > 1:
|
| 797 |
+
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
| 798 |
+
1)
|
| 799 |
+
torch_npu.atb.npu_ring_mla(
|
| 800 |
+
q_nope=q_nope,
|
| 801 |
+
q_rope=q_pe,
|
| 802 |
+
k_nope=k_nope,
|
| 803 |
+
k_rope=k_pe,
|
| 804 |
+
value=value,
|
| 805 |
+
mask=mask,
|
| 806 |
+
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
| 807 |
+
dtype=torch.int32),
|
| 808 |
+
head_num=self.num_heads,
|
| 809 |
+
kv_head_num=self.num_heads,
|
| 810 |
+
pre_out=None,
|
| 811 |
+
prev_lse=None,
|
| 812 |
+
qk_scale=self.scale,
|
| 813 |
+
kernel_type="kernel_type_high_precision",
|
| 814 |
+
mask_type="mask_type_triu",
|
| 815 |
+
input_layout="type_bsnd",
|
| 816 |
+
calc_type="calc_type_first_ring",
|
| 817 |
+
output=attn_output,
|
| 818 |
+
softmax_lse=attn_lse)
|
| 819 |
+
attn_output, attn_lse = self._compute_prefill_context( \
|
| 820 |
+
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
| 821 |
+
|
| 822 |
+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 823 |
+
key = torch.cat((k_nope, k_pe), dim=-1)
|
| 824 |
+
context_lens_list = torch.cumsum(attn_metadata.prefill.context_lens, dim=0).tolist()
|
| 825 |
+
attn_output = torch_npu.npu_fused_infer_attention_score(
|
| 826 |
+
query,
|
| 827 |
+
key,
|
| 828 |
+
value,
|
| 829 |
+
num_heads=self.num_heads,
|
| 830 |
+
input_layout="TND",
|
| 831 |
+
scale=self.scale,
|
| 832 |
+
sparse_mode=3,
|
| 833 |
+
atten_mask=self.SHARE_MASK_TRIL_SPARSE,
|
| 834 |
+
actual_seq_lengths=context_lens_list,
|
| 835 |
+
actual_seq_lengths_kv=context_lens_list,
|
| 836 |
+
inner_precise=0)[0]
|
| 837 |
+
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
| 838 |
+
else:
|
| 839 |
+
raise RuntimeError(
|
| 840 |
+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
| 841 |
+
)
|
| 842 |
+
attn_output = attn_output.reshape(
|
| 843 |
+
[num_tokens, self.num_heads * self.v_head_dim])
|
| 844 |
+
if attn_metadata.attn_state in [
|
| 845 |
+
AscendAttentionState.ChunkedPrefill,
|
| 846 |
+
AscendAttentionState.SpecDecoding,
|
| 847 |
+
AscendAttentionState.PrefillCacheHit
|
| 848 |
+
] and not ascend_config.chunked_prefill_for_mla:
|
| 849 |
+
attn_output = attn_output_torch
|
| 850 |
+
|
| 851 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 852 |
+
if current_ms_metadata is None:
|
| 853 |
+
return self.o_proj(attn_output, is_prefill=True)[0]
|
| 854 |
+
else:
|
| 855 |
+
current_ms_metadata.before_comm_event.record()
|
| 856 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 857 |
+
current_ms_metadata.before_comm_event.wait()
|
| 858 |
+
return self.o_proj(attn_output, is_prefill=True)[0]
|
| 859 |
+
|
| 860 |
+
def exec_kv(
|
| 861 |
+
self,
|
| 862 |
+
hidden_states: torch.Tensor,
|
| 863 |
+
cos: torch.Tensor,
|
| 864 |
+
sin: torch.Tensor,
|
| 865 |
+
kv_cache: Tuple,
|
| 866 |
+
slots: torch.Tensor,
|
| 867 |
+
):
|
| 868 |
+
|
| 869 |
+
B = hidden_states.shape[0]
|
| 870 |
+
N = self.num_kv_heads
|
| 871 |
+
S = 1
|
| 872 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 873 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 874 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 875 |
+
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
| 876 |
+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
| 877 |
+
kv,
|
| 878 |
+
self.kv_a_layernorm.weight,
|
| 879 |
+
cos,
|
| 880 |
+
sin,
|
| 881 |
+
slots.to(torch.int64),
|
| 882 |
+
kv_cache[1],
|
| 883 |
+
kv_cache[0],
|
| 884 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 885 |
+
cache_mode=cache_mode,
|
| 886 |
+
)
|
| 887 |
+
return k_pe, k_nope, kv
|
| 888 |
+
|
| 889 |
+
def exec_kv_prefill(
|
| 890 |
+
self,
|
| 891 |
+
hidden_states: torch.Tensor,
|
| 892 |
+
cos: torch.Tensor,
|
| 893 |
+
sin: torch.Tensor,
|
| 894 |
+
kv_cache: Tuple,
|
| 895 |
+
slots: torch.Tensor,
|
| 896 |
+
):
|
| 897 |
+
|
| 898 |
+
B = hidden_states.shape[0]
|
| 899 |
+
N = self.num_kv_heads
|
| 900 |
+
S = 1
|
| 901 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 902 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 903 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 904 |
+
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
|
| 905 |
+
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
| 906 |
+
kv,
|
| 907 |
+
self.kv_a_layernorm.weight,
|
| 908 |
+
cos,
|
| 909 |
+
sin,
|
| 910 |
+
slots.to(torch.int64),
|
| 911 |
+
kv_cache[1],
|
| 912 |
+
kv_cache[0],
|
| 913 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 914 |
+
cache_mode=cache_mode,
|
| 915 |
+
is_output_kv=True,
|
| 916 |
+
)
|
| 917 |
+
return k_pe, k_nope
|
| 918 |
+
|
| 919 |
+
def rope_single(
|
| 920 |
+
self,
|
| 921 |
+
x: torch.Tensor,
|
| 922 |
+
cos: torch.Tensor,
|
| 923 |
+
sin: torch.Tensor,
|
| 924 |
+
) -> torch.Tensor:
|
| 925 |
+
B, N, D = x.shape
|
| 926 |
+
S = 1
|
| 927 |
+
x = x.view(B, N, S, D)
|
| 928 |
+
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
| 929 |
+
return x.view(B, N, D)
|
| 930 |
+
|
| 931 |
+
def _forward_decode(
|
| 932 |
+
self,
|
| 933 |
+
q_nope: torch.Tensor,
|
| 934 |
+
q_pe: torch.Tensor,
|
| 935 |
+
k_nope: torch.Tensor,
|
| 936 |
+
k_pe: torch.Tensor,
|
| 937 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 938 |
+
attn_metadata: AscendMLAMetadata,
|
| 939 |
+
enable_multistream_mla: bool = False,
|
| 940 |
+
) -> torch.Tensor:
|
| 941 |
+
decode_meta = attn_metadata.decode
|
| 942 |
+
assert decode_meta is not None
|
| 943 |
+
|
| 944 |
+
q = torch.cat([q_nope, q_pe], dim=-1)
|
| 945 |
+
num_tokens = q.size(0)
|
| 946 |
+
attn_output = torch.empty(
|
| 947 |
+
[num_tokens, self.num_heads, self.kv_lora_rank],
|
| 948 |
+
dtype=q.dtype,
|
| 949 |
+
device=q.device)
|
| 950 |
+
if self.running_in_graph:
|
| 951 |
+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
|
| 952 |
+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
| 953 |
+
assert num_tokens % self.spec_token_num == 0
|
| 954 |
+
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
|
| 955 |
+
self.spec_token_num + 1, self.num_heads,
|
| 956 |
+
-1)
|
| 957 |
+
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
|
| 958 |
+
self.spec_token_num + 1, self.num_heads, -1)
|
| 959 |
+
if not self.enable_kv_nz:
|
| 960 |
+
q_nope = q_nope.transpose(1, 2).contiguous()
|
| 961 |
+
q_pe = q_pe.transpose(1, 2).contiguous()
|
| 962 |
+
sparse_mode = 3
|
| 963 |
+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
| 964 |
+
else:
|
| 965 |
+
if self.enable_kv_nz:
|
| 966 |
+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
|
| 967 |
+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
| 968 |
+
else:
|
| 969 |
+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
| 970 |
+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
| 971 |
+
sparse_mode = 0
|
| 972 |
+
spec_attn_mask = None
|
| 973 |
+
# shape of knope/k_pe for npu graph mode should be:
|
| 974 |
+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
| 975 |
+
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
| 976 |
+
if self.enable_kv_nz:
|
| 977 |
+
k_nope = k_nope.view(-1, self.num_kv_heads,
|
| 978 |
+
self.kv_lora_rank // 16, block_size, 16)
|
| 979 |
+
k_pe = k_pe.view(-1, self.num_kv_heads,
|
| 980 |
+
self.qk_rope_head_dim // 16, block_size, 16)
|
| 981 |
+
input_layout = "BSND"
|
| 982 |
+
else:
|
| 983 |
+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
| 984 |
+
self.kv_lora_rank)
|
| 985 |
+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
| 986 |
+
self.qk_rope_head_dim)
|
| 987 |
+
input_layout = "BNSD"
|
| 988 |
+
|
| 989 |
+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 990 |
+
q_nope,
|
| 991 |
+
k_nope,
|
| 992 |
+
k_nope,
|
| 993 |
+
query_rope=q_pe,
|
| 994 |
+
key_rope=k_pe,
|
| 995 |
+
num_heads=self.num_heads,
|
| 996 |
+
num_key_value_heads=self.num_kv_heads,
|
| 997 |
+
input_layout=input_layout,
|
| 998 |
+
atten_mask=spec_attn_mask,
|
| 999 |
+
sparse_mode=sparse_mode,
|
| 1000 |
+
scale=self.scale,
|
| 1001 |
+
antiquant_mode=0,
|
| 1002 |
+
antiquant_scale=None,
|
| 1003 |
+
block_table=decode_meta.block_table,
|
| 1004 |
+
block_size=block_size,
|
| 1005 |
+
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
| 1006 |
+
)
|
| 1007 |
+
else:
|
| 1008 |
+
torch_npu._npu_paged_attention_mla(
|
| 1009 |
+
query=q,
|
| 1010 |
+
key_cache=kv_c_and_k_pe_cache,
|
| 1011 |
+
num_kv_heads=self.num_kv_heads,
|
| 1012 |
+
num_heads=self.num_heads,
|
| 1013 |
+
scale_value=self.scale,
|
| 1014 |
+
block_table=attn_metadata.decode.block_table, # type:ignore
|
| 1015 |
+
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
| 1016 |
+
mla_vheadsize=self.kv_lora_rank,
|
| 1017 |
+
out=attn_output)
|
| 1018 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1019 |
+
if current_ms_metadata is None:
|
| 1020 |
+
return self._v_up_proj_and_o_proj(attn_output,
|
| 1021 |
+
enable_multistream_mla)
|
| 1022 |
+
else:
|
| 1023 |
+
current_ms_metadata.before_comm_event.record()
|
| 1024 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1025 |
+
current_ms_metadata.before_comm_event.wait()
|
| 1026 |
+
return self._v_up_proj_and_o_proj(attn_output)
|
| 1027 |
+
|
| 1028 |
+
def forward(
|
| 1029 |
+
self,
|
| 1030 |
+
layer: AttentionLayer,
|
| 1031 |
+
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
| 1032 |
+
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
|
| 1033 |
+
k_pe: torch.Tensor, # value in unified attn
|
| 1034 |
+
kv_cache: torch.Tensor,
|
| 1035 |
+
attn_metadata: M,
|
| 1036 |
+
output: Optional[torch.Tensor] = None,
|
| 1037 |
+
enable_multistream_mla: bool = False,
|
| 1038 |
+
ckq: Optional[torch.Tensor] = None,
|
| 1039 |
+
) -> torch.Tensor:
|
| 1040 |
+
assert output is not None, "Output tensor must be provided."
|
| 1041 |
+
if attn_metadata is None:
|
| 1042 |
+
# Profiling run.
|
| 1043 |
+
return output
|
| 1044 |
+
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
|
| 1045 |
+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
| 1046 |
+
]
|
| 1047 |
+
num_actual_toks = attn_metadata.num_actual_tokens
|
| 1048 |
+
if k_pe is None and not self.running_in_graph:
|
| 1049 |
+
if not self.torchair_graph_enabled:
|
| 1050 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
| 1051 |
+
hidden_states_or_kv_c_normed)[0].split(
|
| 1052 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 1053 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 1054 |
+
else:
|
| 1055 |
+
kv_c_normed = hidden_states_or_kv_c_normed
|
| 1056 |
+
assert attn_metadata.num_decodes is not None and \
|
| 1057 |
+
attn_metadata.num_prefills is not None and \
|
| 1058 |
+
attn_metadata.num_decode_tokens is not None
|
| 1059 |
+
has_decode = attn_metadata.num_decodes > 0
|
| 1060 |
+
has_prefill = attn_metadata.num_prefills > 0
|
| 1061 |
+
num_decode_tokens = attn_metadata.num_decode_tokens
|
| 1062 |
+
if not self.running_in_graph:
|
| 1063 |
+
# Inputs and outputs may be padded for CUDA graphs
|
| 1064 |
+
output_padded = output
|
| 1065 |
+
output = output[:num_actual_toks, ...]
|
| 1066 |
+
if not self.torchair_graph_enabled:
|
| 1067 |
+
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
| 1068 |
+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
| 1069 |
+
if not self.running_in_graph:
|
| 1070 |
+
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
| 1071 |
+
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
| 1072 |
+
if not self.torchair_graph_enabled:
|
| 1073 |
+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
| 1074 |
+
k_pe = k_pe[:num_actual_toks, ...]
|
| 1075 |
+
k_pe = k_pe.unsqueeze(1)
|
| 1076 |
+
decode_k_pe = k_pe[:num_decode_tokens]
|
| 1077 |
+
prefill_k_pe = k_pe[num_decode_tokens:]
|
| 1078 |
+
else:
|
| 1079 |
+
decode_hs_or_q_c = hidden_states_or_q_c
|
| 1080 |
+
if has_decode:
|
| 1081 |
+
decode_k_nope = None
|
| 1082 |
+
assert attn_metadata.decode is not None
|
| 1083 |
+
if self.running_in_graph:
|
| 1084 |
+
seq_len = self.rotary_emb.max_position_embeddings * \
|
| 1085 |
+
getattr(self.rotary_emb, "scaling_factor", 1)
|
| 1086 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
| 1087 |
+
dtype=decode_hs_or_q_c.dtype)
|
| 1088 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
| 1089 |
+
dtype=decode_hs_or_q_c.dtype)
|
| 1090 |
+
cos = cos[attn_metadata.decode.input_positions]
|
| 1091 |
+
sin = sin[attn_metadata.decode.input_positions]
|
| 1092 |
+
cos = cos[:, None, None, :]
|
| 1093 |
+
sin = sin[:, None, None, :]
|
| 1094 |
+
with npu_stream_switch("mla_secondary",
|
| 1095 |
+
0,
|
| 1096 |
+
enabled=enable_multistream_mla):
|
| 1097 |
+
npu_wait_tensor(hidden_states_or_kv_c_normed,
|
| 1098 |
+
ckq,
|
| 1099 |
+
enabled=enable_multistream_mla)
|
| 1100 |
+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
|
| 1101 |
+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
| 1102 |
+
attn_metadata.slot_mapping)
|
| 1103 |
+
# Without explicitly controlling the order, IndexByTensor operations
|
| 1104 |
+
# would be placed after `matmul W_KV_T` hindering the overlapping of
|
| 1105 |
+
# KvRmsNormRopeCache and SingleRope.
|
| 1106 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1107 |
+
cos,
|
| 1108 |
+
enabled=enable_multistream_mla)
|
| 1109 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1110 |
+
sin,
|
| 1111 |
+
enabled=enable_multistream_mla)
|
| 1112 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1113 |
+
decode_kv,
|
| 1114 |
+
enabled=enable_multistream_mla)
|
| 1115 |
+
|
| 1116 |
+
decode_ql_nope, decode_q_pe = \
|
| 1117 |
+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
| 1118 |
+
if self.running_in_graph:
|
| 1119 |
+
with npu_stream_switch("mla_secondary",
|
| 1120 |
+
0,
|
| 1121 |
+
enabled=enable_multistream_mla):
|
| 1122 |
+
npu_wait_tensor(decode_q_pe,
|
| 1123 |
+
decode_k_pe,
|
| 1124 |
+
enabled=enable_multistream_mla)
|
| 1125 |
+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
| 1126 |
+
else:
|
| 1127 |
+
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
| 1128 |
+
attn_metadata.decode.input_positions,
|
| 1129 |
+
decode_q_pe.contiguous(),
|
| 1130 |
+
decode_k_pe,
|
| 1131 |
+
max_seq_len=attn_metadata.decode.max_seq_lens)
|
| 1132 |
+
if has_prefill:
|
| 1133 |
+
assert attn_metadata.prefill is not None
|
| 1134 |
+
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
| 1135 |
+
.view(-1, self.num_heads, self.qk_head_dim)
|
| 1136 |
+
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
| 1137 |
+
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
| 1138 |
+
if self.torchair_graph_enabled:
|
| 1139 |
+
num_tokens = prefill_hs_or_q_c.shape[0]
|
| 1140 |
+
seq_len = self.rotary_emb.max_position_embeddings * \
|
| 1141 |
+
getattr(self.rotary_emb, "scaling_factor", 1)
|
| 1142 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
| 1143 |
+
dtype=prefill_q_pe.dtype)
|
| 1144 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
| 1145 |
+
dtype=prefill_q_pe.dtype)
|
| 1146 |
+
cos = cos[attn_metadata.prefill.input_positions]
|
| 1147 |
+
sin = sin[attn_metadata.prefill.input_positions]
|
| 1148 |
+
cos = cos[:, None, None, :]
|
| 1149 |
+
sin = sin[:, None, None, :]
|
| 1150 |
+
|
| 1151 |
+
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
| 1152 |
+
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
| 1153 |
+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
| 1154 |
+
attn_metadata.slot_mapping)
|
| 1155 |
+
|
| 1156 |
+
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
| 1157 |
+
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
|
| 1158 |
+
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
| 1159 |
+
-1)
|
| 1160 |
+
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
| 1161 |
+
else:
|
| 1162 |
+
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
| 1163 |
+
attn_metadata.prefill.input_positions,
|
| 1164 |
+
prefill_q_pe.contiguous(),
|
| 1165 |
+
prefill_k_pe,
|
| 1166 |
+
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
| 1167 |
+
if self.torchair_graph_enabled:
|
| 1168 |
+
if len(kv_cache) > 0 and kv_cache[0].numel(
|
| 1169 |
+
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 1170 |
+
slots = attn_metadata.slot_mapping
|
| 1171 |
+
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
| 1172 |
+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
| 1173 |
+
num_tokens, self.num_kv_heads, -1),
|
| 1174 |
+
value=prefill_k_pe,
|
| 1175 |
+
key_cache=kv_cache[0],
|
| 1176 |
+
value_cache=kv_cache[1],
|
| 1177 |
+
slot_indices=slots)
|
| 1178 |
+
elif kv_cache.numel() > 0:
|
| 1179 |
+
key = torch.cat([
|
| 1180 |
+
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
|
| 1181 |
+
k_pe
|
| 1182 |
+
],
|
| 1183 |
+
dim=2)
|
| 1184 |
+
torch_npu._npu_reshape_and_cache_siso(
|
| 1185 |
+
key=key,
|
| 1186 |
+
key_cache=kv_cache,
|
| 1187 |
+
slot_indices=attn_metadata.slot_mapping.flatten())
|
| 1188 |
+
if has_prefill:
|
| 1189 |
+
# FIX: aicore move should be also placed on the comm stream in dbo,
|
| 1190 |
+
# otherwise it may affect the accuracy
|
| 1191 |
+
# TODO: use an elegant way to overlap
|
| 1192 |
+
output_prefill = self._forward_prefill(prefill_q,
|
| 1193 |
+
prefill_k_c_normed,
|
| 1194 |
+
prefill_k_pe, kv_cache,
|
| 1195 |
+
attn_metadata)
|
| 1196 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1197 |
+
if current_ms_metadata is not None:
|
| 1198 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1199 |
+
output[num_decode_tokens:] = output_prefill
|
| 1200 |
+
current_ms_metadata.after_comm_event.record()
|
| 1201 |
+
else:
|
| 1202 |
+
output[num_decode_tokens:] = output_prefill
|
| 1203 |
+
|
| 1204 |
+
if has_decode:
|
| 1205 |
+
if self.running_in_graph:
|
| 1206 |
+
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
| 1207 |
+
decode_k_nope, decode_k_pe,
|
| 1208 |
+
kv_cache, attn_metadata,
|
| 1209 |
+
enable_multistream_mla)
|
| 1210 |
+
else:
|
| 1211 |
+
output_decode = self._forward_decode(decode_ql_nope,
|
| 1212 |
+
decode_q_pe,
|
| 1213 |
+
decode_k_nope,
|
| 1214 |
+
decode_k_pe, kv_cache,
|
| 1215 |
+
attn_metadata)
|
| 1216 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1217 |
+
if current_ms_metadata is not None:
|
| 1218 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1219 |
+
output[:num_decode_tokens] = output_decode
|
| 1220 |
+
current_ms_metadata.after_comm_event.record()
|
| 1221 |
+
else:
|
| 1222 |
+
output[:num_decode_tokens] = output_decode
|
| 1223 |
+
|
| 1224 |
+
return output_padded
|
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
from .pangu_reasoning_parser import PanguReasoningParser
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"PanguReasoningParser"
|
| 6 |
+
]
|
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 4 |
+
|
| 5 |
+
from collections.abc import Sequence
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
|
| 8 |
+
from transformers import PreTrainedTokenizerBase
|
| 9 |
+
|
| 10 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 11 |
+
DeltaMessage)
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
| 14 |
+
|
| 15 |
+
logger = init_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@ReasoningParserManager.register_module("pangu")
|
| 19 |
+
class PanguReasoningParser(ReasoningParser):
|
| 20 |
+
"""
|
| 21 |
+
Reasoning parser for Pangu model.
|
| 22 |
+
|
| 23 |
+
The Pangu model uses [unused16]...[unused17] tokens to denote reasoning
|
| 24 |
+
text. This parser extracts the reasoning content from the model output.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
start_token_id: int
|
| 28 |
+
end_token_id: int
|
| 29 |
+
|
| 30 |
+
start_token: str = "[unused16]"
|
| 31 |
+
end_token: str = "[unused17]"
|
| 32 |
+
|
| 33 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
| 34 |
+
super().__init__(tokenizer)
|
| 35 |
+
|
| 36 |
+
if not self.model_tokenizer:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"The model tokenizer must be passed to the ReasoningParser "
|
| 39 |
+
"constructor during construction.")
|
| 40 |
+
|
| 41 |
+
self.start_token_id = self.vocab.get(self.start_token)
|
| 42 |
+
self.end_token_id = self.vocab.get(self.end_token)
|
| 43 |
+
if self.start_token_id is None or self.end_token_id is None:
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
"Pangu reasoning parser could not locate think start/end "
|
| 46 |
+
"tokens in the tokenizer!")
|
| 47 |
+
|
| 48 |
+
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
| 49 |
+
return self.end_token_id in input_ids
|
| 50 |
+
|
| 51 |
+
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
| 52 |
+
"""
|
| 53 |
+
Extract the content after the end tokens
|
| 54 |
+
"""
|
| 55 |
+
if self.end_token_id not in input_ids[:-1]:
|
| 56 |
+
return []
|
| 57 |
+
else:
|
| 58 |
+
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
| 59 |
+
|
| 60 |
+
def extract_reasoning_content_streaming(
|
| 61 |
+
self,
|
| 62 |
+
previous_text: str,
|
| 63 |
+
current_text: str,
|
| 64 |
+
delta_text: str,
|
| 65 |
+
previous_token_ids: Sequence[int],
|
| 66 |
+
current_token_ids: Sequence[int],
|
| 67 |
+
delta_token_ids: Sequence[int],
|
| 68 |
+
) -> Union[DeltaMessage, None]:
|
| 69 |
+
"""
|
| 70 |
+
Extract reasoning content from a delta message.
|
| 71 |
+
Handles streaming output where previous + delta = current.
|
| 72 |
+
Uses token IDs for faster processing.
|
| 73 |
+
For text [unused16]abc[unused17]xyz:
|
| 74 |
+
- 'abc' goes to reasoning_content
|
| 75 |
+
- 'xyz' goes to content
|
| 76 |
+
"""
|
| 77 |
+
# Skip single special tokens
|
| 78 |
+
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
| 79 |
+
self.start_token_id, self.end_token_id
|
| 80 |
+
]):
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
# Check if [unused16] is present in previous or delta.
|
| 84 |
+
# Keep compatibility with models that don't generate [unused16] tokens.
|
| 85 |
+
if self.start_token_id in previous_token_ids:
|
| 86 |
+
if self.end_token_id in delta_token_ids:
|
| 87 |
+
# [unused16] in previous, [unused17] in delta,
|
| 88 |
+
# extract reasoning content
|
| 89 |
+
end_index = delta_text.find(self.end_token)
|
| 90 |
+
reasoning_content = delta_text[:end_index]
|
| 91 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 92 |
+
return DeltaMessage(
|
| 93 |
+
reasoning_content=reasoning_content,
|
| 94 |
+
content=content if content else None,
|
| 95 |
+
)
|
| 96 |
+
elif self.end_token_id in previous_token_ids:
|
| 97 |
+
# [unused16] in previous, [unused17] in previous,
|
| 98 |
+
# reasoning content continues
|
| 99 |
+
return DeltaMessage(content=delta_text)
|
| 100 |
+
else:
|
| 101 |
+
# [unused16] in previous, no [unused17] in previous or delta,
|
| 102 |
+
# reasoning content continues
|
| 103 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 104 |
+
elif self.start_token_id in delta_token_ids:
|
| 105 |
+
if self.end_token_id in delta_token_ids:
|
| 106 |
+
# [unused16] in delta, [unused17] in delta, extract reasoning content
|
| 107 |
+
start_index = delta_text.find(self.start_token)
|
| 108 |
+
end_index = delta_text.find(self.end_token)
|
| 109 |
+
reasoning_content = delta_text[start_index +
|
| 110 |
+
len(self.start_token):end_index]
|
| 111 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 112 |
+
return DeltaMessage(
|
| 113 |
+
reasoning_content=reasoning_content,
|
| 114 |
+
content=content if content else None,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
# [unused16] in delta, no [unused17] in delta,
|
| 118 |
+
# reasoning content continues
|
| 119 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 120 |
+
else:
|
| 121 |
+
# No [unused16] in previous or delta, also need to check for [unused17].
|
| 122 |
+
# Because the model may have generated [unused17] without [unused16]
|
| 123 |
+
if self.end_token_id in delta_token_ids:
|
| 124 |
+
# [unused17] in delta with more tokens,
|
| 125 |
+
# extract reasoning content and content
|
| 126 |
+
end_index = delta_text.find(self.end_token)
|
| 127 |
+
reasoning_content = delta_text[:end_index]
|
| 128 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 129 |
+
return DeltaMessage(
|
| 130 |
+
reasoning_content=reasoning_content,
|
| 131 |
+
content=content if content else None,
|
| 132 |
+
)
|
| 133 |
+
elif self.end_token_id in previous_token_ids:
|
| 134 |
+
# [unused17] in previous, thinking content ends
|
| 135 |
+
return DeltaMessage(content=delta_text)
|
| 136 |
+
else:
|
| 137 |
+
# no [unused17] in previous or delta, reasoning content continues
|
| 138 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 139 |
+
|
| 140 |
+
def extract_reasoning_content(
|
| 141 |
+
self, model_output: str, request: ChatCompletionRequest
|
| 142 |
+
) -> tuple[Optional[str], Optional[str]]:
|
| 143 |
+
"""
|
| 144 |
+
Extract reasoning content from the model output.
|
| 145 |
+
|
| 146 |
+
For text [unused16]abc[unused17]xyz:
|
| 147 |
+
- 'abc' goes to reasoning_content
|
| 148 |
+
- 'xyz' goes to content
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
tuple[Optional[str], Optional[str]]: reasoning content and content
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
# Check if the start token is present in the model output, remove it
|
| 155 |
+
# if it is present.
|
| 156 |
+
model_output_parts = model_output.partition(self.start_token)
|
| 157 |
+
model_output = model_output_parts[2] if model_output_parts[
|
| 158 |
+
1] else model_output_parts[0]
|
| 159 |
+
|
| 160 |
+
# Thus we assume the reasoning content is always at the start.
|
| 161 |
+
if self.end_token not in model_output:
|
| 162 |
+
return model_output, None
|
| 163 |
+
else:
|
| 164 |
+
reasoning_content, _, content = model_output.partition(
|
| 165 |
+
self.end_token)
|
| 166 |
+
# If the end token is not found, return the model output as is.
|
| 167 |
+
# It should not happen since we already checked for the presence
|
| 168 |
+
# of the end token.
|
| 169 |
+
# If generation stops right after end-of-think, return null content
|
| 170 |
+
final_content = content or None
|
| 171 |
+
return reasoning_content, final_content
|
inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
from .pangu_tool_parser import PanguToolParser
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"PanguToolParser"
|
| 6 |
+
]
|
inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The vLLM team.
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from json import JSONDecodeError, JSONDecoder
|
| 7 |
+
from typing import Dict, List, Sequence, Union, Optional
|
| 8 |
+
from pydantic import Field
|
| 9 |
+
import partial_json_parser
|
| 10 |
+
from partial_json_parser.core.options import Allow
|
| 11 |
+
from transformers import PreTrainedTokenizerBase
|
| 12 |
+
|
| 13 |
+
from vllm.entrypoints.chat_utils import random_tool_call_id
|
| 14 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (
|
| 15 |
+
extract_intermediate_diff)
|
| 16 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 17 |
+
DeltaFunctionCall, DeltaMessage,
|
| 18 |
+
DeltaToolCall,
|
| 19 |
+
ExtractedToolCallInformation,
|
| 20 |
+
FunctionCall, ToolCall,
|
| 21 |
+
)
|
| 22 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 23 |
+
ToolParser, ToolParserManager)
|
| 24 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
| 25 |
+
is_complete_json)
|
| 26 |
+
from vllm.logger import init_logger
|
| 27 |
+
import os
|
| 28 |
+
|
| 29 |
+
logger = init_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@ToolParserManager.register_module("pangu")
|
| 33 |
+
class PanguToolParser(ToolParser):
|
| 34 |
+
|
| 35 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase, enable_reasoning=False):
|
| 36 |
+
super().__init__(tokenizer)
|
| 37 |
+
|
| 38 |
+
# initialize properties used for state when parsing tool calls in
|
| 39 |
+
# streaming mode
|
| 40 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 41 |
+
self.current_tool_id: int = -1
|
| 42 |
+
self.current_tool_name_sent: bool = False
|
| 43 |
+
self.streamed_args_for_tool: List[str] = [
|
| 44 |
+
] # map what has been streamed for each tool so far to a list
|
| 45 |
+
|
| 46 |
+
self.tool_call_start_token = "[unused11]"
|
| 47 |
+
self.tool_call_end_token = "[unused12]"
|
| 48 |
+
self.pattern = re.escape(self.tool_call_start_token) \
|
| 49 |
+
+ "(.*?)" + re.escape(self.tool_call_end_token)
|
| 50 |
+
self.tool_call_regex = re.compile(self.pattern, re.DOTALL)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
self.tool_call_start_token_id = self.vocab.get(
|
| 54 |
+
self.tool_call_start_token)
|
| 55 |
+
self.tool_call_end_token_id = self.vocab.get(
|
| 56 |
+
self.tool_call_end_token)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if (self.tool_call_start_token_id is None
|
| 60 |
+
or self.tool_call_end_token_id is None):
|
| 61 |
+
raise RuntimeError(
|
| 62 |
+
"Pangu Tool parser could not locate tool calls start/end "
|
| 63 |
+
"tokens in the tokenizer!")
|
| 64 |
+
self.is_complete = []
|
| 65 |
+
self.text_after_start_token = ""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_tool_calls(
|
| 69 |
+
self, model_output: str,
|
| 70 |
+
request: ChatCompletionRequest
|
| 71 |
+
) -> ExtractedToolCallInformation:
|
| 72 |
+
"""
|
| 73 |
+
Extract the tool calls from a complete model response.
|
| 74 |
+
"""
|
| 75 |
+
# case -- if a tool call token is not present, return a text response
|
| 76 |
+
if not (self.tool_call_start_token in model_output and \
|
| 77 |
+
model_output.find(self.tool_call_end_token) != -1):
|
| 78 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 79 |
+
tool_calls=[],
|
| 80 |
+
content=model_output)
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
raw_function_calls = []
|
| 84 |
+
# use a regex to find the tool call between the tags
|
| 85 |
+
function_call_tuples = self.tool_call_regex.findall(model_output)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# load the JSON, and then use it to build the Function and
|
| 89 |
+
# Tool Call
|
| 90 |
+
for function_call_str in function_call_tuples:
|
| 91 |
+
function_call = json.loads(function_call_str)
|
| 92 |
+
raw_function_calls.extend(function_call)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
tool_calls: List[ToolCall] = [
|
| 96 |
+
ToolCall(
|
| 97 |
+
type="function",
|
| 98 |
+
function=FunctionCall(
|
| 99 |
+
name=function_call["name"],
|
| 100 |
+
# function call args are JSON but as a string
|
| 101 |
+
arguments=json.dumps(function_call["arguments"] \
|
| 102 |
+
if "arguments" in function_call \
|
| 103 |
+
else function_call["parameters"], ensure_ascii=False)))
|
| 104 |
+
for function_call in raw_function_calls
|
| 105 |
+
]
|
| 106 |
+
content = model_output[:model_output.
|
| 107 |
+
find(self.tool_call_start_token)]
|
| 108 |
+
|
| 109 |
+
# get any content before the tool call
|
| 110 |
+
ret = ExtractedToolCallInformation(tools_called=True,
|
| 111 |
+
tool_calls=tool_calls,
|
| 112 |
+
content=content if content else None)
|
| 113 |
+
|
| 114 |
+
return ret
|
| 115 |
+
|
| 116 |
+
except Exception:
|
| 117 |
+
logger.exception("Error in extracting tool call from response.")
|
| 118 |
+
# return information to just treat the tool call as regular JSON
|
| 119 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 120 |
+
tool_calls=[],
|
| 121 |
+
content=model_output)
|
| 122 |
+
|
| 123 |
+
def extract_tool_calls_streaming(
|
| 124 |
+
self,
|
| 125 |
+
previous_text: str,
|
| 126 |
+
current_text: str,
|
| 127 |
+
delta_text: str,
|
| 128 |
+
previous_token_ids: Sequence[int],
|
| 129 |
+
current_token_ids: Sequence[int],
|
| 130 |
+
delta_token_ids: Sequence[int],
|
| 131 |
+
request: ChatCompletionRequest,
|
| 132 |
+
) -> Union[DeltaMessage, None]:
|
| 133 |
+
|
| 134 |
+
if (self.tool_call_end_token_id in delta_token_ids
|
| 135 |
+
and len(delta_token_ids) == 1):
|
| 136 |
+
# if it's the only token, return None, so we don't send a chat
|
| 137 |
+
# completion and don't send a control token
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
if (self.tool_call_end_token in current_text
|
| 141 |
+
and self.tool_call_end_token not in delta_text):
|
| 142 |
+
return DeltaMessage(content=delta_text)
|
| 143 |
+
|
| 144 |
+
if self.tool_call_start_token not in current_text:
|
| 145 |
+
return DeltaMessage(content=delta_text)
|
| 146 |
+
|
| 147 |
+
if self.tool_call_start_token in delta_text:
|
| 148 |
+
texts = delta_text.split(self.tool_call_start_token)
|
| 149 |
+
text_before_start_token = texts[0]
|
| 150 |
+
if text_before_start_token:
|
| 151 |
+
return DeltaMessage(content=text_before_start_token)
|
| 152 |
+
|
| 153 |
+
if (self.tool_call_start_token_id in delta_token_ids
|
| 154 |
+
and len(delta_token_ids) == 1):
|
| 155 |
+
# if it's the only token, return None, so we don't send a chat
|
| 156 |
+
# completion and don't send a control token
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 160 |
+
# sent yet, don't allow sending
|
| 161 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 162 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 163 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 164 |
+
else Allow.ALL & ~Allow.STR
|
| 165 |
+
try:
|
| 166 |
+
|
| 167 |
+
tool_call_portion = current_text.split(
|
| 168 |
+
self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0]
|
| 169 |
+
try:
|
| 170 |
+
tool_call_arr: list[dict] = partial_json_parser.loads(
|
| 171 |
+
tool_call_portion, flags)
|
| 172 |
+
|
| 173 |
+
self.is_complete.append(
|
| 174 |
+
is_complete_json(tool_call_portion))
|
| 175 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 176 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
# select as the current tool call the one we're on the state at
|
| 180 |
+
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
| 181 |
+
if len(tool_call_arr) > 0 else {}
|
| 182 |
+
|
| 183 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 184 |
+
# only the array brackets, stream nothing
|
| 185 |
+
if len(tool_call_arr) == 0:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
# case: we are starting a new tool in the array
|
| 189 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 190 |
+
elif (len(tool_call_arr) > 0
|
| 191 |
+
and len(tool_call_arr) > self.current_tool_id + 1):
|
| 192 |
+
|
| 193 |
+
# if we're moving on to a new call, first make sure we
|
| 194 |
+
# haven't missed anything in the previous one that was
|
| 195 |
+
# auto-generated due to JSON completions, but wasn't
|
| 196 |
+
# streamed to the client yet.
|
| 197 |
+
if self.current_tool_id >= 0:
|
| 198 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 199 |
+
if cur_arguments:
|
| 200 |
+
cur_args_json = json.dumps(cur_arguments,
|
| 201 |
+
ensure_ascii=False)
|
| 202 |
+
sent = len(
|
| 203 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 204 |
+
argument_diff = cur_args_json[sent:]
|
| 205 |
+
|
| 206 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 207 |
+
delta = DeltaMessage(tool_calls=[
|
| 208 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 209 |
+
function=DeltaFunctionCall(
|
| 210 |
+
arguments=argument_diff).
|
| 211 |
+
model_dump(exclude_none=True))
|
| 212 |
+
])
|
| 213 |
+
self.streamed_args_for_tool[
|
| 214 |
+
self.current_tool_id] += argument_diff
|
| 215 |
+
else:
|
| 216 |
+
delta = None
|
| 217 |
+
else:
|
| 218 |
+
delta = None
|
| 219 |
+
# re-set stuff pertaining to progress in the current tool
|
| 220 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 221 |
+
self.current_tool_name_sent = False
|
| 222 |
+
self.streamed_args_for_tool.append("")
|
| 223 |
+
self.is_complete = []
|
| 224 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 225 |
+
return delta
|
| 226 |
+
|
| 227 |
+
# if the current tool name hasn't been sent, send if available
|
| 228 |
+
# - otherwise send nothing
|
| 229 |
+
elif not self.current_tool_name_sent:
|
| 230 |
+
function_name = current_tool_call.get("name")
|
| 231 |
+
if function_name:
|
| 232 |
+
delta = DeltaMessage(tool_calls=[
|
| 233 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 234 |
+
type="function",
|
| 235 |
+
id=random_tool_call_id(),
|
| 236 |
+
function=DeltaFunctionCall(
|
| 237 |
+
name=function_name).model_dump(
|
| 238 |
+
exclude_none=True))
|
| 239 |
+
])
|
| 240 |
+
self.current_tool_name_sent = True
|
| 241 |
+
else:
|
| 242 |
+
delta = None
|
| 243 |
+
|
| 244 |
+
# now we know we're on the same tool call and we're streaming
|
| 245 |
+
# arguments
|
| 246 |
+
else:
|
| 247 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 248 |
+
delta = None
|
| 249 |
+
if (self.is_complete[-1] and not cur_arguments
|
| 250 |
+
and not self.streamed_args_for_tool[-1]):
|
| 251 |
+
argument_diff = "{}"
|
| 252 |
+
delta = DeltaMessage(tool_calls=[
|
| 253 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 254 |
+
function=DeltaFunctionCall(
|
| 255 |
+
arguments=argument_diff).
|
| 256 |
+
model_dump(exclude_none=True))
|
| 257 |
+
])
|
| 258 |
+
self.streamed_args_for_tool[
|
| 259 |
+
self.current_tool_id] += argument_diff
|
| 260 |
+
|
| 261 |
+
if cur_arguments:
|
| 262 |
+
sent = len(
|
| 263 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 264 |
+
cur_args_json = json.dumps(cur_arguments,
|
| 265 |
+
ensure_ascii=False)
|
| 266 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 267 |
+
self.current_tool_id].get("arguments")
|
| 268 |
+
|
| 269 |
+
argument_diff = None
|
| 270 |
+
if self.is_complete[-1]:
|
| 271 |
+
argument_diff = cur_args_json[sent:]
|
| 272 |
+
elif prev_arguments:
|
| 273 |
+
prev_args_json = json.dumps(prev_arguments,
|
| 274 |
+
ensure_ascii=False)
|
| 275 |
+
if cur_args_json != prev_args_json:
|
| 276 |
+
|
| 277 |
+
prefix = find_common_prefix(
|
| 278 |
+
prev_args_json, cur_args_json)
|
| 279 |
+
argument_diff = prefix[sent:]
|
| 280 |
+
|
| 281 |
+
if argument_diff is not None:
|
| 282 |
+
delta = DeltaMessage(tool_calls=[
|
| 283 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 284 |
+
function=DeltaFunctionCall(
|
| 285 |
+
arguments=argument_diff).
|
| 286 |
+
model_dump(exclude_none=True))
|
| 287 |
+
])
|
| 288 |
+
self.streamed_args_for_tool[
|
| 289 |
+
self.current_tool_id] += argument_diff
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 293 |
+
return delta
|
| 294 |
+
|
| 295 |
+
except Exception:
|
| 296 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 297 |
+
logger.debug(
|
| 298 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 299 |
+
"error")
|
| 300 |
+
return None
|
inference/vllm_ascend/envs.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
|
| 6 |
+
# Copyright 2023 The vLLM team.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
#
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Callable, Dict
|
| 23 |
+
|
| 24 |
+
# The begin-* and end* here are used by the documentation generator
|
| 25 |
+
# to extract the used env vars.
|
| 26 |
+
|
| 27 |
+
# begin-env-vars-definition
|
| 28 |
+
|
| 29 |
+
env_variables: Dict[str, Callable[[], Any]] = {
|
| 30 |
+
# max compile thread number for package building. Usually, it is set to
|
| 31 |
+
# the number of CPU cores. If not set, the default value is None, which
|
| 32 |
+
# means all number of CPU cores will be used.
|
| 33 |
+
"MAX_JOBS":
|
| 34 |
+
lambda: os.getenv("MAX_JOBS", None),
|
| 35 |
+
# The build type of the package. It can be one of the following values:
|
| 36 |
+
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
|
| 37 |
+
"CMAKE_BUILD_TYPE":
|
| 38 |
+
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
| 39 |
+
# Whether to compile custom kernels. If not set, the default value is True.
|
| 40 |
+
# If set to False, the custom kernels will not be compiled. Please note that
|
| 41 |
+
# the sleep mode feature will be disabled as well if custom kernels are not
|
| 42 |
+
# compiled.
|
| 43 |
+
"COMPILE_CUSTOM_KERNELS":
|
| 44 |
+
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
| 45 |
+
# The CXX compiler used for compiling the package. If not set, the default
|
| 46 |
+
# value is None, which means the system default CXX compiler will be used.
|
| 47 |
+
"CXX_COMPILER":
|
| 48 |
+
lambda: os.getenv("CXX_COMPILER", None),
|
| 49 |
+
# The C compiler used for compiling the package. If not set, the default
|
| 50 |
+
# value is None, which means the system default C compiler will be used.
|
| 51 |
+
"C_COMPILER":
|
| 52 |
+
lambda: os.getenv("C_COMPILER", None),
|
| 53 |
+
# The version of the Ascend chip. If not set, the default value is
|
| 54 |
+
# ASCEND910B1. It's used for package building. Please make sure that the
|
| 55 |
+
# version is correct.
|
| 56 |
+
"SOC_VERSION":
|
| 57 |
+
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
| 58 |
+
# If set, vllm-ascend will print verbose logs during compilation
|
| 59 |
+
"VERBOSE":
|
| 60 |
+
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
| 61 |
+
# The home path for CANN toolkit. If not set, the default value is
|
| 62 |
+
# /usr/local/Ascend/ascend-toolkit/latest
|
| 63 |
+
"ASCEND_HOME_PATH":
|
| 64 |
+
lambda: os.getenv("ASCEND_HOME_PATH", None),
|
| 65 |
+
# The path for HCCN Tool, the tool will be called by disaggregated prefilling
|
| 66 |
+
# case.
|
| 67 |
+
"HCCN_PATH":
|
| 68 |
+
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
|
| 69 |
+
# The path for HCCL library, it's used by pyhccl communicator backend. If
|
| 70 |
+
# not set, the default value is libhccl.so。
|
| 71 |
+
"HCCL_SO_PATH":
|
| 72 |
+
# The prefill device id for disaggregated prefilling case.
|
| 73 |
+
lambda: os.environ.get("HCCL_SO_PATH", None),
|
| 74 |
+
"PROMPT_DEVICE_ID":
|
| 75 |
+
lambda: os.getenv("PROMPT_DEVICE_ID", None),
|
| 76 |
+
# The decode device id for disaggregated prefilling case.
|
| 77 |
+
"DECODE_DEVICE_ID":
|
| 78 |
+
lambda: os.getenv("DECODE_DEVICE_ID", None),
|
| 79 |
+
# The port number for llmdatadist communication. If not set, the default
|
| 80 |
+
# value is 26000.
|
| 81 |
+
"LLMDATADIST_COMM_PORT":
|
| 82 |
+
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
|
| 83 |
+
# The wait time for llmdatadist sync cache. If not set, the default value is
|
| 84 |
+
# 5000ms.
|
| 85 |
+
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
|
| 86 |
+
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
|
| 87 |
+
# The version of vllm is installed. This value is used for developers who
|
| 88 |
+
# installed vllm from source locally. In this case, the version of vllm is
|
| 89 |
+
# usually changed. For example, if the version of vllm is "0.9.0", but when
|
| 90 |
+
# it's installed from source, the version of vllm is usually set to "0.9.1".
|
| 91 |
+
# In this case, developers need to set this value to "0.9.0" to make sure
|
| 92 |
+
# that the correct package is installed.
|
| 93 |
+
"VLLM_VERSION":
|
| 94 |
+
lambda: os.getenv("VLLM_VERSION", None),
|
| 95 |
+
# Whether to enable the trace recompiles from pytorch.
|
| 96 |
+
"VLLM_ASCEND_TRACE_RECOMPILES":
|
| 97 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
| 98 |
+
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
|
| 99 |
+
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
|
| 100 |
+
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
| 101 |
+
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
| 102 |
+
),
|
| 103 |
+
"VLLM_ASCEND_ENABLE_DBO":
|
| 104 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
|
| 105 |
+
# Whether to enable the model execute time observe profile. Disable it when
|
| 106 |
+
# running vllm ascend in production environment.
|
| 107 |
+
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
| 108 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
| 109 |
+
),
|
| 110 |
+
# MOE_ALL2ALL_BUFFER:
|
| 111 |
+
# 0: default, normal init.
|
| 112 |
+
# 1: enable moe_all2all_buffer.
|
| 113 |
+
"MOE_ALL2ALL_BUFFER":
|
| 114 |
+
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
|
| 115 |
+
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
| 116 |
+
# training, the optimized model may not be suitable. In this case, set this
|
| 117 |
+
# value to False to disable the optimized model.
|
| 118 |
+
"USE_OPTIMIZED_MODEL":
|
| 119 |
+
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
| 120 |
+
# SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
|
| 121 |
+
# In theory, it should have better performance than select_experts.
|
| 122 |
+
# Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
|
| 123 |
+
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
|
| 124 |
+
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
|
| 125 |
+
# The tolerance of the kv cache size, if the difference between the
|
| 126 |
+
# actual kv cache size and the cached kv cache size is less than this value,
|
| 127 |
+
# then the cached kv cache size will be used.
|
| 128 |
+
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
|
| 129 |
+
lambda: int(
|
| 130 |
+
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
|
| 131 |
+
# Whether to enable the topk optimization. It's disabled by default for experimental support
|
| 132 |
+
# We'll make it enabled by default in the future.
|
| 133 |
+
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
|
| 134 |
+
lambda: bool(
|
| 135 |
+
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
|
| 136 |
+
# Whether to enable top n sigma sampling
|
| 137 |
+
"VLLM_ASCEND_ENABLE_TOP_N_SIGMA":
|
| 138 |
+
lambda: bool(
|
| 139 |
+
int(os.getenv("VLLM_ASCEND_ENABLE_TOP_N_SIGMA", '0'))),
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# end-env-vars-definition
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def __getattr__(name: str):
|
| 146 |
+
# lazy evaluation of environment variables
|
| 147 |
+
if name in env_variables:
|
| 148 |
+
return env_variables[name]()
|
| 149 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def __dir__():
|
| 153 |
+
return list(env_variables.keys())
|
inference/vllm_ascend/models/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vllm import ModelRegistry
|
| 2 |
+
|
| 3 |
+
import vllm_ascend.envs as envs
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def register_model():
|
| 7 |
+
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
|
| 8 |
+
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
| 9 |
+
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
| 10 |
+
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
|
| 11 |
+
from .open_pangu import PanguUltraMoEForCausalLM # noqa: F401
|
| 12 |
+
from .open_pangu import PanguEmbeddedForCausalLM # noqa: F401
|
| 13 |
+
from .qwen2_5_vl import \
|
| 14 |
+
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
|
| 15 |
+
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
|
| 16 |
+
|
| 17 |
+
ModelRegistry.register_model(
|
| 18 |
+
"DeepSeekMTPModel",
|
| 19 |
+
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
| 20 |
+
|
| 21 |
+
ModelRegistry.register_model(
|
| 22 |
+
"Qwen2VLForConditionalGeneration",
|
| 23 |
+
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
|
| 24 |
+
|
| 25 |
+
if envs.USE_OPTIMIZED_MODEL:
|
| 26 |
+
ModelRegistry.register_model(
|
| 27 |
+
"Qwen2_5_VLForConditionalGeneration",
|
| 28 |
+
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
ModelRegistry.register_model(
|
| 32 |
+
"Qwen2_5_VLForConditionalGeneration",
|
| 33 |
+
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if envs.VLLM_ASCEND_ENABLE_DBO:
|
| 37 |
+
ModelRegistry.register_model(
|
| 38 |
+
"DeepseekV2ForCausalLM",
|
| 39 |
+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
| 40 |
+
|
| 41 |
+
ModelRegistry.register_model(
|
| 42 |
+
"DeepseekV3ForCausalLM",
|
| 43 |
+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
ModelRegistry.register_model(
|
| 47 |
+
"DeepseekV2ForCausalLM",
|
| 48 |
+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
| 49 |
+
|
| 50 |
+
ModelRegistry.register_model(
|
| 51 |
+
"DeepseekV3ForCausalLM",
|
| 52 |
+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
| 53 |
+
|
| 54 |
+
ModelRegistry.register_model(
|
| 55 |
+
"Qwen3MoeForCausalLM",
|
| 56 |
+
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
| 57 |
+
|
| 58 |
+
ModelRegistry.register_model(
|
| 59 |
+
"PanguProMoEForCausalLM",
|
| 60 |
+
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
|
| 61 |
+
|
| 62 |
+
ModelRegistry.register_model(
|
| 63 |
+
"PanguUltraMoEForCausalLM",
|
| 64 |
+
"vllm_ascend.models.open_pangu:PanguUltraMoEForCausalLM")
|
| 65 |
+
|
| 66 |
+
ModelRegistry.register_model(
|
| 67 |
+
"PanguEmbeddedForCausalLM",
|
| 68 |
+
"vllm_ascend.models.open_pangu:PanguEmbeddedForCausalLM")
|
inference/vllm_ascend/models/open_pangu.py
ADDED
|
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 7 |
+
# and OPT implementations in this library. It has been modified from its
|
| 8 |
+
# original forms to accommodate minor architectural differences compared
|
| 9 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
|
| 23 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
| 24 |
+
import torch
|
| 25 |
+
import torch_npu
|
| 26 |
+
import vllm.envs as envs
|
| 27 |
+
from torch import nn
|
| 28 |
+
from transformers import PretrainedConfig
|
| 29 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 30 |
+
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
| 31 |
+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
| 32 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 33 |
+
get_tensor_model_parallel_world_size,
|
| 34 |
+
get_tp_group, split_tensor_along_last_dim,
|
| 35 |
+
tensor_model_parallel_all_gather,
|
| 36 |
+
tensor_model_parallel_all_reduce,
|
| 37 |
+
tensor_model_parallel_reduce_scatter)
|
| 38 |
+
from vllm.distributed.parallel_state import get_dp_group
|
| 39 |
+
from vllm.forward_context import get_forward_context
|
| 40 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 41 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 42 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 43 |
+
MergedColumnParallelLinear,
|
| 44 |
+
ReplicatedLinear,
|
| 45 |
+
RowParallelLinear,
|
| 46 |
+
UnquantizedLinearMethod,
|
| 47 |
+
QKVParallelLinear)
|
| 48 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 49 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 50 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope, _rotate_gptj
|
| 51 |
+
from vllm.model_executor.layers.sampler import get_sampler
|
| 52 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 53 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 54 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 55 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 56 |
+
from vllm.model_executor.models.utils import (
|
| 57 |
+
make_layers, maybe_prefix, extract_layer_index)
|
| 58 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 59 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 60 |
+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
| 61 |
+
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
| 62 |
+
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
| 63 |
+
from vllm_ascend.utils import dispose_tensor, npu_prefetch, get_fused_moe_state, FusedMoEState
|
| 64 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class OpenPanguMergedReplicatedLinear(ReplicatedLinear):
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
input_size: int,
|
| 72 |
+
output_sizes: list[int],
|
| 73 |
+
bias: bool = True,
|
| 74 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 75 |
+
prefix: str = "",
|
| 76 |
+
):
|
| 77 |
+
self.output_sizes = output_sizes
|
| 78 |
+
super().__init__(input_size,
|
| 79 |
+
sum(output_sizes),
|
| 80 |
+
bias=bias,
|
| 81 |
+
quant_config=quant_config,
|
| 82 |
+
prefix=prefix)
|
| 83 |
+
|
| 84 |
+
def weight_loader(self, param: torch.nn.Parameter,
|
| 85 |
+
loaded_weight: torch.Tensor, loaded_shard_id: int):
|
| 86 |
+
# With no support for GGUF format yet.
|
| 87 |
+
if getattr(param, "is_gguf_weight", False) or getattr(param, "is_gguf_weight_type", False):
|
| 88 |
+
raise ValueError('With no support for GGUF format yet.')
|
| 89 |
+
if loaded_shard_id >= len(self.output_sizes):
|
| 90 |
+
raise ValueError(f'loaded_shard_id {loaded_shard_id} >= len(self.output_sizes) {len(self.output_sizes)}.')
|
| 91 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
| 92 |
+
shard_size = self.output_sizes[loaded_shard_id]
|
| 93 |
+
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
|
| 94 |
+
if shard.size() != loaded_weight.size():
|
| 95 |
+
raise ValueError(f"Tried to load weights of size {loaded_weight.size()} "
|
| 96 |
+
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}.")
|
| 97 |
+
shard.copy_(loaded_weight)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class OpenPanguRowParallelLinearReplaceAllreduce(RowParallelLinear):
|
| 101 |
+
|
| 102 |
+
def forward(
|
| 103 |
+
self,
|
| 104 |
+
input_,
|
| 105 |
+
is_prefill=True
|
| 106 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
| 107 |
+
if self.input_is_parallel:
|
| 108 |
+
input_parallel = input_
|
| 109 |
+
else:
|
| 110 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 111 |
+
splitted_input = split_tensor_along_last_dim(
|
| 112 |
+
input_, num_partitions=self.tp_size)
|
| 113 |
+
input_parallel = splitted_input[tp_rank].contiguous()
|
| 114 |
+
|
| 115 |
+
# Matrix multiply.
|
| 116 |
+
if self.quant_method is None:
|
| 117 |
+
raise ValueError('self.quant_method is None.')
|
| 118 |
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
| 119 |
+
# bias will not get added more than once in TP>1 case)
|
| 120 |
+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
| 121 |
+
output_parallel = self.quant_method.apply(self,
|
| 122 |
+
input_parallel,
|
| 123 |
+
bias=bias_)
|
| 124 |
+
if self.reduce_results and self.tp_size > 1:
|
| 125 |
+
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
|
| 126 |
+
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
| 127 |
+
dim=0)
|
| 128 |
+
else:
|
| 129 |
+
output = tensor_model_parallel_all_reduce(output_parallel)
|
| 130 |
+
else:
|
| 131 |
+
output = output_parallel
|
| 132 |
+
|
| 133 |
+
output_bias = self.bias if self.skip_bias_add else None
|
| 134 |
+
|
| 135 |
+
if not self.return_bias:
|
| 136 |
+
return output
|
| 137 |
+
return output, output_bias
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class OpenPanguRowParallelLinear(RowParallelLinear):
|
| 141 |
+
|
| 142 |
+
def forward(
|
| 143 |
+
self,
|
| 144 |
+
input_,
|
| 145 |
+
is_prefill=True
|
| 146 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
| 147 |
+
return super().forward(input_)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class OpenPanguRotaryEmbedding(nn.Module):
|
| 151 |
+
def __init__(self,
|
| 152 |
+
head_size: int,
|
| 153 |
+
rotary_dim: int,
|
| 154 |
+
max_position_embeddings: int,
|
| 155 |
+
base: float,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.dim = rotary_dim
|
| 159 |
+
self.max_position_embeddings = max_position_embeddings
|
| 160 |
+
self.base = base
|
| 161 |
+
self._set_cos_sin_cache(
|
| 162 |
+
seq_len=max_position_embeddings,
|
| 163 |
+
device='npu',
|
| 164 |
+
dtype=torch.get_default_dtype(),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def _set_cos_sin_cache(self,
|
| 168 |
+
seq_len: int,
|
| 169 |
+
device: str,
|
| 170 |
+
dtype: torch.dtype
|
| 171 |
+
):
|
| 172 |
+
self.max_seq_len = seq_len
|
| 173 |
+
inv_freq = 1.0 / (
|
| 174 |
+
self.base
|
| 175 |
+
** (torch.arange(0, self.dim, 2, dtype=torch.float32, device='npu') / self.dim)
|
| 176 |
+
)
|
| 177 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 178 |
+
t = torch.arange(seq_len, device='npu', dtype=torch.float32)
|
| 179 |
+
freqs = torch.outer(t, inv_freq)
|
| 180 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 181 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 182 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 183 |
+
|
| 184 |
+
def forward(self,
|
| 185 |
+
positions: torch.Tensor,
|
| 186 |
+
query: torch.Tensor,
|
| 187 |
+
key: torch.Tensor,
|
| 188 |
+
offsets: Optional[torch.Tensor] = None,
|
| 189 |
+
max_seq_len: Optional[int] = None,
|
| 190 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 191 |
+
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
| 192 |
+
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
| 193 |
+
idx = torch.add(positions, offsets) if offsets is not None else positions
|
| 194 |
+
cos = self.cos_cached[idx]
|
| 195 |
+
sin = self.sin_cached[idx]
|
| 196 |
+
# Adapt: adapt cos and sin shape
|
| 197 |
+
cos = cos.view(-1, 1, cos.shape[-1])
|
| 198 |
+
sin = sin.view(-1, 1, sin.shape[-1])
|
| 199 |
+
# Adapt end.
|
| 200 |
+
query_rot = query * cos + _rotate_gptj(query) * sin
|
| 201 |
+
if key is not None:
|
| 202 |
+
key_rot = key * cos + _rotate_gptj(key) * sin
|
| 203 |
+
return query_rot, key_rot
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class OpenPanguSiluAndMul(SiluAndMul):
|
| 207 |
+
|
| 208 |
+
def __init__(self,
|
| 209 |
+
*,
|
| 210 |
+
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.weight_scale = weight_scale
|
| 213 |
+
|
| 214 |
+
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
|
| 215 |
+
torch.Tensor]]):
|
| 216 |
+
if isinstance(x, tuple):
|
| 217 |
+
if self.weight_scale is None:
|
| 218 |
+
raise ValueError('self.weight_scale is None.')
|
| 219 |
+
quantized_x, dynamic_scale = x
|
| 220 |
+
return torch_npu.npu_dequant_swiglu_quant(
|
| 221 |
+
x=quantized_x,
|
| 222 |
+
weight_scale=self.weight_scale(),
|
| 223 |
+
activation_scale=dynamic_scale,
|
| 224 |
+
activate_left=True,
|
| 225 |
+
quant_mode=1)
|
| 226 |
+
else:
|
| 227 |
+
return super().forward_oot(x)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def check_ffn_act_fn(act_fn: str):
|
| 231 |
+
if act_fn != "silu":
|
| 232 |
+
raise ValueError(
|
| 233 |
+
f"Unsupported activation: {act_fn}. Only silu is supported for now.")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class OpenPanguMLP(nn.Module):
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
hidden_size: int,
|
| 241 |
+
intermediate_size: int,
|
| 242 |
+
hidden_act: str,
|
| 243 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 244 |
+
bias: bool = False,
|
| 245 |
+
reduce_results: bool = True,
|
| 246 |
+
force_replicate: bool = False,
|
| 247 |
+
prefix: str = "",
|
| 248 |
+
) -> None:
|
| 249 |
+
super().__init__()
|
| 250 |
+
if not force_replicate:
|
| 251 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 252 |
+
hidden_size, [intermediate_size] * 2,
|
| 253 |
+
bias=bias,
|
| 254 |
+
quant_config=quant_config,
|
| 255 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 256 |
+
self.down_proj = RowParallelLinear(intermediate_size,
|
| 257 |
+
hidden_size,
|
| 258 |
+
bias=bias,
|
| 259 |
+
quant_config=quant_config,
|
| 260 |
+
reduce_results=reduce_results,
|
| 261 |
+
prefix=f"{prefix}.down_proj")
|
| 262 |
+
else:
|
| 263 |
+
self.gate_up_proj = OpenPanguMergedReplicatedLinear(
|
| 264 |
+
hidden_size, [intermediate_size] * 2,
|
| 265 |
+
bias=bias,
|
| 266 |
+
quant_config=quant_config,
|
| 267 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 268 |
+
self.down_proj = ReplicatedLinear(intermediate_size,
|
| 269 |
+
hidden_size,
|
| 270 |
+
bias=bias,
|
| 271 |
+
quant_config=quant_config,
|
| 272 |
+
prefix=f"{prefix}.down_proj")
|
| 273 |
+
|
| 274 |
+
check_ffn_act_fn(hidden_act)
|
| 275 |
+
|
| 276 |
+
quant_method = self.gate_up_proj.quant_method
|
| 277 |
+
if isinstance(quant_method, UnquantizedLinearMethod):
|
| 278 |
+
self.act_fn = OpenPanguSiluAndMul()
|
| 279 |
+
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
|
| 280 |
+
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
|
| 281 |
+
# TODO(sdmyzlp): Currently preserved as before:
|
| 282 |
+
# 1. The only quantization supported for silu is W8A8Dynamic
|
| 283 |
+
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
|
| 284 |
+
#
|
| 285 |
+
# Maybe one can implement a better and more general configuration
|
| 286 |
+
# scheme, e.g. by somehow passing around the tweaked `quant_config`
|
| 287 |
+
self.act_fn = OpenPanguSiluAndMul(
|
| 288 |
+
# Use lazy binding, for `weight_scale_fp32` is accessible
|
| 289 |
+
# only after `process_weights_after_loading`.
|
| 290 |
+
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
|
| 291 |
+
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
|
| 292 |
+
self.gate_up_proj._ascend_quant_config = {
|
| 293 |
+
"output_dtype": torch.int32,
|
| 294 |
+
"pertoken_scale": False,
|
| 295 |
+
"return_scale": True,
|
| 296 |
+
}
|
| 297 |
+
self.down_proj._ascend_quant_config = {
|
| 298 |
+
"output_dtype": torch.bfloat16,
|
| 299 |
+
"pertoken_scale": True,
|
| 300 |
+
"return_scale": False,
|
| 301 |
+
}
|
| 302 |
+
else:
|
| 303 |
+
raise NotImplementedError(
|
| 304 |
+
f"Quantization with [{type(quant_method)}] is NOT supported")
|
| 305 |
+
|
| 306 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
return self.down_proj(self.act_fn(self.gate_up_proj(x)[0]))[0]
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class OpenPanguMoE(nn.Module):
|
| 311 |
+
|
| 312 |
+
top_k: int
|
| 313 |
+
|
| 314 |
+
def __init__(
|
| 315 |
+
self,
|
| 316 |
+
config: PretrainedConfig,
|
| 317 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 318 |
+
prefix: str = "",
|
| 319 |
+
):
|
| 320 |
+
super().__init__()
|
| 321 |
+
ascend_config = get_ascend_config()
|
| 322 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 323 |
+
self.enable_multistream_moe = \
|
| 324 |
+
ascend_config.torchair_graph_config.enable_multistream_moe
|
| 325 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 326 |
+
check_ffn_act_fn(config.hidden_act)
|
| 327 |
+
|
| 328 |
+
self.gate = ReplicatedLinear(config.hidden_size,
|
| 329 |
+
config.num_routed_experts,
|
| 330 |
+
bias=False,
|
| 331 |
+
quant_config=None,
|
| 332 |
+
prefix=f"{prefix}.gate")
|
| 333 |
+
|
| 334 |
+
self.experts = AscendFusedMoE(
|
| 335 |
+
num_experts=config.num_routed_experts,
|
| 336 |
+
top_k=config.num_experts_per_tok,
|
| 337 |
+
hidden_size=config.hidden_size,
|
| 338 |
+
intermediate_size=config.moe_intermediate_size,
|
| 339 |
+
reduce_results=False,
|
| 340 |
+
renormalize=config.norm_topk_prob,
|
| 341 |
+
quant_config=quant_config,
|
| 342 |
+
use_grouped_topk=True,
|
| 343 |
+
num_expert_group=1,
|
| 344 |
+
topk_group=1,
|
| 345 |
+
prefix=f"{prefix}.experts",
|
| 346 |
+
scoring_func='sigmoid',
|
| 347 |
+
e_score_correction_bias=None)
|
| 348 |
+
|
| 349 |
+
if config.num_shared_experts is not None:
|
| 350 |
+
self.all_reduce_merge = self.experts.all_reduce_merge
|
| 351 |
+
reduce_results = not self.all_reduce_merge
|
| 352 |
+
intermediate_size = (config.moe_intermediate_size * config.num_shared_experts)
|
| 353 |
+
self.shared_experts = OpenPanguMLP(
|
| 354 |
+
hidden_size=config.hidden_size,
|
| 355 |
+
intermediate_size=intermediate_size,
|
| 356 |
+
hidden_act=config.hidden_act,
|
| 357 |
+
quant_config=quant_config,
|
| 358 |
+
reduce_results=reduce_results,
|
| 359 |
+
force_replicate=self.enable_multistream_moe,
|
| 360 |
+
prefix=f"{prefix}.shared_experts",
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
self.shared_experts = None # type: ignore
|
| 364 |
+
|
| 365 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 366 |
+
self.dp_size = get_dp_group().world_size
|
| 367 |
+
self.tp_group = get_tp_group().device_group
|
| 368 |
+
self.tp_rank = get_tp_group().rank_in_group
|
| 369 |
+
self.ep_group = get_ep_group()
|
| 370 |
+
|
| 371 |
+
self.params_dtype = torch.get_default_dtype()
|
| 372 |
+
self.rm_router_logits = self.experts.rm_router_logits
|
| 373 |
+
|
| 374 |
+
self.__class__.top_k = config.num_experts_per_tok
|
| 375 |
+
|
| 376 |
+
def forward(self,
|
| 377 |
+
hidden_states: torch.Tensor,
|
| 378 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 379 |
+
replace_allreduce: bool = False) -> torch.Tensor:
|
| 380 |
+
|
| 381 |
+
if attn_metadata is None:
|
| 382 |
+
attn_metadata = get_forward_context().attn_metadata
|
| 383 |
+
# when profile runs, force experts to load balanced tokens
|
| 384 |
+
# to avoid high memory consumption on a single rank.
|
| 385 |
+
# TODO: need a better flag to indicate whether in profile run or not.
|
| 386 |
+
if attn_metadata is None:
|
| 387 |
+
# for profile run
|
| 388 |
+
is_prefill = True
|
| 389 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
|
| 390 |
+
enable_force_load_balance = fused_moe_state != FusedMoEState.AllGatherEP
|
| 391 |
+
else:
|
| 392 |
+
is_prefill = attn_metadata.num_prefills > 0
|
| 393 |
+
enable_force_load_balance = False
|
| 394 |
+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
| 395 |
+
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
| 396 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
|
| 397 |
+
|
| 398 |
+
# router_logits: (num_tokens, n_experts)
|
| 399 |
+
router_logits = None
|
| 400 |
+
if not self.rm_router_logits or fused_moe_state == FusedMoEState.All2All:
|
| 401 |
+
router_logits, _ = self.gate(hidden_states.float())
|
| 402 |
+
|
| 403 |
+
routed_hidden_states, shared_hidden_states = self.experts(
|
| 404 |
+
hidden_states=hidden_states,
|
| 405 |
+
router_logits=router_logits,
|
| 406 |
+
is_prefill=is_prefill,
|
| 407 |
+
top_k=self.__class__.top_k,
|
| 408 |
+
enable_force_load_balance=enable_force_load_balance,
|
| 409 |
+
shared_experts=self.shared_experts,
|
| 410 |
+
gate=self.gate,
|
| 411 |
+
replace_allreduce=replace_allreduce)
|
| 412 |
+
|
| 413 |
+
if self.all_reduce_merge and fused_moe_state == FusedMoEState.All2All:
|
| 414 |
+
shared_hidden_states = tensor_model_parallel_all_reduce(shared_hidden_states)
|
| 415 |
+
hidden_states = routed_hidden_states * self.routed_scaling_factor + shared_hidden_states
|
| 416 |
+
if self.all_reduce_merge and fused_moe_state != FusedMoEState.All2All:
|
| 417 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 418 |
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
| 419 |
+
|
| 420 |
+
return hidden_states
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class OpenPanguMLAAttention(nn.Module):
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
config: PretrainedConfig,
|
| 428 |
+
hidden_size: int,
|
| 429 |
+
num_heads: int,
|
| 430 |
+
attention_qk_dim: int,
|
| 431 |
+
attention_qk_rope_dim: int,
|
| 432 |
+
attention_v_dim: int,
|
| 433 |
+
attention_q_lora_dim: Optional[int],
|
| 434 |
+
attention_kv_lora_dim: int,
|
| 435 |
+
rope_theta: float = 10000,
|
| 436 |
+
max_position_embeddings: int = 8192,
|
| 437 |
+
cache_config: Optional[CacheConfig] = None,
|
| 438 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 439 |
+
prefix: str = "",
|
| 440 |
+
) -> None:
|
| 441 |
+
super().__init__()
|
| 442 |
+
ascend_config = get_ascend_config()
|
| 443 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 444 |
+
self.enable_multistream_mla = ascend_config.torchair_graph_config.enable_multistream_mla
|
| 445 |
+
|
| 446 |
+
self.hidden_size = hidden_size
|
| 447 |
+
self.num_heads = num_heads
|
| 448 |
+
self.attention_qk_dim = attention_qk_dim
|
| 449 |
+
self.attention_qk_rope_dim = attention_qk_rope_dim
|
| 450 |
+
self.qk_head_dim = attention_qk_dim + attention_qk_rope_dim
|
| 451 |
+
self.attention_v_dim = attention_v_dim
|
| 452 |
+
self.attention_q_lora_dim = attention_q_lora_dim
|
| 453 |
+
self.attention_kv_lora_dim = attention_kv_lora_dim
|
| 454 |
+
self.rope_theta = rope_theta
|
| 455 |
+
|
| 456 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 457 |
+
if num_heads % tp_size != 0:
|
| 458 |
+
raise ValueError(f'num_heads {num_heads} is not divisible by tp_size {tp_size}.')
|
| 459 |
+
self.num_local_heads = num_heads // tp_size
|
| 460 |
+
|
| 461 |
+
self.scaling = self.qk_head_dim**-0.5
|
| 462 |
+
self.max_position_embeddings = max_position_embeddings
|
| 463 |
+
|
| 464 |
+
self.prefix = prefix
|
| 465 |
+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
| 466 |
+
|
| 467 |
+
if self.attention_q_lora_dim is not None:
|
| 468 |
+
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
| 469 |
+
self.attention_q_lora_dim,
|
| 470 |
+
bias=False,
|
| 471 |
+
quant_config=quant_config,
|
| 472 |
+
prefix=f"{prefix}.q_a_proj")
|
| 473 |
+
self.q_a_layernorm = RMSNorm(self.attention_q_lora_dim, eps=config.rms_norm_eps)
|
| 474 |
+
self.q_b_proj = ColumnParallelLinear(attention_q_lora_dim,
|
| 475 |
+
self.num_heads * self.qk_head_dim,
|
| 476 |
+
bias=False,
|
| 477 |
+
quant_config=quant_config,
|
| 478 |
+
prefix=f"{prefix}.q_b_proj")
|
| 479 |
+
else:
|
| 480 |
+
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
| 481 |
+
self.num_heads * self.qk_head_dim,
|
| 482 |
+
bias=False,
|
| 483 |
+
quant_config=quant_config,
|
| 484 |
+
prefix=f"{prefix}.q_proj")
|
| 485 |
+
|
| 486 |
+
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
| 487 |
+
self.hidden_size,
|
| 488 |
+
self.attention_kv_lora_dim + self.attention_qk_rope_dim,
|
| 489 |
+
bias=False,
|
| 490 |
+
quant_config=quant_config,
|
| 491 |
+
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
| 492 |
+
self.kv_a_layernorm = RMSNorm(self.attention_kv_lora_dim,
|
| 493 |
+
eps=config.rms_norm_eps)
|
| 494 |
+
self.kv_b_proj = ColumnParallelLinear(
|
| 495 |
+
self.attention_kv_lora_dim,
|
| 496 |
+
self.num_heads * (self.attention_qk_dim + self.attention_v_dim),
|
| 497 |
+
bias=False,
|
| 498 |
+
quant_config=quant_config,
|
| 499 |
+
prefix=f"{prefix}.kv_b_proj")
|
| 500 |
+
if (config.num_routed_experts is not None
|
| 501 |
+
and self.debug_layer_idx >= config.num_dense_layers and
|
| 502 |
+
ascend_config.torchair_graph_config.enable_multistream_moe):
|
| 503 |
+
self.o_proj = OpenPanguRowParallelLinearReplaceAllreduce(
|
| 504 |
+
self.num_heads * self.attention_v_dim,
|
| 505 |
+
self.hidden_size,
|
| 506 |
+
bias=False,
|
| 507 |
+
quant_config=quant_config,
|
| 508 |
+
prefix=f"{prefix}.o_proj")
|
| 509 |
+
else:
|
| 510 |
+
self.o_proj = OpenPanguRowParallelLinear(
|
| 511 |
+
self.num_heads * self.attention_v_dim,
|
| 512 |
+
self.hidden_size,
|
| 513 |
+
bias=False,
|
| 514 |
+
quant_config=quant_config,
|
| 515 |
+
prefix=f"{prefix}.o_proj")
|
| 516 |
+
|
| 517 |
+
self.rotary_emb = OpenPanguRotaryEmbedding(attention_qk_rope_dim,
|
| 518 |
+
rotary_dim=attention_qk_rope_dim,
|
| 519 |
+
max_position_embeddings=max_position_embeddings,
|
| 520 |
+
base=rope_theta)
|
| 521 |
+
|
| 522 |
+
self.mla_attn = Attention(
|
| 523 |
+
num_heads=self.num_local_heads,
|
| 524 |
+
head_size=self.attention_kv_lora_dim + self.attention_qk_rope_dim,
|
| 525 |
+
scale=self.scaling,
|
| 526 |
+
num_kv_heads=1,
|
| 527 |
+
cache_config=cache_config,
|
| 528 |
+
quant_config=quant_config,
|
| 529 |
+
prefix=f"{prefix}.attn",
|
| 530 |
+
use_mla=True,
|
| 531 |
+
# MLA Args
|
| 532 |
+
q_lora_rank=self.attention_q_lora_dim,
|
| 533 |
+
kv_lora_rank=self.attention_kv_lora_dim,
|
| 534 |
+
qk_nope_head_dim=self.attention_qk_dim,
|
| 535 |
+
qk_rope_head_dim=self.attention_qk_rope_dim,
|
| 536 |
+
qk_head_dim=self.qk_head_dim,
|
| 537 |
+
v_head_dim=self.attention_v_dim,
|
| 538 |
+
rotary_emb=self.rotary_emb,
|
| 539 |
+
q_proj=self.q_proj if self.attention_q_lora_dim is None else self.q_b_proj,
|
| 540 |
+
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
| 541 |
+
kv_a_layernorm=self.kv_a_layernorm,
|
| 542 |
+
kv_b_proj=self.kv_b_proj,
|
| 543 |
+
o_proj=self.o_proj,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
positions: torch.Tensor,
|
| 549 |
+
hidden_states: torch.Tensor,
|
| 550 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 551 |
+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
| 552 |
+
enable_multistream_mla = (self.enable_multistream_mla
|
| 553 |
+
and attn_metadata is not None
|
| 554 |
+
and not attn_metadata.with_prefill_across_dp
|
| 555 |
+
and attn_metadata.num_decodes > 0)
|
| 556 |
+
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
| 557 |
+
if self.attention_q_lora_dim is not None:
|
| 558 |
+
npu_prefetch(self.q_a_proj.weight,
|
| 559 |
+
hidden_states,
|
| 560 |
+
enabled=enable_multistream_mla)
|
| 561 |
+
ckq = self.q_a_proj(hidden_states)[0]
|
| 562 |
+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
| 563 |
+
forward_kwargs['ckq'] = ckq
|
| 564 |
+
else:
|
| 565 |
+
hidden_states_or_q_c = hidden_states
|
| 566 |
+
if self.torchair_graph_enabled:
|
| 567 |
+
if envs.VLLM_USE_V1:
|
| 568 |
+
output_shape = hidden_states.shape
|
| 569 |
+
output = torch.empty(output_shape,
|
| 570 |
+
dtype=hidden_states_or_q_c.dtype,
|
| 571 |
+
device=hidden_states_or_q_c.device)
|
| 572 |
+
forward_kwargs['output'] = output
|
| 573 |
+
|
| 574 |
+
output = self.mla_attn.impl.forward(self.mla_attn,
|
| 575 |
+
hidden_states_or_q_c,
|
| 576 |
+
hidden_states, None, kv_cache,
|
| 577 |
+
attn_metadata,
|
| 578 |
+
**forward_kwargs)
|
| 579 |
+
if envs.VLLM_USE_V1:
|
| 580 |
+
output = output.view(-1, output_shape[-1])
|
| 581 |
+
return output
|
| 582 |
+
else:
|
| 583 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
| 584 |
+
[self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1)
|
| 585 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 586 |
+
return self.mla_attn(hidden_states_or_q_c,
|
| 587 |
+
kv_c_normed,
|
| 588 |
+
k_pe,
|
| 589 |
+
output_shape=hidden_states.shape)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class OpenPanguEmbeddedAttention(nn.Module):
|
| 593 |
+
|
| 594 |
+
def __init__(
|
| 595 |
+
self,
|
| 596 |
+
config: PretrainedConfig,
|
| 597 |
+
hidden_size: int,
|
| 598 |
+
num_heads: int,
|
| 599 |
+
num_kv_heads: int,
|
| 600 |
+
rope_theta: float = 10000,
|
| 601 |
+
rope_scaling: Optional[dict[str, Any]] = None,
|
| 602 |
+
max_position_embeddings: int = 8192,
|
| 603 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 604 |
+
bias: bool = False,
|
| 605 |
+
bias_o_proj: bool = False,
|
| 606 |
+
cache_config: Optional[CacheConfig] = None,
|
| 607 |
+
prefix: str = "",
|
| 608 |
+
attn_type: str = AttentionType.DECODER,
|
| 609 |
+
) -> None:
|
| 610 |
+
super().__init__()
|
| 611 |
+
layer_idx = extract_layer_index(prefix)
|
| 612 |
+
self.hidden_size = hidden_size
|
| 613 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 614 |
+
self.total_num_heads = num_heads
|
| 615 |
+
if self.total_num_heads % tp_size != 0:
|
| 616 |
+
raise ValueError(f'total_num_heads {total_num_heads} is not divisible by tp_size {tp_size}.')
|
| 617 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 618 |
+
self.total_num_kv_heads = num_kv_heads
|
| 619 |
+
if self.total_num_kv_heads >= tp_size and self.total_num_kv_heads % tp_size != 0:
|
| 620 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 621 |
+
# the KV heads across multiple tensor parallel NPUs.
|
| 622 |
+
raise ValueError(f'Number of KV heads is less than TP size, but total_num_kv_heads {self.total_num_kv_heads} '
|
| 623 |
+
f'is not divisible by tp_size {tp_size}.')
|
| 624 |
+
elif self.total_num_kv_heads < tp_size and tp_size % self.total_num_kv_heads != 0:
|
| 625 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 626 |
+
# the KV heads across multiple tensor parallel NPUs.
|
| 627 |
+
raise ValueError(f'Number of KV heads is less than TP size, but tp_size {tp_size} '
|
| 628 |
+
f'is not divisible by total_num_kv_heads {self.total_num_kv_heads}.')
|
| 629 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 630 |
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
| 631 |
+
head_dim = getattr(config, "head_dim", None)
|
| 632 |
+
if head_dim is None:
|
| 633 |
+
head_dim = self.hidden_size // self.total_num_heads
|
| 634 |
+
self.head_dim = head_dim
|
| 635 |
+
# Phi models introduced a partial_rotary_factor parameter in the config
|
| 636 |
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
| 637 |
+
self.q_size = self.num_heads * self.head_dim
|
| 638 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 639 |
+
self.scaling = self.head_dim**-0.5
|
| 640 |
+
self.rope_theta = rope_theta
|
| 641 |
+
self.max_position_embeddings = max_position_embeddings
|
| 642 |
+
|
| 643 |
+
self.qkv_proj = QKVParallelLinear(
|
| 644 |
+
hidden_size=hidden_size,
|
| 645 |
+
head_size=self.head_dim,
|
| 646 |
+
total_num_heads=self.total_num_heads,
|
| 647 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 648 |
+
bias=bias,
|
| 649 |
+
quant_config=quant_config,
|
| 650 |
+
prefix=f"{prefix}.qkv_proj",
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
self.o_proj = RowParallelLinear(
|
| 654 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 655 |
+
output_size=hidden_size,
|
| 656 |
+
bias=bias_o_proj,
|
| 657 |
+
quant_config=quant_config,
|
| 658 |
+
prefix=f"{prefix}.o_proj",
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
self._init_rotary_emb(config,
|
| 662 |
+
rope_scaling=rope_scaling,
|
| 663 |
+
quant_config=quant_config)
|
| 664 |
+
|
| 665 |
+
if hasattr(config, "interleaved_sliding_window"):
|
| 666 |
+
interleaved_sliding_window = config.interleaved_sliding_window
|
| 667 |
+
if isinstance(interleaved_sliding_window, int):
|
| 668 |
+
sliding_window = interleaved_sliding_window
|
| 669 |
+
elif isinstance(interleaved_sliding_window, list):
|
| 670 |
+
sw_idx = layer_idx % len(interleaved_sliding_window)
|
| 671 |
+
sliding_window = interleaved_sliding_window[sw_idx]
|
| 672 |
+
else:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
f"{type(interleaved_sliding_window)} is not supported.")
|
| 675 |
+
else:
|
| 676 |
+
sliding_window = None
|
| 677 |
+
|
| 678 |
+
self.attn = Attention(
|
| 679 |
+
self.num_heads,
|
| 680 |
+
self.head_dim,
|
| 681 |
+
self.scaling,
|
| 682 |
+
num_kv_heads=self.num_kv_heads,
|
| 683 |
+
cache_config=cache_config,
|
| 684 |
+
quant_config=quant_config,
|
| 685 |
+
per_layer_sliding_window=sliding_window,
|
| 686 |
+
attn_type=attn_type,
|
| 687 |
+
prefix=f"{prefix}.attn",
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
def forward(
|
| 691 |
+
self,
|
| 692 |
+
positions: torch.Tensor,
|
| 693 |
+
hidden_states: torch.Tensor,
|
| 694 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 695 |
+
attn_metadata: Optional[AttentionMetadata] = None
|
| 696 |
+
) -> torch.Tensor:
|
| 697 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 698 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 699 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 700 |
+
attn_output = self.attn(q, k, v)
|
| 701 |
+
output, _ = self.o_proj(attn_output)
|
| 702 |
+
return output
|
| 703 |
+
|
| 704 |
+
def _init_rotary_emb(self, config: PretrainedConfig,
|
| 705 |
+
rope_scaling: Optional[dict[str, Any]],
|
| 706 |
+
quant_config: Optional[QuantizationConfig]) -> None:
|
| 707 |
+
is_neox_style = True
|
| 708 |
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
| 709 |
+
if is_gguf and config.model_type == "Pangu":
|
| 710 |
+
is_neox_style = False
|
| 711 |
+
|
| 712 |
+
self.rotary_emb = get_rope(
|
| 713 |
+
self.head_dim,
|
| 714 |
+
rotary_dim=self.head_dim,
|
| 715 |
+
max_position=self.max_position_embeddings,
|
| 716 |
+
base=self.rope_theta,
|
| 717 |
+
rope_scaling=rope_scaling,
|
| 718 |
+
is_neox_style=is_neox_style,
|
| 719 |
+
#partial_rotary_factor=self.partial_rotary_factor,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class OpenPanguDecoderLayer(nn.Module):
|
| 724 |
+
|
| 725 |
+
def __init__(
|
| 726 |
+
self,
|
| 727 |
+
config: PretrainedConfig,
|
| 728 |
+
prefix: str,
|
| 729 |
+
model_config: ModelConfig,
|
| 730 |
+
cache_config: Optional[CacheConfig] = None,
|
| 731 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 732 |
+
) -> None:
|
| 733 |
+
super().__init__()
|
| 734 |
+
self.hidden_size = config.hidden_size
|
| 735 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 736 |
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
| 737 |
+
|
| 738 |
+
layer_idx = int(prefix.split(sep='.')[-1])
|
| 739 |
+
self.layer_idx = layer_idx
|
| 740 |
+
self.layers = config.num_hidden_layers
|
| 741 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 742 |
+
self.tp_rank = get_tp_group().rank_in_group
|
| 743 |
+
ascend_config = get_ascend_config()
|
| 744 |
+
|
| 745 |
+
self.use_mla = hasattr(config, 'attention_qk_dim') and hasattr(config, 'attention_qk_rope_dim') \
|
| 746 |
+
and hasattr(config, 'attention_v_dim') and hasattr(config, 'attention_kv_lora_dim')
|
| 747 |
+
if self.use_mla:
|
| 748 |
+
self.self_attn = OpenPanguMLAAttention(
|
| 749 |
+
config=config,
|
| 750 |
+
hidden_size=self.hidden_size,
|
| 751 |
+
num_heads=config.num_attention_heads,
|
| 752 |
+
attention_qk_dim=config.attention_qk_dim,
|
| 753 |
+
attention_qk_rope_dim=config.attention_qk_rope_dim,
|
| 754 |
+
attention_v_dim=config.attention_v_dim,
|
| 755 |
+
attention_q_lora_dim=config.attention_q_lora_dim
|
| 756 |
+
if hasattr(config, "attention_q_lora_dim") else None,
|
| 757 |
+
attention_kv_lora_dim=config.attention_kv_lora_dim,
|
| 758 |
+
rope_theta=rope_theta,
|
| 759 |
+
max_position_embeddings=max_position_embeddings,
|
| 760 |
+
cache_config=cache_config,
|
| 761 |
+
quant_config=quant_config,
|
| 762 |
+
prefix=f"{prefix}.self_attn",
|
| 763 |
+
)
|
| 764 |
+
else:
|
| 765 |
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
| 766 |
+
config, "bias", False)
|
| 767 |
+
bias_o_proj = attention_bias
|
| 768 |
+
if hasattr(config, 'qkv_bias'):
|
| 769 |
+
attention_bias = config.qkv_bias
|
| 770 |
+
# By default, PanguEmbedded uses causal attention as it is a decoder-only model.
|
| 771 |
+
# You can override the HF config with `is_causal=False` to enable
|
| 772 |
+
# bidirectional attention, which is used in some embedding models
|
| 773 |
+
if getattr(config, "is_causal", True):
|
| 774 |
+
attn_type = AttentionType.DECODER
|
| 775 |
+
else:
|
| 776 |
+
attn_type = AttentionType.ENCODER_ONLY
|
| 777 |
+
self.self_attn = OpenPanguEmbeddedAttention(
|
| 778 |
+
config=config,
|
| 779 |
+
hidden_size=self.hidden_size,
|
| 780 |
+
num_heads=config.num_attention_heads,
|
| 781 |
+
num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads),
|
| 782 |
+
rope_theta=rope_theta,
|
| 783 |
+
rope_scaling=getattr(config, "rope_scaling", None),
|
| 784 |
+
max_position_embeddings=max_position_embeddings,
|
| 785 |
+
quant_config=quant_config,
|
| 786 |
+
bias=attention_bias,
|
| 787 |
+
bias_o_proj=bias_o_proj,
|
| 788 |
+
cache_config=cache_config,
|
| 789 |
+
prefix=f"{prefix}.self_attn",
|
| 790 |
+
attn_type=attn_type,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if getattr(config, 'num_routed_experts', None) is not None and layer_idx >= config.num_dense_layers:
|
| 794 |
+
self.mlp = OpenPanguMoE(
|
| 795 |
+
config=config,
|
| 796 |
+
quant_config=quant_config,
|
| 797 |
+
prefix=f"{prefix}.mlp",
|
| 798 |
+
)
|
| 799 |
+
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
|
| 800 |
+
and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
|
| 801 |
+
else:
|
| 802 |
+
self.mlp = OpenPanguMLP(
|
| 803 |
+
hidden_size=self.hidden_size,
|
| 804 |
+
intermediate_size=config.intermediate_size,
|
| 805 |
+
hidden_act=config.hidden_act,
|
| 806 |
+
quant_config=quant_config,
|
| 807 |
+
bias=getattr(config, "mlp_bias", False),
|
| 808 |
+
prefix=f"{prefix}.mlp",
|
| 809 |
+
)
|
| 810 |
+
self.mla_moe_communication = False
|
| 811 |
+
self.routed_scaling_factor = getattr(config, 'routed_scaling_factor', None)
|
| 812 |
+
self.num_dense_layers = getattr(config, 'num_dense_layers', None)
|
| 813 |
+
|
| 814 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 815 |
+
eps=config.rms_norm_eps)
|
| 816 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 817 |
+
eps=config.rms_norm_eps)
|
| 818 |
+
if getattr(config, 'sandwich_norm', False):
|
| 819 |
+
self.sandwich_norm = True
|
| 820 |
+
self.pre_mlp_layernorm = RMSNorm(config.hidden_size,
|
| 821 |
+
eps=config.rms_norm_eps)
|
| 822 |
+
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
|
| 823 |
+
eps=config.rms_norm_eps)
|
| 824 |
+
else:
|
| 825 |
+
self.sandwich_norm = False
|
| 826 |
+
|
| 827 |
+
def forward(
|
| 828 |
+
self,
|
| 829 |
+
positions: torch.Tensor,
|
| 830 |
+
hidden_states: torch.Tensor,
|
| 831 |
+
residual: Optional[torch.Tensor],
|
| 832 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 833 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 834 |
+
replace_allreduce: bool = False,
|
| 835 |
+
) -> torch.Tensor:
|
| 836 |
+
# Self Attention
|
| 837 |
+
if self.use_mla and attn_metadata is not None and attn_metadata.num_decodes > 0:
|
| 838 |
+
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
| 839 |
+
else:
|
| 840 |
+
mla_moe_communication = False
|
| 841 |
+
if residual is None:
|
| 842 |
+
residual = hidden_states
|
| 843 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 844 |
+
else:
|
| 845 |
+
previous_hidden_states, previous_residual = hidden_states, residual
|
| 846 |
+
hidden_states, residual = self.input_layernorm(
|
| 847 |
+
hidden_states, residual)
|
| 848 |
+
# Dispose hidden_states and residual from the previous layer
|
| 849 |
+
# to save npu memory because they're no longer used.
|
| 850 |
+
dispose_tensor(previous_hidden_states)
|
| 851 |
+
dispose_tensor(previous_residual)
|
| 852 |
+
if mla_moe_communication and self.layer_idx > self.num_dense_layers:
|
| 853 |
+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
| 854 |
+
dim=0)
|
| 855 |
+
|
| 856 |
+
hidden_states = self.self_attn(
|
| 857 |
+
positions=positions,
|
| 858 |
+
hidden_states=hidden_states,
|
| 859 |
+
kv_cache=kv_cache,
|
| 860 |
+
attn_metadata=attn_metadata,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
if mla_moe_communication and residual.shape[0] != hidden_states.shape[0]:
|
| 864 |
+
chunk_hidden_states = torch.tensor_split(residual,
|
| 865 |
+
self.tp_size,
|
| 866 |
+
dim=0)
|
| 867 |
+
residual = chunk_hidden_states[self.tp_rank]
|
| 868 |
+
|
| 869 |
+
if self.routed_scaling_factor is not None and hidden_states.dtype == torch.float16:
|
| 870 |
+
# Fix FP16 overflow
|
| 871 |
+
# We scale both hidden_states and residual before
|
| 872 |
+
# rmsnorm, and rmsnorm result would not affect by scale.
|
| 873 |
+
hidden_states *= 1. / self.routed_scaling_factor
|
| 874 |
+
if self.layer_idx == 0:
|
| 875 |
+
# The residual is shared by all layers, we only scale it on
|
| 876 |
+
# first layer.
|
| 877 |
+
residual *= 1. / self.routed_scaling_factor
|
| 878 |
+
|
| 879 |
+
if self.sandwich_norm:
|
| 880 |
+
hidden_states = self.post_attention_layernorm(
|
| 881 |
+
hidden_states)
|
| 882 |
+
hidden_states, residual = self.pre_mlp_layernorm(
|
| 883 |
+
hidden_states, residual)
|
| 884 |
+
else:
|
| 885 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 886 |
+
hidden_states, residual)
|
| 887 |
+
|
| 888 |
+
# Fully Connected
|
| 889 |
+
if isinstance(self.mlp, OpenPanguMoE):
|
| 890 |
+
hidden_states = self.mlp(hidden_states,
|
| 891 |
+
attn_metadata,
|
| 892 |
+
replace_allreduce=mla_moe_communication)
|
| 893 |
+
else:
|
| 894 |
+
hidden_states = self.mlp(hidden_states)
|
| 895 |
+
|
| 896 |
+
if self.routed_scaling_factor is not None and isinstance(self.mlp, OpenPanguMLP) \
|
| 897 |
+
and hidden_states.dtype == torch.float16:
|
| 898 |
+
hidden_states *= 1. / self.routed_scaling_factor
|
| 899 |
+
|
| 900 |
+
if self.sandwich_norm:
|
| 901 |
+
hidden_states = self.post_mlp_layernorm(hidden_states)
|
| 902 |
+
|
| 903 |
+
if mla_moe_communication and self.layer_idx == self.layers - 1:
|
| 904 |
+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
| 905 |
+
dim=0)
|
| 906 |
+
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
| 907 |
+
|
| 908 |
+
return hidden_states, residual
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@support_torch_compile
|
| 912 |
+
class OpenPanguModel(nn.Module):
|
| 913 |
+
|
| 914 |
+
fall_back_to_pt_during_load = False
|
| 915 |
+
|
| 916 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 917 |
+
super().__init__()
|
| 918 |
+
|
| 919 |
+
config = vllm_config.model_config.hf_config
|
| 920 |
+
model_config = vllm_config.model_config
|
| 921 |
+
cache_config = vllm_config.cache_config
|
| 922 |
+
quant_config = vllm_config.quant_config
|
| 923 |
+
|
| 924 |
+
self.padding_idx = config.pad_token_id
|
| 925 |
+
self.vocab_size = config.vocab_size
|
| 926 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 927 |
+
|
| 928 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 929 |
+
config.vocab_size,
|
| 930 |
+
config.hidden_size,
|
| 931 |
+
quant_config=quant_config,
|
| 932 |
+
prefix=f"{prefix}.embed_tokens")
|
| 933 |
+
|
| 934 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 935 |
+
config.num_hidden_layers,
|
| 936 |
+
lambda prefix: OpenPanguDecoderLayer(
|
| 937 |
+
config,
|
| 938 |
+
prefix,
|
| 939 |
+
model_config=model_config,
|
| 940 |
+
cache_config=cache_config,
|
| 941 |
+
quant_config=quant_config,
|
| 942 |
+
),
|
| 943 |
+
prefix=f"{prefix}.layers")
|
| 944 |
+
|
| 945 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 946 |
+
|
| 947 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 948 |
+
return self.embed_tokens(input_ids)
|
| 949 |
+
|
| 950 |
+
def forward(
|
| 951 |
+
self,
|
| 952 |
+
input_ids: torch.Tensor,
|
| 953 |
+
positions: torch.Tensor,
|
| 954 |
+
kv_caches: Optional[List[torch.Tensor]] = None,
|
| 955 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 956 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 957 |
+
**kwargs,
|
| 958 |
+
) -> torch.Tensor:
|
| 959 |
+
if inputs_embeds is not None:
|
| 960 |
+
hidden_states = inputs_embeds
|
| 961 |
+
else:
|
| 962 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 963 |
+
residual = None
|
| 964 |
+
|
| 965 |
+
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
| 966 |
+
|
| 967 |
+
for i in range(self.start_layer, self.end_layer):
|
| 968 |
+
layer = self.layers[i]
|
| 969 |
+
hidden_states, residual = layer(
|
| 970 |
+
positions,
|
| 971 |
+
hidden_states,
|
| 972 |
+
residual,
|
| 973 |
+
kv_caches[i -
|
| 974 |
+
self.start_layer] if kv_caches is not None else None,
|
| 975 |
+
attn_metadata,
|
| 976 |
+
replace_allreduce=replace_allreduce)
|
| 977 |
+
|
| 978 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 979 |
+
return hidden_states
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
class OpenPanguForCausalLM(nn.Module):
|
| 983 |
+
packed_modules_mapping = {
|
| 984 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
| 985 |
+
"experts":
|
| 986 |
+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 990 |
+
super().__init__()
|
| 991 |
+
config = vllm_config.model_config.hf_config
|
| 992 |
+
quant_config = vllm_config.quant_config
|
| 993 |
+
self.config = config
|
| 994 |
+
self.quant_config = quant_config
|
| 995 |
+
self.model = OpenPanguModel(vllm_config=vllm_config,
|
| 996 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 997 |
+
self.lm_head = ParallelLMHead(config.vocab_size,
|
| 998 |
+
config.hidden_size,
|
| 999 |
+
quant_config=quant_config,
|
| 1000 |
+
prefix=maybe_prefix(prefix, "lm_head"))
|
| 1001 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 1002 |
+
self.sampler = get_sampler()
|
| 1003 |
+
|
| 1004 |
+
def load_attn_mlp_weight(self,
|
| 1005 |
+
attn_mlp_replace_mapping: List[Tuple[str, str, int]],
|
| 1006 |
+
params_dict: Dict[str, Any],
|
| 1007 |
+
weight_name: str,
|
| 1008 |
+
loaded_weight: torch.Tensor,
|
| 1009 |
+
loaded_params: set[str]) -> bool:
|
| 1010 |
+
for (param_name, origin_name, shard_id) in attn_mlp_replace_mapping:
|
| 1011 |
+
if origin_name not in weight_name or \
|
| 1012 |
+
(("mlp.experts." in weight_name) and weight_name not in params_dict):
|
| 1013 |
+
continue
|
| 1014 |
+
weight_name = weight_name.replace(origin_name, param_name)
|
| 1015 |
+
if weight_name.endswith(".bias") and weight_name not in params_dict:
|
| 1016 |
+
continue
|
| 1017 |
+
param = params_dict[weight_name]
|
| 1018 |
+
weight_loader = param.weight_loader
|
| 1019 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 1020 |
+
loaded_params.add(weight_name)
|
| 1021 |
+
return True
|
| 1022 |
+
return False
|
| 1023 |
+
|
| 1024 |
+
def load_expert_weight(self,
|
| 1025 |
+
expert_merge_mapping: List[Tuple[str, str, int, str]],
|
| 1026 |
+
params_dict: Dict[str, Any],
|
| 1027 |
+
weight_name: str,
|
| 1028 |
+
loaded_weight: torch.Tensor,
|
| 1029 |
+
loaded_params: set[str]) -> bool:
|
| 1030 |
+
for mapping in expert_merge_mapping:
|
| 1031 |
+
param_name, origin_name, expert_id, shard_id = mapping
|
| 1032 |
+
if origin_name not in weight_name:
|
| 1033 |
+
continue
|
| 1034 |
+
weight_name = weight_name.replace(origin_name, param_name)
|
| 1035 |
+
param = params_dict[weight_name]
|
| 1036 |
+
weight_loader = param.weight_loader
|
| 1037 |
+
weight_loader(param,
|
| 1038 |
+
loaded_weight,
|
| 1039 |
+
weight_name,
|
| 1040 |
+
shard_id=shard_id,
|
| 1041 |
+
expert_id=expert_id,
|
| 1042 |
+
return_success=False)
|
| 1043 |
+
loaded_params.add(weight_name)
|
| 1044 |
+
return True
|
| 1045 |
+
return False
|
| 1046 |
+
|
| 1047 |
+
def load_weights(self, weights: Iterable[tuple[str,
|
| 1048 |
+
torch.Tensor]]) -> set[str]:
|
| 1049 |
+
# (param_name, shard_name, shard_id)
|
| 1050 |
+
attn_mlp_replace_mapping = [
|
| 1051 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 1052 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 1053 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 1054 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 1055 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 1056 |
+
]
|
| 1057 |
+
has_experts = hasattr(self.config, 'num_routed_experts')
|
| 1058 |
+
if has_experts:
|
| 1059 |
+
expert_merge_mapping = AscendFusedMoE.make_expert_params_mapping(
|
| 1060 |
+
ckpt_gate_proj_name="gate_proj",
|
| 1061 |
+
ckpt_down_proj_name="down_proj",
|
| 1062 |
+
ckpt_up_proj_name="up_proj",
|
| 1063 |
+
num_experts=self.config.num_routed_experts)
|
| 1064 |
+
|
| 1065 |
+
params_dict = dict(self.named_parameters())
|
| 1066 |
+
loaded_params: set[str] = set()
|
| 1067 |
+
for name, loaded_weight in weights:
|
| 1068 |
+
if "rotary_emb.inv_freq" in name:
|
| 1069 |
+
continue
|
| 1070 |
+
if 'layers' in name: # skip spec decode layers for main model
|
| 1071 |
+
layer_idx = int(name.split('layers.')[-1].split('.')[0])
|
| 1072 |
+
if layer_idx > self.config.num_hidden_layers:
|
| 1073 |
+
continue
|
| 1074 |
+
|
| 1075 |
+
if 'layers' in name and hasattr(self.config, "num_mtp_layers") \
|
| 1076 |
+
and (self.config.num_mtp_layers > 0):
|
| 1077 |
+
layer_idx = int(name.split('layers.')[-1].split('.')[0])
|
| 1078 |
+
mtp_idx = layer_idx - self.config.num_hidden_layers
|
| 1079 |
+
if mtp_idx >= 0 and mtp_idx < self.config.num_mtp_layers:
|
| 1080 |
+
continue # skip spec decode layers for main model
|
| 1081 |
+
if self.load_attn_mlp_weight(attn_mlp_replace_mapping, params_dict, name, loaded_weight, loaded_params):
|
| 1082 |
+
continue
|
| 1083 |
+
elif has_experts and self.load_expert_weight(expert_merge_mapping, params_dict, name, loaded_weight, loaded_params):
|
| 1084 |
+
continue
|
| 1085 |
+
else:
|
| 1086 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 1087 |
+
continue
|
| 1088 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 1089 |
+
if name is None:
|
| 1090 |
+
continue
|
| 1091 |
+
param = params_dict[name]
|
| 1092 |
+
weight_loader = getattr(param, "weight_loader",
|
| 1093 |
+
default_weight_loader)
|
| 1094 |
+
weight_loader(param, loaded_weight)
|
| 1095 |
+
loaded_params.add(name)
|
| 1096 |
+
if self.config.tie_word_embeddings:
|
| 1097 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 1098 |
+
return loaded_params
|
| 1099 |
+
|
| 1100 |
+
def forward(
|
| 1101 |
+
self,
|
| 1102 |
+
input_ids: torch.Tensor,
|
| 1103 |
+
positions: torch.Tensor,
|
| 1104 |
+
kv_caches: Optional[List[torch.Tensor]] = None,
|
| 1105 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 1106 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1107 |
+
**kwargs,
|
| 1108 |
+
) -> torch.Tensor:
|
| 1109 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 1110 |
+
attn_metadata, inputs_embeds)
|
| 1111 |
+
return hidden_states
|
| 1112 |
+
|
| 1113 |
+
def compute_logits(
|
| 1114 |
+
self,
|
| 1115 |
+
hidden_states: torch.Tensor,
|
| 1116 |
+
sampling_metadata: SamplingMetadata,
|
| 1117 |
+
) -> Optional[torch.Tensor]:
|
| 1118 |
+
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
|
| 1119 |
+
return logits
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
class PanguUltraMoEForCausalLM(OpenPanguForCausalLM):
|
| 1123 |
+
pass
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class PanguEmbeddedForCausalLM(OpenPanguForCausalLM):
|
| 1127 |
+
pass
|
inference/vllm_ascend/ops/fused_moe.py
ADDED
|
@@ -0,0 +1,1530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The vLLM team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# This file is a part of the vllm-ascend project.
|
| 16 |
+
# Adapted from vllm/tests/kernels/test_moe.py
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
import torch_npu
|
| 24 |
+
from torch import nn
|
| 25 |
+
from vllm.config import get_current_vllm_config
|
| 26 |
+
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
| 27 |
+
get_tensor_model_parallel_world_size,
|
| 28 |
+
tensor_model_parallel_all_reduce)
|
| 29 |
+
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
|
| 30 |
+
from vllm.forward_context import get_forward_context
|
| 31 |
+
from vllm.model_executor.layers.fused_moe.config import \
|
| 32 |
+
FusedMoEConfig # isort: skip
|
| 33 |
+
from vllm.model_executor.layers.fused_moe.config import \
|
| 34 |
+
FusedMoEParallelConfig # isort: skip
|
| 35 |
+
from vllm.model_executor.layers.fused_moe.layer import (
|
| 36 |
+
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
| 37 |
+
from vllm.model_executor.layers.quantization.base_config import \
|
| 38 |
+
QuantizationConfig
|
| 39 |
+
|
| 40 |
+
import vllm_ascend.envs as envs_ascend
|
| 41 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 42 |
+
from vllm_ascend.distributed.communication_op import \
|
| 43 |
+
data_parallel_reduce_scatter
|
| 44 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
| 45 |
+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
| 46 |
+
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
| 47 |
+
get_all_reduce_merge_state, get_fused_moe_state,
|
| 48 |
+
get_rm_router_logits_state, is_310p,
|
| 49 |
+
npu_stream_switch, npu_wait_tensor)
|
| 50 |
+
|
| 51 |
+
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
| 52 |
+
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
| 56 |
+
max_row_per_ep_rank: int, num_tokens: int,
|
| 57 |
+
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 58 |
+
original_total_elements = num_tokens * top_k
|
| 59 |
+
device = topk_ids.device
|
| 60 |
+
original_dtype = topk_ids.dtype
|
| 61 |
+
|
| 62 |
+
if original_total_elements == 0:
|
| 63 |
+
output_len = ep_size * max_row_per_ep_rank
|
| 64 |
+
topk_ids_pad = torch.full((output_len, ),
|
| 65 |
+
expert_num,
|
| 66 |
+
dtype=original_dtype,
|
| 67 |
+
device=device)
|
| 68 |
+
unpad_indices = torch.full((original_total_elements, ),
|
| 69 |
+
-1,
|
| 70 |
+
dtype=torch.long,
|
| 71 |
+
device=device)
|
| 72 |
+
return topk_ids_pad, unpad_indices
|
| 73 |
+
|
| 74 |
+
experts_per_ep_rank_val = expert_num // ep_size
|
| 75 |
+
if experts_per_ep_rank_val == 0:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
|
| 78 |
+
"Ensure expert_num >= ep_size.")
|
| 79 |
+
|
| 80 |
+
assigned_ep_rank = (topk_ids.float() /
|
| 81 |
+
experts_per_ep_rank_val).to(original_dtype)
|
| 82 |
+
indices_arange = torch.arange(topk_ids.shape[0], device=device)
|
| 83 |
+
|
| 84 |
+
is_new_segment = torch.cat(
|
| 85 |
+
(torch.tensor([True], device=device), assigned_ep_rank[1:]
|
| 86 |
+
!= assigned_ep_rank[:-1]))
|
| 87 |
+
temp_start_markers = torch.full_like(indices_arange,
|
| 88 |
+
-1,
|
| 89 |
+
dtype=indices_arange.dtype)
|
| 90 |
+
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
|
| 91 |
+
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
|
| 92 |
+
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
|
| 93 |
+
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
|
| 94 |
+
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
|
| 95 |
+
indices_in_rec_cond_list_for_all = cumsum_kept - 1
|
| 96 |
+
unpad_indices = torch.where(
|
| 97 |
+
is_kept_mask, indices_in_rec_cond_list_for_all,
|
| 98 |
+
torch.tensor(-1, device=device, dtype=torch.long))
|
| 99 |
+
output_len = ep_size * max_row_per_ep_rank
|
| 100 |
+
topk_ids_pad = torch.full((output_len, ),
|
| 101 |
+
expert_num,
|
| 102 |
+
dtype=original_dtype,
|
| 103 |
+
device=device)
|
| 104 |
+
if topk_ids.shape[0] > 0:
|
| 105 |
+
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
|
| 106 |
+
temp_pad_buffer = torch.full((output_len + 1, ),
|
| 107 |
+
expert_num,
|
| 108 |
+
dtype=original_dtype,
|
| 109 |
+
device=device)
|
| 110 |
+
output_len_tensor = torch.tensor(output_len,
|
| 111 |
+
dtype=torch.long,
|
| 112 |
+
device=device)
|
| 113 |
+
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
|
| 114 |
+
output_len_tensor)
|
| 115 |
+
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
|
| 116 |
+
topk_ids_pad = temp_pad_buffer[:output_len]
|
| 117 |
+
return topk_ids_pad, unpad_indices
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def fused_experts_with_mc2(
|
| 121 |
+
hidden_states: torch.Tensor,
|
| 122 |
+
w1: torch.Tensor,
|
| 123 |
+
w2: torch.Tensor,
|
| 124 |
+
topk_weights: torch.Tensor,
|
| 125 |
+
topk_ids: torch.Tensor,
|
| 126 |
+
top_k: int,
|
| 127 |
+
expert_map: torch.Tensor = None,
|
| 128 |
+
moe_all_to_all_group_name: Optional[str] = None,
|
| 129 |
+
shared_experts: Optional[Any] = None
|
| 130 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 131 |
+
global_bs = 0
|
| 132 |
+
moe_expert_num = len(expert_map)
|
| 133 |
+
kwargs_mc2 = {
|
| 134 |
+
"x": hidden_states,
|
| 135 |
+
"expert_ids": topk_ids,
|
| 136 |
+
"expert_shard_type": 0,
|
| 137 |
+
"shared_expert_rank_num": 0,
|
| 138 |
+
"moe_expert_num": moe_expert_num,
|
| 139 |
+
"global_bs": global_bs,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
rank = torch.distributed.get_rank()
|
| 143 |
+
|
| 144 |
+
quant_mode = 0
|
| 145 |
+
ep_group = get_ep_group().device_group
|
| 146 |
+
local_rank = torch.distributed.get_rank(group=ep_group)
|
| 147 |
+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
| 148 |
+
|
| 149 |
+
tp_size = get_etp_group().world_size
|
| 150 |
+
tp_rank = rank % tp_size
|
| 151 |
+
|
| 152 |
+
stage1_kwargs = {
|
| 153 |
+
"scales": None,
|
| 154 |
+
"quant_mode": quant_mode,
|
| 155 |
+
"group_ep": moe_all_to_all_group_name,
|
| 156 |
+
"ep_world_size": all_to_all_group_size,
|
| 157 |
+
"ep_rank_id": local_rank,
|
| 158 |
+
# "group_tp": self.moe_rs_group_name,
|
| 159 |
+
"group_tp": moe_all_to_all_group_name,
|
| 160 |
+
"tp_world_size": tp_size,
|
| 161 |
+
"tp_rank_id": tp_rank,
|
| 162 |
+
}
|
| 163 |
+
kwargs_mc2.update(stage1_kwargs)
|
| 164 |
+
|
| 165 |
+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
| 166 |
+
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
| 167 |
+
0:5]
|
| 168 |
+
|
| 169 |
+
if shared_experts is not None:
|
| 170 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 171 |
+
npu_wait_tensor(hidden_states, topk_weights)
|
| 172 |
+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
| 173 |
+
npu_wait_tensor(shared_gate_up, expand_x)
|
| 174 |
+
shared_act = shared_experts.act_fn(shared_gate_up)
|
| 175 |
+
|
| 176 |
+
w1 = w1.transpose(1, 2)
|
| 177 |
+
|
| 178 |
+
group_list = expert_token_nums.to(torch.int64)
|
| 179 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 180 |
+
x=[expand_x],
|
| 181 |
+
weight=[w1],
|
| 182 |
+
split_item=2,
|
| 183 |
+
# 1 means count mode, to avoid cumulative operation of the group list
|
| 184 |
+
group_list_type=1,
|
| 185 |
+
group_type=0,
|
| 186 |
+
group_list=group_list,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# TODO: Remove this in the future.
|
| 190 |
+
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
| 191 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 192 |
+
|
| 193 |
+
w2 = w2.transpose(1, 2)
|
| 194 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 195 |
+
x=[gate_up_out],
|
| 196 |
+
weight=[w2],
|
| 197 |
+
split_item=2,
|
| 198 |
+
group_list_type=1,
|
| 199 |
+
group_type=0,
|
| 200 |
+
group_list=group_list,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
down_out_list = torch.cat(down_out_list, dim=0)
|
| 204 |
+
|
| 205 |
+
# moeCombine
|
| 206 |
+
kwargs_mc2 = {
|
| 207 |
+
"expand_x": down_out_list,
|
| 208 |
+
"expert_ids": topk_ids,
|
| 209 |
+
"expand_idx": expand_idx,
|
| 210 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 211 |
+
"expert_shard_type": 0,
|
| 212 |
+
"shared_expert_rank_num": 0,
|
| 213 |
+
"moe_expert_num": moe_expert_num,
|
| 214 |
+
"global_bs": 0,
|
| 215 |
+
}
|
| 216 |
+
tp_recv_counts = output[5]
|
| 217 |
+
stage3_kwargs = {
|
| 218 |
+
"ep_send_counts": ep_recv_counts,
|
| 219 |
+
"group_ep": moe_all_to_all_group_name,
|
| 220 |
+
"ep_world_size": all_to_all_group_size,
|
| 221 |
+
"ep_rank_id": local_rank,
|
| 222 |
+
"tp_send_counts": tp_recv_counts,
|
| 223 |
+
# "group_tp": self.moe_rs_group_name,
|
| 224 |
+
"group_tp": moe_all_to_all_group_name,
|
| 225 |
+
"tp_world_size": tp_size,
|
| 226 |
+
"tp_rank_id": tp_rank,
|
| 227 |
+
}
|
| 228 |
+
kwargs_mc2.update(stage3_kwargs)
|
| 229 |
+
|
| 230 |
+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
| 231 |
+
|
| 232 |
+
if shared_experts is None:
|
| 233 |
+
return hidden_states
|
| 234 |
+
else:
|
| 235 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 236 |
+
npu_wait_tensor(shared_act, down_out_list)
|
| 237 |
+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
| 238 |
+
return hidden_states, shared_hidden_states
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
| 242 |
+
w1: torch.Tensor,
|
| 243 |
+
w2: torch.Tensor,
|
| 244 |
+
group_list: torch.Tensor,
|
| 245 |
+
group_list_type: int = 1) -> torch.Tensor:
|
| 246 |
+
"""
|
| 247 |
+
apply MLP: gate_up_proj -> swiglu -> down_proj
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
| 251 |
+
w1: expert weights1 with shape
|
| 252 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 253 |
+
w2: expert weights2 with shape
|
| 254 |
+
(num_experts, intermediate_size, hidden_size)
|
| 255 |
+
group_list: number of tokens for each expert, follow cumsum mode, and
|
| 256 |
+
with shape (num_experts).
|
| 257 |
+
transpose_weight:
|
| 258 |
+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
| 259 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 260 |
+
w2: (num_experts, hidden_size, intermediate_size) ->
|
| 261 |
+
(num_experts, intermediate_size, hidden_size)
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
hidden_states: output hidden states after MLP.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
assert len(hidden_states_wrapper) == 1
|
| 268 |
+
hidden_states = hidden_states_wrapper.pop()
|
| 269 |
+
|
| 270 |
+
w1 = w1.transpose(1, 2)
|
| 271 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 272 |
+
x=[hidden_states],
|
| 273 |
+
weight=[w1],
|
| 274 |
+
split_item=2,
|
| 275 |
+
group_list_type=group_list_type,
|
| 276 |
+
group_type=0,
|
| 277 |
+
group_list=group_list,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
hidden_states = torch.cat(hidden_states, dim=0)
|
| 281 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 282 |
+
|
| 283 |
+
w2 = w2.transpose(1, 2)
|
| 284 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 285 |
+
x=[hidden_states],
|
| 286 |
+
weight=[w2],
|
| 287 |
+
split_item=2,
|
| 288 |
+
group_list_type=group_list_type,
|
| 289 |
+
group_type=0,
|
| 290 |
+
group_list=group_list,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
hidden_states = torch.cat(hidden_states, dim=0)
|
| 294 |
+
return hidden_states
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def fused_experts_with_all2all(
|
| 298 |
+
hidden_states: torch.Tensor,
|
| 299 |
+
w1: torch.Tensor,
|
| 300 |
+
w2: torch.Tensor,
|
| 301 |
+
topk_weights: torch.Tensor,
|
| 302 |
+
topk_ids: torch.Tensor,
|
| 303 |
+
top_k: int,
|
| 304 |
+
expert_map: torch.Tensor = None,
|
| 305 |
+
ep_group: GroupCoordinator = None,
|
| 306 |
+
):
|
| 307 |
+
original_shape = hidden_states.shape
|
| 308 |
+
if len(original_shape) == 3:
|
| 309 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 310 |
+
|
| 311 |
+
num_tokens, _ = hidden_states.shape
|
| 312 |
+
num_experts = w1.shape[0]
|
| 313 |
+
device = hidden_states.device
|
| 314 |
+
|
| 315 |
+
if expert_map is not None:
|
| 316 |
+
global_num_experts = len(expert_map)
|
| 317 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 318 |
+
row_idx_len = num_tokens * top_k
|
| 319 |
+
row_idx = (torch.arange(0,
|
| 320 |
+
row_idx_len,
|
| 321 |
+
dtype=torch.int32,
|
| 322 |
+
device=device).view(top_k, -1).permute(
|
| 323 |
+
1, 0).contiguous())
|
| 324 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 325 |
+
hidden_states,
|
| 326 |
+
row_idx=row_idx,
|
| 327 |
+
expert_idx=topk_ids,
|
| 328 |
+
active_num=num_tokens)
|
| 329 |
+
|
| 330 |
+
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
| 331 |
+
minlength=global_num_experts)
|
| 332 |
+
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
| 333 |
+
-1).sum(-1)
|
| 334 |
+
|
| 335 |
+
gather_sizes = torch.empty_like(scatter_sizes)
|
| 336 |
+
dist.all_to_all_single(gather_sizes,
|
| 337 |
+
scatter_sizes,
|
| 338 |
+
group=ep_group.device_group)
|
| 339 |
+
scatter_size_list = scatter_sizes.cpu().tolist()
|
| 340 |
+
gather_size_list = gather_sizes.cpu().tolist()
|
| 341 |
+
|
| 342 |
+
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
| 343 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 344 |
+
scatter_size_list,
|
| 345 |
+
gather_size_list)
|
| 346 |
+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
| 347 |
+
scatter_size_list,
|
| 348 |
+
gather_size_list)
|
| 349 |
+
|
| 350 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
| 351 |
+
|
| 352 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 353 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 354 |
+
|
| 355 |
+
hidden_states = hidden_states[sorted_idx]
|
| 356 |
+
else:
|
| 357 |
+
row_idx_len = num_tokens * top_k
|
| 358 |
+
row_idx = torch.arange(0,
|
| 359 |
+
row_idx_len,
|
| 360 |
+
dtype=torch.int32,
|
| 361 |
+
device=topk_weights.device).view(
|
| 362 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 363 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 364 |
+
hidden_states,
|
| 365 |
+
row_idx=row_idx,
|
| 366 |
+
expert_idx=topk_ids,
|
| 367 |
+
active_num=num_tokens)
|
| 368 |
+
|
| 369 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 370 |
+
expanded_expert_idx, num_experts)
|
| 371 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 372 |
+
|
| 373 |
+
w1 = w1.transpose(1, 2)
|
| 374 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 375 |
+
x=[hidden_states],
|
| 376 |
+
weight=[w1],
|
| 377 |
+
split_item=2,
|
| 378 |
+
group_list_type=0,
|
| 379 |
+
group_type=0,
|
| 380 |
+
group_list=expert_tokens,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# TODO: Remove this in the future.
|
| 384 |
+
hidden_states = torch.cat(gate_up_out_list, dim=0)
|
| 385 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 386 |
+
|
| 387 |
+
w2 = w2.transpose(1, 2)
|
| 388 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 389 |
+
x=[hidden_states],
|
| 390 |
+
weight=[w2],
|
| 391 |
+
split_item=2,
|
| 392 |
+
group_list_type=0,
|
| 393 |
+
group_type=0,
|
| 394 |
+
group_list=expert_tokens,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
hidden_states = torch.cat(down_out_list, dim=0)
|
| 398 |
+
|
| 399 |
+
if expert_map is not None:
|
| 400 |
+
resorted_idx = torch.argsort(sorted_idx)
|
| 401 |
+
hidden_states = hidden_states[resorted_idx]
|
| 402 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 403 |
+
gather_size_list,
|
| 404 |
+
scatter_size_list)
|
| 405 |
+
|
| 406 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 407 |
+
hidden_states,
|
| 408 |
+
skip1=None,
|
| 409 |
+
skip2=None,
|
| 410 |
+
bias=None,
|
| 411 |
+
scales=topk_weights,
|
| 412 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 413 |
+
export_for_source_row=topk_ids,
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 417 |
+
# implementation here when suitable operators become available.
|
| 418 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 419 |
+
hidden_states,
|
| 420 |
+
skip1=None,
|
| 421 |
+
skip2=None,
|
| 422 |
+
bias=None,
|
| 423 |
+
scales=topk_weights,
|
| 424 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 425 |
+
export_for_source_row=topk_ids,
|
| 426 |
+
)
|
| 427 |
+
if len(original_shape) == 3:
|
| 428 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 429 |
+
return final_hidden_states
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# currently expert parallelism implemented with all2all
|
| 433 |
+
# is under-optimized.
|
| 434 |
+
def fused_experts_with_all2all_buffer(
|
| 435 |
+
hidden_states: torch.Tensor,
|
| 436 |
+
w1: torch.Tensor,
|
| 437 |
+
w2: torch.Tensor,
|
| 438 |
+
topk_weights: torch.Tensor,
|
| 439 |
+
topk_ids: torch.Tensor,
|
| 440 |
+
top_k: int,
|
| 441 |
+
max_model_len: int,
|
| 442 |
+
global_batch_size: int,
|
| 443 |
+
expert_map: torch.Tensor = None,
|
| 444 |
+
ep_group: GroupCoordinator = None,
|
| 445 |
+
):
|
| 446 |
+
original_shape = hidden_states.shape
|
| 447 |
+
if len(original_shape) == 3:
|
| 448 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 449 |
+
|
| 450 |
+
num_tokens, _ = hidden_states.shape
|
| 451 |
+
device = hidden_states.device
|
| 452 |
+
|
| 453 |
+
global_num_experts = len(expert_map)
|
| 454 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 455 |
+
row_idx_len = num_tokens * top_k
|
| 456 |
+
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
|
| 457 |
+
device=device).view(top_k,
|
| 458 |
+
-1).permute(1, 0).contiguous())
|
| 459 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 460 |
+
hidden_states,
|
| 461 |
+
row_idx=row_idx,
|
| 462 |
+
expert_idx=topk_ids,
|
| 463 |
+
active_num=num_tokens)
|
| 464 |
+
|
| 465 |
+
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
|
| 466 |
+
max_model_len // ep_group.world_size +
|
| 467 |
+
1) * top_k * 2
|
| 468 |
+
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
|
| 469 |
+
expanded_expert_idx, global_num_experts, ep_group.world_size,
|
| 470 |
+
max_row_per_ep_rank, num_tokens, top_k)
|
| 471 |
+
hidden_states_pad_idx = torch.zeros(
|
| 472 |
+
expert_idx_buffer_scatter.shape,
|
| 473 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 474 |
+
device=expert_idx_buffer_scatter.device)
|
| 475 |
+
non_pad_len = torch.sum((expert_idx_buffer_scatter
|
| 476 |
+
!= global_num_experts).to(torch.int32))
|
| 477 |
+
hidden_states_pad_idx[expert_idx_buffer_scatter !=
|
| 478 |
+
global_num_experts] = torch.arange(
|
| 479 |
+
non_pad_len,
|
| 480 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 481 |
+
device=hidden_states.device)
|
| 482 |
+
|
| 483 |
+
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
|
| 484 |
+
expert_idx_buffer_gather = torch.empty_like(
|
| 485 |
+
expert_idx_buffer_scatter,
|
| 486 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 487 |
+
device=expert_idx_buffer_scatter.device)
|
| 488 |
+
hidden_states_buffer_gather = torch.empty_like(
|
| 489 |
+
hidden_states_buffer_scatter,
|
| 490 |
+
dtype=hidden_states_buffer_scatter.dtype,
|
| 491 |
+
device=hidden_states_buffer_scatter.device)
|
| 492 |
+
dist.all_to_all_single(expert_idx_buffer_gather,
|
| 493 |
+
expert_idx_buffer_scatter,
|
| 494 |
+
group=ep_group.device_group)
|
| 495 |
+
dist.all_to_all_single(hidden_states_buffer_gather,
|
| 496 |
+
hidden_states_buffer_scatter,
|
| 497 |
+
group=ep_group.device_group)
|
| 498 |
+
mask = expert_idx_buffer_gather != global_num_experts
|
| 499 |
+
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
|
| 500 |
+
global_num_experts // ep_group.world_size)
|
| 501 |
+
hidden_states = hidden_states_buffer_gather[mask]
|
| 502 |
+
idx_type = local_expert_idx.dtype
|
| 503 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
|
| 504 |
+
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
|
| 505 |
+
|
| 506 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 507 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 508 |
+
hidden_states = hidden_states[sorted_idx]
|
| 509 |
+
group_list_type = 0
|
| 510 |
+
|
| 511 |
+
hidden_states_wrapper = [hidden_states]
|
| 512 |
+
del hidden_states
|
| 513 |
+
|
| 514 |
+
hidden_states = apply_mlp(hidden_states_wrapper,
|
| 515 |
+
w1,
|
| 516 |
+
w2,
|
| 517 |
+
expert_tokens,
|
| 518 |
+
group_list_type=group_list_type)
|
| 519 |
+
|
| 520 |
+
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
|
| 521 |
+
hidden_states = hidden_states[resorted_idx]
|
| 522 |
+
hidden_states_scatter = torch.zeros(
|
| 523 |
+
(mask.shape[0], hidden_states.shape[1]),
|
| 524 |
+
dtype=hidden_states.dtype,
|
| 525 |
+
device=hidden_states.device)
|
| 526 |
+
hidden_states_scatter[mask] = hidden_states
|
| 527 |
+
hidden_states_gatter = torch.empty_like(
|
| 528 |
+
hidden_states_scatter,
|
| 529 |
+
dtype=hidden_states_scatter.dtype,
|
| 530 |
+
device=hidden_states_scatter.device)
|
| 531 |
+
dist.all_to_all_single(hidden_states_gatter,
|
| 532 |
+
hidden_states_scatter,
|
| 533 |
+
group=ep_group.device_group)
|
| 534 |
+
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
|
| 535 |
+
global_num_experts]
|
| 536 |
+
if hidden_states_gatter.shape[0] != row_idx_len:
|
| 537 |
+
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
|
| 538 |
+
dtype=hidden_states.dtype,
|
| 539 |
+
device=hidden_states.device)
|
| 540 |
+
hidden_states[unpad_indices != -1] = hidden_states_gatter
|
| 541 |
+
else:
|
| 542 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 543 |
+
hidden_states = hidden_states_gatter
|
| 544 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 545 |
+
hidden_states,
|
| 546 |
+
skip1=None,
|
| 547 |
+
skip2=None,
|
| 548 |
+
bias=None,
|
| 549 |
+
scales=topk_weights,
|
| 550 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 551 |
+
export_for_source_row=topk_ids,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
if len(original_shape) == 3:
|
| 555 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 556 |
+
return final_hidden_states
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def fused_experts_moge(
|
| 560 |
+
hidden_states: torch.Tensor,
|
| 561 |
+
w1: torch.Tensor,
|
| 562 |
+
w2: torch.Tensor,
|
| 563 |
+
topk_weights: torch.Tensor,
|
| 564 |
+
topk_ids: torch.Tensor,
|
| 565 |
+
top_k: int,
|
| 566 |
+
global_num_experts: int,
|
| 567 |
+
expert_map: torch.Tensor = None,
|
| 568 |
+
apply_router_weight_on_input: bool = False,
|
| 569 |
+
) -> torch.Tensor:
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 574 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 575 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 576 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 577 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 578 |
+
top_k: Number of experts to select.
|
| 579 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
hidden_states: Hidden states after routing.
|
| 583 |
+
"""
|
| 584 |
+
ep_size = get_ep_group().world_size
|
| 585 |
+
local_num_experts = global_num_experts // ep_size
|
| 586 |
+
local_num_group = top_k // ep_size
|
| 587 |
+
|
| 588 |
+
if apply_router_weight_on_input:
|
| 589 |
+
assert (topk_weights.dim() == 2
|
| 590 |
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
| 591 |
+
_, topk = topk_weights.shape
|
| 592 |
+
assert (
|
| 593 |
+
topk == 1
|
| 594 |
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
| 595 |
+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
| 596 |
+
|
| 597 |
+
bsz, _ = hidden_states.shape
|
| 598 |
+
flatten_topk_ids = topk_ids.view(-1)
|
| 599 |
+
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
| 600 |
+
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
| 601 |
+
sorted_hidden_states = hidden_states.index_select(
|
| 602 |
+
0, sorted_topk_ids // local_num_group)
|
| 603 |
+
|
| 604 |
+
experts_id = torch.arange(0,
|
| 605 |
+
local_num_experts,
|
| 606 |
+
dtype=topk_ids.dtype,
|
| 607 |
+
device=topk_ids.device)
|
| 608 |
+
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
| 609 |
+
torch.float32).sum(0)
|
| 610 |
+
topk_scales = topk_weights.view(-1).index_select(
|
| 611 |
+
0, sorted_topk_ids).unsqueeze(-1)
|
| 612 |
+
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
| 613 |
+
|
| 614 |
+
w1 = w1.transpose(1, 2)
|
| 615 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 616 |
+
x=[sorted_hidden_states],
|
| 617 |
+
weight=[w1],
|
| 618 |
+
split_item=2,
|
| 619 |
+
group_list_type=0,
|
| 620 |
+
group_type=0,
|
| 621 |
+
group_list=group_list,
|
| 622 |
+
)[0]
|
| 623 |
+
|
| 624 |
+
if is_310p():
|
| 625 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
| 626 |
+
torch.float16)
|
| 627 |
+
else:
|
| 628 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 629 |
+
gate_up_out *= topk_scales
|
| 630 |
+
|
| 631 |
+
w2 = w2.transpose(1, 2)
|
| 632 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 633 |
+
x=[gate_up_out],
|
| 634 |
+
weight=[w2],
|
| 635 |
+
split_item=2,
|
| 636 |
+
group_list_type=0,
|
| 637 |
+
group_type=0,
|
| 638 |
+
group_list=group_list,
|
| 639 |
+
)[0]
|
| 640 |
+
|
| 641 |
+
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
| 642 |
+
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
| 643 |
+
final_hidden_states = unsorted_hidden_states.reshape(
|
| 644 |
+
bsz, top_k // ep_size, -1).sum(1)
|
| 645 |
+
|
| 646 |
+
return final_hidden_states
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def fused_experts(
|
| 650 |
+
hidden_states: torch.Tensor,
|
| 651 |
+
w1: torch.Tensor,
|
| 652 |
+
w2: torch.Tensor,
|
| 653 |
+
topk_weights: torch.Tensor,
|
| 654 |
+
topk_ids: torch.Tensor,
|
| 655 |
+
top_k: int,
|
| 656 |
+
expert_map: torch.Tensor = None,
|
| 657 |
+
apply_router_weight_on_input: bool = False,
|
| 658 |
+
max_num_tokens: Optional[int] = None,
|
| 659 |
+
) -> torch.Tensor:
|
| 660 |
+
"""
|
| 661 |
+
Fused experts with top-k routing.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 665 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 666 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 667 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 668 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 669 |
+
top_k: Number of experts to select.
|
| 670 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 671 |
+
|
| 672 |
+
Returns:
|
| 673 |
+
hidden_states: Hidden states after routing.
|
| 674 |
+
"""
|
| 675 |
+
"""
|
| 676 |
+
# Check constraints.
|
| 677 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
| 678 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
| 679 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 680 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 681 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 682 |
+
"""
|
| 683 |
+
# if torch.distributed.get_rank() == 0:
|
| 684 |
+
# print(w1.shape)
|
| 685 |
+
# print(hidden_states.shape)
|
| 686 |
+
|
| 687 |
+
original_shape = hidden_states.shape
|
| 688 |
+
# assert len(original_shape) == 2
|
| 689 |
+
|
| 690 |
+
num_tokens = hidden_states.shape[:-1].numel()
|
| 691 |
+
num_experts = w1.shape[0]
|
| 692 |
+
dtype = hidden_states.dtype
|
| 693 |
+
device = hidden_states.device
|
| 694 |
+
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
| 695 |
+
# ], "Only float32, float16, and bfloat16 are supported"
|
| 696 |
+
|
| 697 |
+
if apply_router_weight_on_input:
|
| 698 |
+
assert (topk_weights.dim() == 2
|
| 699 |
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
| 700 |
+
_, topk = topk_weights.shape
|
| 701 |
+
assert (
|
| 702 |
+
topk == 1
|
| 703 |
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
| 704 |
+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
| 705 |
+
|
| 706 |
+
if expert_map is not None:
|
| 707 |
+
# Generate token indices and flatten
|
| 708 |
+
token_indices = (torch.arange(num_tokens,
|
| 709 |
+
device=device,
|
| 710 |
+
dtype=torch.int64).unsqueeze(1).expand(
|
| 711 |
+
-1, top_k).reshape(-1))
|
| 712 |
+
|
| 713 |
+
# Flatten token-to-expert mappings and map to local experts
|
| 714 |
+
weights_flat = topk_weights.view(-1)
|
| 715 |
+
experts_flat = topk_ids.view(-1)
|
| 716 |
+
local_experts_flat = expert_map[experts_flat]
|
| 717 |
+
|
| 718 |
+
# Filter valid token-expert pairs
|
| 719 |
+
mask = local_experts_flat != -1
|
| 720 |
+
filtered_weights = torch.where(
|
| 721 |
+
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
| 722 |
+
filtered_experts = torch.where(
|
| 723 |
+
mask, local_experts_flat,
|
| 724 |
+
torch.full_like(local_experts_flat,
|
| 725 |
+
num_experts)).to(topk_ids.dtype)
|
| 726 |
+
|
| 727 |
+
# Sort by local expert IDs
|
| 728 |
+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
| 729 |
+
sorted_token_indices = token_indices[sort_indices]
|
| 730 |
+
sorted_weights = filtered_weights[sort_indices]
|
| 731 |
+
|
| 732 |
+
# Compute token counts with minlength of num_experts
|
| 733 |
+
# This is equivalent to but faster than:
|
| 734 |
+
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
| 735 |
+
token_counts = torch.zeros(num_experts + 1,
|
| 736 |
+
device=device,
|
| 737 |
+
dtype=torch.int64)
|
| 738 |
+
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
| 739 |
+
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
| 740 |
+
token_counts = token_counts[:num_experts]
|
| 741 |
+
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
|
| 742 |
+
|
| 743 |
+
# Rearrange hidden_states
|
| 744 |
+
sorted_hidden_states = hidden_states[sorted_token_indices]
|
| 745 |
+
else:
|
| 746 |
+
row_idx_len = num_tokens * top_k
|
| 747 |
+
row_idx = (torch.arange(0,
|
| 748 |
+
row_idx_len,
|
| 749 |
+
dtype=torch.int32,
|
| 750 |
+
device=device).view(top_k, -1).permute(
|
| 751 |
+
1, 0).contiguous())
|
| 752 |
+
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
|
| 753 |
+
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 754 |
+
hidden_states,
|
| 755 |
+
row_idx=row_idx,
|
| 756 |
+
expert_idx=topk_ids,
|
| 757 |
+
active_num=active_num)
|
| 758 |
+
|
| 759 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 760 |
+
expanded_expert_idx, num_experts)
|
| 761 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 762 |
+
|
| 763 |
+
w1 = w1.transpose(1, 2)
|
| 764 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 765 |
+
x=[sorted_hidden_states],
|
| 766 |
+
weight=[w1],
|
| 767 |
+
split_item=2,
|
| 768 |
+
group_list_type=0,
|
| 769 |
+
group_type=0,
|
| 770 |
+
group_list=expert_tokens,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# TODO: Remove this in the future.
|
| 774 |
+
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
| 775 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 776 |
+
|
| 777 |
+
w2 = w2.transpose(1, 2)
|
| 778 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 779 |
+
x=[gate_up_out],
|
| 780 |
+
weight=[w2],
|
| 781 |
+
split_item=2,
|
| 782 |
+
group_list_type=0,
|
| 783 |
+
group_type=0,
|
| 784 |
+
group_list=expert_tokens,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
down_out_list = torch.cat(down_out_list, dim=0)
|
| 788 |
+
|
| 789 |
+
if expert_map is not None:
|
| 790 |
+
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
|
| 791 |
+
|
| 792 |
+
final_hidden_states = torch.zeros(*original_shape,
|
| 793 |
+
device=hidden_states.device,
|
| 794 |
+
dtype=dtype)
|
| 795 |
+
|
| 796 |
+
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
| 797 |
+
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
| 798 |
+
# remove this mask and filter after it being fixed
|
| 799 |
+
num_valid_tokens = mask.sum()
|
| 800 |
+
valid_token_mask = torch.arange(
|
| 801 |
+
0, sorted_token_indices.shape[0],
|
| 802 |
+
device=device).unsqueeze(1) < num_valid_tokens
|
| 803 |
+
valid_output = torch.where(
|
| 804 |
+
valid_token_mask, weighted_down_out,
|
| 805 |
+
torch.zeros_like(weighted_down_out)).to(dtype)
|
| 806 |
+
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
| 807 |
+
else:
|
| 808 |
+
scales = torch.ones_like(
|
| 809 |
+
topk_weights) if apply_router_weight_on_input else topk_weights
|
| 810 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 811 |
+
# implementation here when suitable operators become available.
|
| 812 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 813 |
+
down_out_list,
|
| 814 |
+
skip1=None,
|
| 815 |
+
skip2=None,
|
| 816 |
+
bias=None,
|
| 817 |
+
scales=scales,
|
| 818 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 819 |
+
export_for_source_row=topk_ids,
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
return final_hidden_states
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def fused_experts_allgather_ep(
|
| 826 |
+
hidden_states: torch.Tensor,
|
| 827 |
+
w1: torch.Tensor,
|
| 828 |
+
w2: torch.Tensor,
|
| 829 |
+
topk_weights: torch.Tensor,
|
| 830 |
+
topk_ids: torch.Tensor,
|
| 831 |
+
is_prefill: bool
|
| 832 |
+
):
|
| 833 |
+
local_rank = torch.distributed.get_rank(group=get_ep_group().device_group)
|
| 834 |
+
num_experts_per_ep = w1.shape[0]
|
| 835 |
+
local_expert_indices_offset = local_rank * num_experts_per_ep
|
| 836 |
+
global_local_mask = (topk_ids >= local_expert_indices_offset) & \
|
| 837 |
+
(topk_ids <= local_expert_indices_offset + num_experts_per_ep - 1)
|
| 838 |
+
non_global_local_mask = (~global_local_mask).to(torch.int32)
|
| 839 |
+
global_local_mask = global_local_mask.to(torch.int32)
|
| 840 |
+
row_idx = torch.arange(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32).view(
|
| 841 |
+
-1, topk_ids.shape[0]).transpose(0, 1).contiguous()
|
| 842 |
+
|
| 843 |
+
topk_ids -= local_expert_indices_offset
|
| 844 |
+
local_topk_ids_mask_with_max = topk_ids * global_local_mask + non_global_local_mask * num_experts_per_ep
|
| 845 |
+
sorted_tokens, expanded_src_to_dst_row, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 846 |
+
x=hidden_states,
|
| 847 |
+
row_idx=row_idx,
|
| 848 |
+
expert_idx=local_topk_ids_mask_with_max,
|
| 849 |
+
active_num=topk_ids.shape[0]*topk_ids.shape[1]
|
| 850 |
+
)
|
| 851 |
+
if expanded_expert_idx.shape[0] > 8192:
|
| 852 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep + 1)
|
| 853 |
+
expert_tokens = expert_tokens[:-1]
|
| 854 |
+
else:
|
| 855 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep)
|
| 856 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 857 |
+
|
| 858 |
+
w1 = w1.transpose(1, 2)
|
| 859 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 860 |
+
x=[sorted_tokens],
|
| 861 |
+
weight=[w1],
|
| 862 |
+
group_list=expert_tokens,
|
| 863 |
+
split_item=3,
|
| 864 |
+
group_type=0
|
| 865 |
+
)[0]
|
| 866 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 867 |
+
|
| 868 |
+
w2 = w2.transpose(1, 2)
|
| 869 |
+
down_out = torch_npu.npu_grouped_matmul(
|
| 870 |
+
x=[gate_up_out],
|
| 871 |
+
weight=[w2],
|
| 872 |
+
group_list=expert_tokens,
|
| 873 |
+
split_item=3,
|
| 874 |
+
group_type=0
|
| 875 |
+
)[0]
|
| 876 |
+
|
| 877 |
+
if is_prefill:
|
| 878 |
+
down_out[expert_tokens[-1]:] = 0
|
| 879 |
+
else:
|
| 880 |
+
sorted_tokens_mask = expanded_expert_idx != num_experts_per_ep
|
| 881 |
+
down_out *= sorted_tokens_mask.unsqueeze(1)
|
| 882 |
+
|
| 883 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 884 |
+
expanded_permuted_rows=down_out,
|
| 885 |
+
skip1=None,
|
| 886 |
+
skip2=None,
|
| 887 |
+
bias=None,
|
| 888 |
+
scales=topk_weights.to(down_out.dtype),
|
| 889 |
+
expanded_src_to_dst_row=expanded_src_to_dst_row,
|
| 890 |
+
export_for_source_row=topk_ids
|
| 891 |
+
)
|
| 892 |
+
return final_hidden_states
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def select_gating_top_k_softmax_experts(
|
| 896 |
+
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
|
| 897 |
+
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
| 898 |
+
"""
|
| 899 |
+
Select top-k experts based on router logits.
|
| 900 |
+
only supports float16、bfloat16、float32
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 904 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 905 |
+
top_k: Number of experts to select.
|
| 906 |
+
renormalize: Whether to renormalize the routing weights.
|
| 907 |
+
|
| 908 |
+
Returns:
|
| 909 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 910 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 911 |
+
|
| 912 |
+
Raises:
|
| 913 |
+
ValueError: If an unsupported scoring function is provided.
|
| 914 |
+
"""
|
| 915 |
+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
| 916 |
+
router_logits, None, k=top_k)
|
| 917 |
+
|
| 918 |
+
# # Required by npu_moe_init_routing
|
| 919 |
+
# topk_weights = topk_weights.to(hidden_states.dtype)
|
| 920 |
+
# topk_ids = topk_ids.to(torch.int32)
|
| 921 |
+
|
| 922 |
+
if renormalize:
|
| 923 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 924 |
+
|
| 925 |
+
return topk_weights, topk_ids
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def native_grouped_topk(
|
| 929 |
+
topk_weights: torch.Tensor,
|
| 930 |
+
num_expert_group: Optional[int],
|
| 931 |
+
topk_group: Optional[int],
|
| 932 |
+
):
|
| 933 |
+
topk_group = 0 if topk_group is None else topk_group
|
| 934 |
+
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
| 935 |
+
|
| 936 |
+
num_token = topk_weights.shape[0]
|
| 937 |
+
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
| 938 |
+
-1).max(dim=-1).values
|
| 939 |
+
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
| 940 |
+
k=topk_group,
|
| 941 |
+
dim=-1,
|
| 942 |
+
sorted=False)[1]
|
| 943 |
+
topk_group_mask = torch.zeros_like(grouped_weights)
|
| 944 |
+
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
| 945 |
+
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
| 946 |
+
num_token, num_expert_group,
|
| 947 |
+
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
| 948 |
+
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
| 949 |
+
|
| 950 |
+
return topk_weights
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
def select_experts(
|
| 954 |
+
hidden_states: torch.Tensor,
|
| 955 |
+
router_logits: torch.Tensor,
|
| 956 |
+
top_k: int,
|
| 957 |
+
use_grouped_topk: bool,
|
| 958 |
+
renormalize: bool,
|
| 959 |
+
topk_group: Optional[int] = None,
|
| 960 |
+
num_expert_group: Optional[int] = None,
|
| 961 |
+
custom_routing_function: Optional[Callable] = None,
|
| 962 |
+
scoring_func: str = "softmax",
|
| 963 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 964 |
+
global_num_experts: Optional[torch.Tensor] = None
|
| 965 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 966 |
+
"""
|
| 967 |
+
Select top-k experts based on router logits.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 971 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 972 |
+
top_k: Number of experts to select.
|
| 973 |
+
use_grouped_topk: Whether to group experts before selecting top-k.
|
| 974 |
+
renormalize: Whether to renormalize the routing weights.
|
| 975 |
+
topk_group: Number of expert groups to select from.
|
| 976 |
+
num_expert_group: Number of experts in each group.
|
| 977 |
+
custom_routing_function: Custom routing function.
|
| 978 |
+
scoring_func: Scoring function to use.
|
| 979 |
+
e_score_correction_bias: Correction bias to apply to expert scores.
|
| 980 |
+
|
| 981 |
+
Returns:
|
| 982 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 983 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 984 |
+
|
| 985 |
+
Raises:
|
| 986 |
+
ValueError: If an unsupported scoring function is provided.
|
| 987 |
+
"""
|
| 988 |
+
|
| 989 |
+
if scoring_func == "softmax":
|
| 990 |
+
# NOTE: vLLM use dtype=torch.float here
|
| 991 |
+
topk_weights = router_logits.softmax(dim=-1)
|
| 992 |
+
elif scoring_func == "sigmoid":
|
| 993 |
+
topk_weights = router_logits.sigmoid()
|
| 994 |
+
else:
|
| 995 |
+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
| 996 |
+
|
| 997 |
+
if use_grouped_topk:
|
| 998 |
+
assert topk_group is not None
|
| 999 |
+
assert num_expert_group is not None
|
| 1000 |
+
|
| 1001 |
+
if e_score_correction_bias is not None:
|
| 1002 |
+
# Store original scores before applying correction bias. We use biased
|
| 1003 |
+
# scores for expert selection but original scores for routing weights
|
| 1004 |
+
original_weights = topk_weights
|
| 1005 |
+
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
| 1006 |
+
|
| 1007 |
+
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
| 1008 |
+
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
| 1009 |
+
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
| 1010 |
+
topk_group)
|
| 1011 |
+
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
| 1012 |
+
if e_score_correction_bias is not None:
|
| 1013 |
+
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 1014 |
+
k=top_k,
|
| 1015 |
+
dim=-1,
|
| 1016 |
+
sorted=False)[1]
|
| 1017 |
+
# Use original unbiased scores for the routing weights
|
| 1018 |
+
topk_weights = original_weights.gather(1, topk_ids)
|
| 1019 |
+
else:
|
| 1020 |
+
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 1021 |
+
k=top_k,
|
| 1022 |
+
dim=-1,
|
| 1023 |
+
sorted=False)
|
| 1024 |
+
elif custom_routing_function is None:
|
| 1025 |
+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
| 1026 |
+
else:
|
| 1027 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 1028 |
+
hidden_states=hidden_states,
|
| 1029 |
+
gating_output=router_logits,
|
| 1030 |
+
topk=top_k,
|
| 1031 |
+
renormalize=renormalize,
|
| 1032 |
+
global_num_experts=global_num_experts)
|
| 1033 |
+
# Required by npu_moe_init_routing
|
| 1034 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 1035 |
+
return topk_weights, topk_ids
|
| 1036 |
+
|
| 1037 |
+
# Required by npu_moe_init_routing
|
| 1038 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 1039 |
+
|
| 1040 |
+
if renormalize:
|
| 1041 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 1042 |
+
|
| 1043 |
+
return topk_weights, topk_ids
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
| 1047 |
+
|
| 1048 |
+
def __init__(self, moe: FusedMoEConfig = None):
|
| 1049 |
+
|
| 1050 |
+
super().__init__(moe=moe)
|
| 1051 |
+
vllm_config = get_current_vllm_config()
|
| 1052 |
+
|
| 1053 |
+
self.ep_group = get_ep_group()
|
| 1054 |
+
self.ep_size = self.ep_group.world_size
|
| 1055 |
+
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
| 1056 |
+
self.local_batch_size = self.global_batch_size // self.ep_size
|
| 1057 |
+
self.max_model_len = vllm_config.model_config.max_model_len
|
| 1058 |
+
|
| 1059 |
+
ascend_config = get_ascend_config()
|
| 1060 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 1061 |
+
|
| 1062 |
+
try:
|
| 1063 |
+
device_group = self.ep_group.device_group
|
| 1064 |
+
# TODO: Try local_rank = ep_group.rank_in_group
|
| 1065 |
+
local_rank = torch.distributed.get_rank(group=device_group)
|
| 1066 |
+
backend = device_group._get_backend(torch.device("npu"))
|
| 1067 |
+
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
| 1068 |
+
local_rank)
|
| 1069 |
+
except AttributeError:
|
| 1070 |
+
self.moe_all_to_all_group_name = None
|
| 1071 |
+
|
| 1072 |
+
def process_weights_after_loading(self, layer):
|
| 1073 |
+
super(UnquantizedFusedMoEMethod,
|
| 1074 |
+
self).process_weights_after_loading(layer)
|
| 1075 |
+
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
| 1076 |
+
layer.w13_weight.data),
|
| 1077 |
+
requires_grad=False)
|
| 1078 |
+
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
| 1079 |
+
layer.w2_weight.data),
|
| 1080 |
+
requires_grad=False)
|
| 1081 |
+
|
| 1082 |
+
def apply(
|
| 1083 |
+
self,
|
| 1084 |
+
layer: torch.nn.Module,
|
| 1085 |
+
x: torch.Tensor,
|
| 1086 |
+
router_logits: torch.Tensor,
|
| 1087 |
+
top_k: int,
|
| 1088 |
+
renormalize: bool,
|
| 1089 |
+
use_grouped_topk: bool = False,
|
| 1090 |
+
global_num_experts: int = -1,
|
| 1091 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 1092 |
+
topk_group: Optional[int] = None,
|
| 1093 |
+
num_expert_group: Optional[int] = None,
|
| 1094 |
+
custom_routing_function: Optional[Callable] = None,
|
| 1095 |
+
scoring_func: str = "softmax",
|
| 1096 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 1097 |
+
is_prefill: bool = False,
|
| 1098 |
+
enable_force_load_balance: bool = False,
|
| 1099 |
+
shared_experts: Optional[Any] = None,
|
| 1100 |
+
**kwargs,
|
| 1101 |
+
) -> torch.Tensor:
|
| 1102 |
+
use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
|
| 1103 |
+
is_deepseek_v3_r1 = global_num_experts == 256
|
| 1104 |
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
| 1105 |
+
if use_grouped_topk and is_deepseek_v3_r1:
|
| 1106 |
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
| 1107 |
+
router_logits,
|
| 1108 |
+
k=top_k, # topk当前写8
|
| 1109 |
+
bias=e_score_correction_bias,
|
| 1110 |
+
k_group=topk_group, # fix: 4
|
| 1111 |
+
group_count=num_expert_group, # fix 8
|
| 1112 |
+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
| 1113 |
+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
| 1114 |
+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
| 1115 |
+
# out_flag=False, # todo new api; 第三个输出是否输出
|
| 1116 |
+
# y2_flag=False, # old api; 第三个输出是否输出
|
| 1117 |
+
routed_scaling_factor=1,
|
| 1118 |
+
eps=float(1e-20))
|
| 1119 |
+
elif use_grouped_topk and SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
| 1120 |
+
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
| 1121 |
+
hidden_states=x,
|
| 1122 |
+
router_logits=router_logits,
|
| 1123 |
+
top_k=top_k,
|
| 1124 |
+
renormalize=renormalize)
|
| 1125 |
+
else:
|
| 1126 |
+
topk_weights, topk_ids = select_experts(
|
| 1127 |
+
hidden_states=x,
|
| 1128 |
+
router_logits=router_logits,
|
| 1129 |
+
top_k=top_k,
|
| 1130 |
+
use_grouped_topk=use_grouped_topk,
|
| 1131 |
+
renormalize=renormalize,
|
| 1132 |
+
topk_group=topk_group,
|
| 1133 |
+
num_expert_group=num_expert_group,
|
| 1134 |
+
custom_routing_function=custom_routing_function,
|
| 1135 |
+
scoring_func=scoring_func,
|
| 1136 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
topk_weights = topk_weights.to(x.dtype)
|
| 1140 |
+
# this is a naive implementation for experts load balance so as
|
| 1141 |
+
# to avoid accumulating too much tokens on a single rank.
|
| 1142 |
+
# currently it is only activated when doing profile runs.
|
| 1143 |
+
if enable_force_load_balance:
|
| 1144 |
+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
| 1145 |
+
|
| 1146 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
| 1147 |
+
is_prefill, is_deepseek_v3_r1)
|
| 1148 |
+
if fused_moe_state == FusedMoEState.MC2:
|
| 1149 |
+
return fused_experts_with_mc2(
|
| 1150 |
+
hidden_states=x,
|
| 1151 |
+
w1=layer.w13_weight,
|
| 1152 |
+
w2=layer.w2_weight,
|
| 1153 |
+
topk_weights=topk_weights,
|
| 1154 |
+
topk_ids=topk_ids,
|
| 1155 |
+
top_k=top_k,
|
| 1156 |
+
expert_map=expert_map,
|
| 1157 |
+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
| 1158 |
+
shared_experts=shared_experts)
|
| 1159 |
+
elif fused_moe_state == FusedMoEState.AllGatherEP:
|
| 1160 |
+
return fused_experts_allgather_ep(
|
| 1161 |
+
hidden_states=x,
|
| 1162 |
+
w1=layer.w13_weight,
|
| 1163 |
+
w2=layer.w2_weight,
|
| 1164 |
+
topk_weights=topk_weights,
|
| 1165 |
+
topk_ids=topk_ids,
|
| 1166 |
+
is_prefill=is_prefill)
|
| 1167 |
+
elif fused_moe_state in [
|
| 1168 |
+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
| 1169 |
+
]:
|
| 1170 |
+
return fused_experts(hidden_states=x,
|
| 1171 |
+
w1=layer.w13_weight,
|
| 1172 |
+
w2=layer.w2_weight,
|
| 1173 |
+
topk_weights=topk_weights,
|
| 1174 |
+
topk_ids=topk_ids,
|
| 1175 |
+
top_k=top_k,
|
| 1176 |
+
expert_map=expert_map)
|
| 1177 |
+
elif MOE_ALL2ALL_BUFFER:
|
| 1178 |
+
return fused_experts_with_all2all_buffer(
|
| 1179 |
+
hidden_states=x,
|
| 1180 |
+
w1=layer.w13_weight,
|
| 1181 |
+
w2=layer.w2_weight,
|
| 1182 |
+
topk_weights=topk_weights,
|
| 1183 |
+
topk_ids=topk_ids,
|
| 1184 |
+
top_k=top_k,
|
| 1185 |
+
max_model_len=self.max_model_len,
|
| 1186 |
+
global_batch_size=self.global_batch_size,
|
| 1187 |
+
expert_map=expert_map,
|
| 1188 |
+
ep_group=get_ep_group())
|
| 1189 |
+
else:
|
| 1190 |
+
return fused_experts_with_all2all(hidden_states=x,
|
| 1191 |
+
w1=layer.w13_weight,
|
| 1192 |
+
w2=layer.w2_weight,
|
| 1193 |
+
topk_weights=topk_weights,
|
| 1194 |
+
topk_ids=topk_ids,
|
| 1195 |
+
top_k=top_k,
|
| 1196 |
+
expert_map=expert_map,
|
| 1197 |
+
ep_group=get_ep_group())
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
class AscendFusedMoE(FusedMoE):
|
| 1201 |
+
|
| 1202 |
+
# The moe_counter parameter is required during the initialization of EPLB
|
| 1203 |
+
# to identify the current layer index within the MOE model.
|
| 1204 |
+
moe_counter = -1
|
| 1205 |
+
|
| 1206 |
+
def __init__(
|
| 1207 |
+
self,
|
| 1208 |
+
num_experts: int, # Global number of experts
|
| 1209 |
+
top_k: int,
|
| 1210 |
+
hidden_size: int,
|
| 1211 |
+
intermediate_size: int,
|
| 1212 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 1213 |
+
reduce_results: bool = False,
|
| 1214 |
+
renormalize: bool = True,
|
| 1215 |
+
use_grouped_topk: bool = False,
|
| 1216 |
+
num_expert_group: Optional[int] = None,
|
| 1217 |
+
topk_group: Optional[int] = None,
|
| 1218 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 1219 |
+
tp_size: Optional[int] = None,
|
| 1220 |
+
ep_size: Optional[int] = None,
|
| 1221 |
+
dp_size: Optional[int] = None,
|
| 1222 |
+
prefix: str = "",
|
| 1223 |
+
custom_routing_function: Optional[Callable] = None,
|
| 1224 |
+
scoring_func: str = "softmax",
|
| 1225 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 1226 |
+
activation: str = "silu",
|
| 1227 |
+
apply_router_weight_on_input: bool = False,
|
| 1228 |
+
):
|
| 1229 |
+
# TODO: This could not initialize FusedMoE baseclass,
|
| 1230 |
+
# fixme and make __init__() of AscendFusedMoE more clear
|
| 1231 |
+
super(FusedMoE, self).__init__()
|
| 1232 |
+
|
| 1233 |
+
AscendFusedMoE.moe_counter += 1
|
| 1234 |
+
self.moe_instance_id = AscendFusedMoE.moe_counter
|
| 1235 |
+
|
| 1236 |
+
if params_dtype is None:
|
| 1237 |
+
params_dtype = torch.get_default_dtype()
|
| 1238 |
+
|
| 1239 |
+
vllm_config = get_current_vllm_config()
|
| 1240 |
+
|
| 1241 |
+
self.moe_parallel_config = FusedMoEParallelConfig.make(
|
| 1242 |
+
tp_size_=(tp_size if tp_size is not None else
|
| 1243 |
+
get_tensor_model_parallel_world_size()),
|
| 1244 |
+
dp_size_=(dp_size
|
| 1245 |
+
if dp_size is not None else get_dp_group().world_size),
|
| 1246 |
+
vllm_parallel_config=vllm_config.parallel_config)
|
| 1247 |
+
|
| 1248 |
+
self.top_k = top_k
|
| 1249 |
+
self.num_experts = num_experts
|
| 1250 |
+
self.global_num_experts = num_experts
|
| 1251 |
+
assert intermediate_size % self.tp_size == 0
|
| 1252 |
+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
| 1253 |
+
self.reduce_results = reduce_results
|
| 1254 |
+
self.renormalize = renormalize
|
| 1255 |
+
self.use_grouped_topk = use_grouped_topk
|
| 1256 |
+
if self.use_grouped_topk:
|
| 1257 |
+
assert num_expert_group is not None and topk_group is not None
|
| 1258 |
+
self.num_expert_group = num_expert_group
|
| 1259 |
+
self.topk_group = topk_group
|
| 1260 |
+
self.custom_routing_function = custom_routing_function
|
| 1261 |
+
self.scoring_func = scoring_func
|
| 1262 |
+
self.e_score_correction_bias = e_score_correction_bias
|
| 1263 |
+
self.expert_map = None
|
| 1264 |
+
self.activation = activation
|
| 1265 |
+
self.log2phy = None
|
| 1266 |
+
self.global_redundant_expert_num = 0
|
| 1267 |
+
|
| 1268 |
+
is_deepseek_v3_r1 = self.global_num_experts == 256
|
| 1269 |
+
self.rm_router_logits = get_rm_router_logits_state(
|
| 1270 |
+
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
|
| 1271 |
+
self.all_reduce_merge = get_all_reduce_merge_state(
|
| 1272 |
+
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
| 1273 |
+
|
| 1274 |
+
ascend_config = get_ascend_config()
|
| 1275 |
+
expert_map_path = ascend_config.expert_map_path
|
| 1276 |
+
if expert_map_path and os.path.exists(expert_map_path):
|
| 1277 |
+
# moe expert load balance
|
| 1278 |
+
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
|
| 1279 |
+
self.global_num_experts)
|
| 1280 |
+
self.local_num_experts, self.expert_map = \
|
| 1281 |
+
expert_load_balancer.get_rank_placement_map(
|
| 1282 |
+
self.moe_instance_id,
|
| 1283 |
+
get_ep_group().rank_in_group)
|
| 1284 |
+
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
|
| 1285 |
+
self.moe_instance_id,
|
| 1286 |
+
get_ep_group().rank_in_group)
|
| 1287 |
+
self.global_redundant_expert_num = \
|
| 1288 |
+
expert_load_balancer.get_global_redundant_expert_num()
|
| 1289 |
+
else:
|
| 1290 |
+
# Create a tensor of size num_experts filled with -1
|
| 1291 |
+
self.local_num_experts, self.expert_map = determine_expert_map(
|
| 1292 |
+
self.ep_size,
|
| 1293 |
+
get_ep_group().rank_in_group, self.global_num_experts)
|
| 1294 |
+
|
| 1295 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 1296 |
+
self.enable_multistream_moe = \
|
| 1297 |
+
ascend_config.torchair_graph_config.enable_multistream_moe
|
| 1298 |
+
|
| 1299 |
+
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
| 1300 |
+
raise ValueError("Only softmax scoring function is supported for "
|
| 1301 |
+
"non-grouped topk.")
|
| 1302 |
+
moe = FusedMoEConfig.make(
|
| 1303 |
+
num_experts=self.global_num_experts,
|
| 1304 |
+
experts_per_token=top_k,
|
| 1305 |
+
hidden_dim=hidden_size,
|
| 1306 |
+
num_local_experts=self.local_num_experts,
|
| 1307 |
+
moe_parallel_config=self.moe_parallel_config,
|
| 1308 |
+
# TODO (bnell): this needs to be fixed for quantized types.
|
| 1309 |
+
in_dtype=params_dtype,
|
| 1310 |
+
quant_config=quant_config)
|
| 1311 |
+
|
| 1312 |
+
if quant_config is None:
|
| 1313 |
+
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
| 1314 |
+
else:
|
| 1315 |
+
self.quant_method = quant_config.get_quant_method(self, prefix)
|
| 1316 |
+
|
| 1317 |
+
assert self.quant_method is not None
|
| 1318 |
+
|
| 1319 |
+
local_num_experts = torch.sum(self.expert_map != -1) \
|
| 1320 |
+
if self.expert_map is not None else num_experts
|
| 1321 |
+
|
| 1322 |
+
moe_quant_params = {
|
| 1323 |
+
"num_experts": local_num_experts,
|
| 1324 |
+
"hidden_size": hidden_size,
|
| 1325 |
+
"intermediate_size_per_partition":
|
| 1326 |
+
self.intermediate_size_per_partition,
|
| 1327 |
+
"params_dtype": params_dtype,
|
| 1328 |
+
"weight_loader": self.weight_loader,
|
| 1329 |
+
}
|
| 1330 |
+
# need full intermediate size pre-sharding for WNA16 act order
|
| 1331 |
+
if (self.quant_method.__class__.__name__
|
| 1332 |
+
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
| 1333 |
+
moe_quant_params["intermediate_size_full"] = intermediate_size
|
| 1334 |
+
|
| 1335 |
+
self.ep_group = get_ep_group()
|
| 1336 |
+
# NOTE: self.tp_group is not expert_tp_group
|
| 1337 |
+
self.tp_group = get_tp_group().device_group
|
| 1338 |
+
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
| 1339 |
+
|
| 1340 |
+
def naive_multicast(self, x: torch.Tensor,
|
| 1341 |
+
cu_tokens_across_dp_cpu: torch.Tensor):
|
| 1342 |
+
assert (len(x.shape) == 2)
|
| 1343 |
+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
| 1344 |
+
device=x.device,
|
| 1345 |
+
dtype=x.dtype)
|
| 1346 |
+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
| 1347 |
+
self.dp_rank - 1]
|
| 1348 |
+
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
| 1349 |
+
buffer[start:end, :].copy_(x)
|
| 1350 |
+
for idx in range(self.dp_size):
|
| 1351 |
+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
| 1352 |
+
end = cu_tokens_across_dp_cpu[idx]
|
| 1353 |
+
get_dp_group().broadcast(buffer[start:end, :], idx)
|
| 1354 |
+
return buffer
|
| 1355 |
+
|
| 1356 |
+
def forward(self,
|
| 1357 |
+
hidden_states: torch.Tensor,
|
| 1358 |
+
router_logits: torch.Tensor,
|
| 1359 |
+
is_prefill: bool,
|
| 1360 |
+
enable_force_load_balance: bool = False,
|
| 1361 |
+
top_k: Optional[int] = None,
|
| 1362 |
+
shared_experts: Optional[Any] = None,
|
| 1363 |
+
gate=None,
|
| 1364 |
+
replace_allreduce: bool = False):
|
| 1365 |
+
|
| 1366 |
+
assert self.quant_method is not None
|
| 1367 |
+
|
| 1368 |
+
if top_k:
|
| 1369 |
+
real_top_k = top_k
|
| 1370 |
+
else:
|
| 1371 |
+
real_top_k = self.top_k
|
| 1372 |
+
|
| 1373 |
+
num_tokens, hidden_size = hidden_states.shape
|
| 1374 |
+
is_deepseek_v3_r1 = self.global_num_experts == 256
|
| 1375 |
+
|
| 1376 |
+
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
| 1377 |
+
is_prefill, is_deepseek_v3_r1)
|
| 1378 |
+
if shared_experts:
|
| 1379 |
+
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
| 1380 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 1381 |
+
shared_hidden_states = shared_experts(hidden_states)
|
| 1382 |
+
|
| 1383 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 1384 |
+
if (tp_size > 1 and fused_moe_state not in [
|
| 1385 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1386 |
+
FusedMoEState.NaiveMulticast
|
| 1387 |
+
] and not replace_allreduce):
|
| 1388 |
+
if num_tokens < tp_size:
|
| 1389 |
+
hidden_states = nn.functional.pad(
|
| 1390 |
+
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
| 1391 |
+
router_logits = nn.functional.pad(
|
| 1392 |
+
router_logits, (0, 0, 0, tp_size - num_tokens))
|
| 1393 |
+
chunk_hidden_states = torch.tensor_split(hidden_states,
|
| 1394 |
+
tp_size,
|
| 1395 |
+
dim=0)
|
| 1396 |
+
chunk_router_logits = torch.tensor_split(router_logits,
|
| 1397 |
+
tp_size,
|
| 1398 |
+
dim=0)
|
| 1399 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 1400 |
+
hidden_states = chunk_hidden_states[tp_rank]
|
| 1401 |
+
router_logits = chunk_router_logits[tp_rank]
|
| 1402 |
+
|
| 1403 |
+
if self.dp_size > 1:
|
| 1404 |
+
if fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
|
| 1405 |
+
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
| 1406 |
+
if not self.torchair_graph_enabled or is_prefill:
|
| 1407 |
+
attn_metadata = get_forward_context().attn_metadata
|
| 1408 |
+
if attn_metadata is not None:
|
| 1409 |
+
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
| 1410 |
+
if num_tokens < max_num_tokens_across_dp:
|
| 1411 |
+
hidden_states = nn.functional.pad(
|
| 1412 |
+
hidden_states,
|
| 1413 |
+
(0, 0, 0,
|
| 1414 |
+
max_num_tokens_across_dp - num_tokens))
|
| 1415 |
+
if not self.rm_router_logits:
|
| 1416 |
+
router_logits = nn.functional.pad(
|
| 1417 |
+
router_logits,
|
| 1418 |
+
(0, 0, 0,
|
| 1419 |
+
max_num_tokens_across_dp - num_tokens))
|
| 1420 |
+
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
| 1421 |
+
if self.rm_router_logits:
|
| 1422 |
+
router_logits, _ = gate(hidden_states.float())
|
| 1423 |
+
else:
|
| 1424 |
+
router_logits = get_dp_group().all_gather(router_logits, 0)
|
| 1425 |
+
|
| 1426 |
+
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
| 1427 |
+
cu_tokens_across_dp_cpu = get_forward_context(
|
| 1428 |
+
).dp_metadata.cu_tokens_across_dp_cpu
|
| 1429 |
+
hidden_states = self.naive_multicast(hidden_states,
|
| 1430 |
+
cu_tokens_across_dp_cpu)
|
| 1431 |
+
if self.rm_router_logits:
|
| 1432 |
+
router_logits, _ = gate(hidden_states.float())
|
| 1433 |
+
else:
|
| 1434 |
+
router_logits = self.naive_multicast(
|
| 1435 |
+
router_logits, cu_tokens_across_dp_cpu)
|
| 1436 |
+
|
| 1437 |
+
# Matrix multiply.
|
| 1438 |
+
e_hidden_states = self.quant_method.apply(
|
| 1439 |
+
layer=self,
|
| 1440 |
+
x=hidden_states,
|
| 1441 |
+
router_logits=router_logits,
|
| 1442 |
+
top_k=real_top_k,
|
| 1443 |
+
renormalize=self.renormalize,
|
| 1444 |
+
use_grouped_topk=self.use_grouped_topk,
|
| 1445 |
+
global_num_experts=self.global_num_experts,
|
| 1446 |
+
expert_map=self.expert_map,
|
| 1447 |
+
topk_group=self.topk_group,
|
| 1448 |
+
num_expert_group=self.num_expert_group,
|
| 1449 |
+
custom_routing_function=self.custom_routing_function,
|
| 1450 |
+
scoring_func=self.scoring_func,
|
| 1451 |
+
e_score_correction_bias=self.e_score_correction_bias,
|
| 1452 |
+
is_prefill=is_prefill,
|
| 1453 |
+
enable_force_load_balance=enable_force_load_balance,
|
| 1454 |
+
log2phy=self.log2phy,
|
| 1455 |
+
global_redundant_expert_num=self.global_redundant_expert_num,
|
| 1456 |
+
shared_experts=shared_experts if self.torchair_graph_enabled
|
| 1457 |
+
and self.enable_multistream_moe and not is_prefill else None,
|
| 1458 |
+
)
|
| 1459 |
+
|
| 1460 |
+
if shared_experts:
|
| 1461 |
+
if isinstance(e_hidden_states, tuple):
|
| 1462 |
+
e_hidden_states, shared_hidden_states = e_hidden_states
|
| 1463 |
+
|
| 1464 |
+
if (tp_size > 1 and fused_moe_state not in [
|
| 1465 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1466 |
+
FusedMoEState.NaiveMulticast
|
| 1467 |
+
] and not replace_allreduce):
|
| 1468 |
+
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
| 1469 |
+
self.tp_group)
|
| 1470 |
+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
| 1471 |
+
if num_tokens < tp_size:
|
| 1472 |
+
final_hidden_states = final_hidden_states[:num_tokens]
|
| 1473 |
+
dispose_tensor(e_hidden_states)
|
| 1474 |
+
elif self.dp_size > 1:
|
| 1475 |
+
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
| 1476 |
+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
| 1477 |
+
self.dp_rank - 1]
|
| 1478 |
+
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
| 1479 |
+
final_hidden_states = get_dp_group().all_reduce(
|
| 1480 |
+
e_hidden_states)
|
| 1481 |
+
final_hidden_states = final_hidden_states[start:end, :]
|
| 1482 |
+
dispose_tensor(e_hidden_states)
|
| 1483 |
+
elif fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
|
| 1484 |
+
final_hidden_states = data_parallel_reduce_scatter(
|
| 1485 |
+
e_hidden_states, dim=0)
|
| 1486 |
+
final_hidden_states = final_hidden_states[:num_tokens]
|
| 1487 |
+
dispose_tensor(e_hidden_states)
|
| 1488 |
+
else:
|
| 1489 |
+
final_hidden_states = e_hidden_states
|
| 1490 |
+
|
| 1491 |
+
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
|
| 1492 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1493 |
+
FusedMoEState.NaiveMulticast
|
| 1494 |
+
]:
|
| 1495 |
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
| 1496 |
+
final_hidden_states)
|
| 1497 |
+
|
| 1498 |
+
if shared_experts:
|
| 1499 |
+
return final_hidden_states, shared_hidden_states
|
| 1500 |
+
else:
|
| 1501 |
+
return final_hidden_states
|
| 1502 |
+
|
| 1503 |
+
# ----------------------------------------- TBO-related --------------------------------------------
|
| 1504 |
+
|
| 1505 |
+
def _forward_ms_fused_moe_comp(
|
| 1506 |
+
self,
|
| 1507 |
+
hidden_states: torch.Tensor,
|
| 1508 |
+
router_logits: torch.Tensor,
|
| 1509 |
+
is_prefill: bool,
|
| 1510 |
+
real_top_k,
|
| 1511 |
+
enable_force_load_balance: bool = False,
|
| 1512 |
+
):
|
| 1513 |
+
hidden_states = self.quant_method.apply(
|
| 1514 |
+
layer=self,
|
| 1515 |
+
x=hidden_states,
|
| 1516 |
+
router_logits=router_logits,
|
| 1517 |
+
top_k=real_top_k,
|
| 1518 |
+
renormalize=self.renormalize,
|
| 1519 |
+
use_grouped_topk=self.use_grouped_topk,
|
| 1520 |
+
global_num_experts=self.global_num_experts,
|
| 1521 |
+
expert_map=self.expert_map,
|
| 1522 |
+
topk_group=self.topk_group,
|
| 1523 |
+
num_expert_group=self.num_expert_group,
|
| 1524 |
+
custom_routing_function=self.custom_routing_function,
|
| 1525 |
+
scoring_func=self.scoring_func,
|
| 1526 |
+
e_score_correction_bias=self.e_score_correction_bias,
|
| 1527 |
+
is_prefill=is_prefill,
|
| 1528 |
+
enable_force_load_balance=enable_force_load_balance)
|
| 1529 |
+
|
| 1530 |
+
return hidden_states
|
inference/vllm_ascend/patch/worker/patch_common/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
# patch_utils should be the first import, because it will be used by other
|
| 19 |
+
# patch files.
|
| 20 |
+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
| 21 |
+
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
| 22 |
+
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
| 23 |
+
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
| 24 |
+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
| 25 |
+
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
| 26 |
+
import vllm_ascend.patch.worker.patch_common.patch_config # noqa
|
| 27 |
+
import vllm_ascend.patch.worker.patch_common.patch_parsers # noqa
|
inference/vllm_ascend/patch/worker/patch_common/patch_config.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
from vllm.config import ModelConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_attr_by_names(src_config, attrs, default_value):
|
| 21 |
+
for attr in attrs:
|
| 22 |
+
value = getattr(src_config, attr, 0)
|
| 23 |
+
if value > 0:
|
| 24 |
+
return value
|
| 25 |
+
return default_value
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _verify_with_expert_parallelism(self) -> None:
|
| 29 |
+
num_expert_names = [
|
| 30 |
+
"moe_num_experts", # Dbrx
|
| 31 |
+
"num_experts", # Jamba
|
| 32 |
+
"n_routed_experts", # DeepSeek
|
| 33 |
+
"num_local_experts", # Mixtral
|
| 34 |
+
"num_routed_experts", # Pangu
|
| 35 |
+
]
|
| 36 |
+
num_experts = 0
|
| 37 |
+
for name in num_expert_names:
|
| 38 |
+
num_experts = getattr(self.hf_text_config, name, 0)
|
| 39 |
+
if num_experts > 0:
|
| 40 |
+
break
|
| 41 |
+
if num_experts < 1:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"Number of experts in the model must be greater than 0 "
|
| 44 |
+
"when expert parallelism is enabled.")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def is_deepseek_mla(self) -> bool:
|
| 49 |
+
kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
|
| 50 |
+
kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, None)
|
| 51 |
+
if not hasattr(self.hf_text_config, "model_type"):
|
| 52 |
+
return False
|
| 53 |
+
elif self.hf_text_config.model_type in \
|
| 54 |
+
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'pangu_ultra_moe'):
|
| 55 |
+
return kv_lora_dim is not None
|
| 56 |
+
elif self.hf_text_config.model_type == 'eagle':
|
| 57 |
+
# if the model is an EAGLE module, check for the
|
| 58 |
+
# underlying architecture
|
| 59 |
+
return self.hf_text_config.model.model_type in \
|
| 60 |
+
('deepseek_v2', 'deepseek_v3', 'pangu_ultra_moe') \
|
| 61 |
+
and kv_lora_dim is not None
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_head_size(self) -> int:
|
| 66 |
+
if self.is_deepseek_mla:
|
| 67 |
+
qk_rope_dim_names = ['attention_qk_rope_dim', 'qk_rope_head_dim']
|
| 68 |
+
kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
|
| 69 |
+
qk_rope_dim = get_attr_by_names(self.hf_text_config, qk_rope_dim_names, 0)
|
| 70 |
+
kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, 0)
|
| 71 |
+
if self.use_mla:
|
| 72 |
+
return kv_lora_dim + qk_rope_dim
|
| 73 |
+
else:
|
| 74 |
+
qk_dim_names = ['attention_qk_dim', 'qk_nope_head_dim']
|
| 75 |
+
qk_dim = get_attr_by_names(self.hf_text_config, qk_dim_names, 0)
|
| 76 |
+
if qk_rope_dim and qk_dim:
|
| 77 |
+
return qk_rope_dim + qk_dim
|
| 78 |
+
if hasattr(self.hf_text_config,
|
| 79 |
+
"model_type") and (self.hf_text_config.model_type
|
| 80 |
+
== "zamba2"):
|
| 81 |
+
return self.hf_text_config.attention_head_dim
|
| 82 |
+
|
| 83 |
+
if self.is_attention_free:
|
| 84 |
+
return 0
|
| 85 |
+
|
| 86 |
+
# NOTE: Some configs may set head_dim=None in the config
|
| 87 |
+
if getattr(self.hf_text_config, "head_dim", None) is not None:
|
| 88 |
+
return self.hf_text_config.head_dim
|
| 89 |
+
|
| 90 |
+
# FIXME(woosuk): This may not be true for all models.
|
| 91 |
+
return (self.hf_text_config.hidden_size //
|
| 92 |
+
self.hf_text_config.num_attention_heads)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
ModelConfig._verify_with_expert_parallelism = _verify_with_expert_parallelism
|
| 96 |
+
ModelConfig.is_deepseek_mla = is_deepseek_mla
|
| 97 |
+
ModelConfig.get_head_size = get_head_size
|
inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from vllm.entrypoints.openai import tool_parsers
|
| 20 |
+
from vllm_ascend.entrypoints.openai.tool_parsers import PanguToolParser
|
| 21 |
+
tool_parsers.__all__.append("PanguToolParser")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from vllm import reasoning
|
| 25 |
+
from vllm_ascend.entrypoints.openai.reasoning_parsers import PanguReasoningParser
|
| 26 |
+
reasoning.__all__.append("PanguReasoningParser")
|
inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
# This file is a part of the vllm-ascend project.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
#
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch_npu
|
| 23 |
+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
| 24 |
+
from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS
|
| 25 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 26 |
+
from vllm_ascend import envs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def apply_top_k_top_p(
|
| 30 |
+
logits: torch.Tensor,
|
| 31 |
+
k: torch.Tensor,
|
| 32 |
+
p: torch.Tensor,
|
| 33 |
+
) -> torch.Tensor:
|
| 34 |
+
if p is not None and k is not None:
|
| 35 |
+
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
| 36 |
+
return torch_npu.npu_top_k_top_p(logits, p, k)
|
| 37 |
+
|
| 38 |
+
probs = logits.softmax(dim=-1)
|
| 39 |
+
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
| 40 |
+
|
| 41 |
+
if k is not None:
|
| 42 |
+
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
| 43 |
+
top_k_count = top_k_count.unsqueeze(dim=1)
|
| 44 |
+
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
| 45 |
+
|
| 46 |
+
# Make sure the no top-k rows are no-op.
|
| 47 |
+
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
| 48 |
+
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
| 49 |
+
|
| 50 |
+
elements_to_discard = probs < top_k_cutoff
|
| 51 |
+
logits.masked_fill_(elements_to_discard, -float("inf"))
|
| 52 |
+
|
| 53 |
+
if p is not None:
|
| 54 |
+
cumprob = torch.cumsum(probs_sort, dim=-1)
|
| 55 |
+
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
| 56 |
+
top_p_mask[:, -1] = False # at least one
|
| 57 |
+
|
| 58 |
+
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
| 59 |
+
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
| 60 |
+
elements_to_discard = probs < top_p_cutoff
|
| 61 |
+
logits.masked_fill_(elements_to_discard, -float("inf"))
|
| 62 |
+
|
| 63 |
+
return logits
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def topk_topp_forward_native(
|
| 67 |
+
self,
|
| 68 |
+
logits: torch.Tensor,
|
| 69 |
+
generators: dict[int, torch.Generator],
|
| 70 |
+
k: Optional[torch.Tensor],
|
| 71 |
+
p: Optional[torch.Tensor],
|
| 72 |
+
) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
PyTorch-native implementation of top-k and top-p sampling.
|
| 75 |
+
|
| 76 |
+
The logits tensor may be updated in-place.
|
| 77 |
+
"""
|
| 78 |
+
logits = apply_top_k_top_p(logits, k, p)
|
| 79 |
+
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
| 80 |
+
return random_sample(probs, generators)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def apply_top_n_sigma(
|
| 84 |
+
logits: torch.Tensor,
|
| 85 |
+
sampling_metadata: SamplingMetadata,
|
| 86 |
+
):
|
| 87 |
+
if sampling_metadata.no_top_n_sigma:
|
| 88 |
+
return logits
|
| 89 |
+
|
| 90 |
+
top_n_sigma = sampling_metadata.top_n_sigma[:, None]
|
| 91 |
+
top_n_sigma_mask = (top_n_sigma != -1)
|
| 92 |
+
filter_value = -3.4028e+38
|
| 93 |
+
max_vals, _ = logits.max(dim=-1, keepdim=True)
|
| 94 |
+
std_vals = logits.std(dim=-1, keepdim=True)
|
| 95 |
+
threshold = max_vals - top_n_sigma * std_vals
|
| 96 |
+
threshold[~top_n_sigma_mask] = filter_value
|
| 97 |
+
mask = (logits < threshold)
|
| 98 |
+
logits = torch.where(mask, filter_value, logits)
|
| 99 |
+
return logits
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def sample(
|
| 103 |
+
self,
|
| 104 |
+
logits: torch.Tensor,
|
| 105 |
+
sampling_metadata: SamplingMetadata,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""Sample logits based on sampling metadata.
|
| 108 |
+
|
| 109 |
+
The various logits processing functions called in this method
|
| 110 |
+
may update the logits tensor in-place.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
assert not (sampling_metadata.all_greedy
|
| 114 |
+
and sampling_metadata.all_random)
|
| 115 |
+
if sampling_metadata.all_random:
|
| 116 |
+
greedy_sampled = None
|
| 117 |
+
else:
|
| 118 |
+
greedy_sampled = self.greedy_sample(logits)
|
| 119 |
+
if sampling_metadata.all_greedy:
|
| 120 |
+
return greedy_sampled
|
| 121 |
+
|
| 122 |
+
assert sampling_metadata.temperature is not None
|
| 123 |
+
|
| 124 |
+
# Apply temperature.
|
| 125 |
+
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
| 126 |
+
|
| 127 |
+
# Apply logits processors that only apply to random sampling
|
| 128 |
+
# (argmax invariant)
|
| 129 |
+
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
| 130 |
+
logits = processor.apply(logits)
|
| 131 |
+
|
| 132 |
+
# Apply top_n_sigma
|
| 133 |
+
logits = apply_top_n_sigma(logits, sampling_metadata)
|
| 134 |
+
|
| 135 |
+
# Apply top_k and/or top_p.
|
| 136 |
+
random_sampled = self.topk_topp_sampler(
|
| 137 |
+
logits,
|
| 138 |
+
sampling_metadata.generators,
|
| 139 |
+
sampling_metadata.top_k,
|
| 140 |
+
sampling_metadata.top_p,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if greedy_sampled is None:
|
| 144 |
+
return random_sampled
|
| 145 |
+
|
| 146 |
+
sampled = torch.where(
|
| 147 |
+
sampling_metadata.temperature < _SAMPLING_EPS,
|
| 148 |
+
greedy_sampled,
|
| 149 |
+
random_sampled,
|
| 150 |
+
out=greedy_sampled, # Reuse tensor
|
| 151 |
+
)
|
| 152 |
+
return sampled
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
| 156 |
+
TopKTopPSampler.forward_native = topk_topp_forward_native
|
| 157 |
+
|
| 158 |
+
if envs.VLLM_ASCEND_ENABLE_TOP_N_SIGMA:
|
| 159 |
+
Sampler.sample = sample
|
inference/vllm_ascend/quantization/w8a8.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from typing import Any, Callable, Dict, Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch_npu
|
| 22 |
+
from vllm.attention.backends.abstract import AttentionType
|
| 23 |
+
|
| 24 |
+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
| 25 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 26 |
+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def quant_per_tensor(in_tensor: torch.Tensor,
|
| 30 |
+
input_scale: torch.Tensor,
|
| 31 |
+
input_offset: torch.Tensor,
|
| 32 |
+
function=False):
|
| 33 |
+
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
| 34 |
+
torch.qint8, -1, function)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AscendW8A8LinearMethod:
|
| 38 |
+
"""Linear method for Ascend W8A8.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
w_sym: whether the linear weight is symmetrically quantized.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self) -> None:
|
| 45 |
+
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
| 46 |
+
self.transpose_weight = not is_310p()
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def get_weight(
|
| 50 |
+
input_size: int,
|
| 51 |
+
output_size: int,
|
| 52 |
+
params_dtype: torch.dtype = torch.bfloat16,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
params_dict = {
|
| 55 |
+
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
| 56 |
+
}
|
| 57 |
+
return params_dict
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 61 |
+
params_dict = {}
|
| 62 |
+
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
| 63 |
+
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
| 64 |
+
return params_dict
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_perchannel_param(
|
| 68 |
+
output_size: int,
|
| 69 |
+
params_dtype: torch.dtype,
|
| 70 |
+
) -> Dict[str, Any]:
|
| 71 |
+
params_dict = {}
|
| 72 |
+
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
| 73 |
+
if params_dtype == torch.bfloat16:
|
| 74 |
+
params_dict["deq_scale"] = torch.empty(output_size,
|
| 75 |
+
dtype=torch.float32)
|
| 76 |
+
elif params_dtype == torch.float16:
|
| 77 |
+
params_dict["deq_scale"] = torch.empty(output_size,
|
| 78 |
+
dtype=torch.int64)
|
| 79 |
+
params_dict["weight_scale"] = torch.empty(output_size,
|
| 80 |
+
1,
|
| 81 |
+
dtype=params_dtype)
|
| 82 |
+
params_dict["weight_offset"] = torch.empty(output_size,
|
| 83 |
+
1,
|
| 84 |
+
dtype=params_dtype)
|
| 85 |
+
return params_dict
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def apply(
|
| 89 |
+
layer: torch.nn.Module,
|
| 90 |
+
x: torch.Tensor,
|
| 91 |
+
bias: Optional[torch.Tensor] = None,
|
| 92 |
+
tp_rank: Optional[int] = 0,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
original_dtype = x.dtype
|
| 95 |
+
if original_dtype != torch.int8:
|
| 96 |
+
x = quant_per_tensor(x, layer.aclnn_input_scale,
|
| 97 |
+
layer.aclnn_input_offset)
|
| 98 |
+
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
| 99 |
+
if is_310p():
|
| 100 |
+
# On 300I Duo platform, we need transpose again if
|
| 101 |
+
# using nz. This transpose can be skipped in torchair.
|
| 102 |
+
output = torch_npu.npu_quant_matmul(
|
| 103 |
+
x,
|
| 104 |
+
layer.weight.data.transpose(1, 0),
|
| 105 |
+
layer.deq_scale,
|
| 106 |
+
bias=quant_bias,
|
| 107 |
+
output_dtype=original_dtype,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
output = torch_npu.npu_quant_matmul(
|
| 111 |
+
x,
|
| 112 |
+
layer.weight,
|
| 113 |
+
layer.deq_scale,
|
| 114 |
+
bias=quant_bias,
|
| 115 |
+
output_dtype=original_dtype,
|
| 116 |
+
)
|
| 117 |
+
return output
|
| 118 |
+
|
| 119 |
+
def process_weights_after_loading(self, layer):
|
| 120 |
+
expanding_factor = layer.weight.data.shape[1]
|
| 121 |
+
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
|
| 122 |
+
layer.input_scale.data.repeat(expanding_factor),
|
| 123 |
+
requires_grad=False)
|
| 124 |
+
layer.aclnn_input_offset = torch.nn.Parameter(
|
| 125 |
+
layer.input_offset.data.repeat(expanding_factor),
|
| 126 |
+
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
| 127 |
+
if self.transpose_weight:
|
| 128 |
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
| 129 |
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
| 130 |
+
ACL_FORMAT_FRACTAL_NZ)
|
| 131 |
+
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
| 132 |
+
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class AscendW8A8FusedMoEMethod:
|
| 136 |
+
"""FusedMoe method for Ascend W8A8.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self):
|
| 140 |
+
self.transpose_weight = True
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
| 144 |
+
hidden_sizes: int,
|
| 145 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 146 |
+
param_dict = {}
|
| 147 |
+
param_dict["w13_weight"] = torch.empty(num_experts,
|
| 148 |
+
2 *
|
| 149 |
+
intermediate_size_per_partition,
|
| 150 |
+
hidden_sizes,
|
| 151 |
+
dtype=torch.int8,
|
| 152 |
+
requires_grad=False)
|
| 153 |
+
param_dict["w2_weight"] = torch.empty(num_experts,
|
| 154 |
+
hidden_sizes,
|
| 155 |
+
intermediate_size_per_partition,
|
| 156 |
+
dtype=torch.int8,
|
| 157 |
+
requires_grad=False)
|
| 158 |
+
return param_dict
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def get_dynamic_quant_param(num_experts: int,
|
| 162 |
+
intermediate_size_per_partition: int,
|
| 163 |
+
hidden_sizes: int,
|
| 164 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 165 |
+
param_dict = {}
|
| 166 |
+
param_dict["w13_weight_scale"] = torch.empty(
|
| 167 |
+
num_experts,
|
| 168 |
+
2 * intermediate_size_per_partition,
|
| 169 |
+
1,
|
| 170 |
+
dtype=torch.float32)
|
| 171 |
+
param_dict["w13_weight_offset"] = torch.empty(
|
| 172 |
+
num_experts,
|
| 173 |
+
2 * intermediate_size_per_partition,
|
| 174 |
+
1,
|
| 175 |
+
dtype=torch.float16)
|
| 176 |
+
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
| 177 |
+
hidden_sizes,
|
| 178 |
+
1,
|
| 179 |
+
dtype=torch.float32)
|
| 180 |
+
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
| 181 |
+
hidden_sizes,
|
| 182 |
+
1,
|
| 183 |
+
dtype=torch.float16)
|
| 184 |
+
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
| 185 |
+
hidden_sizes,
|
| 186 |
+
dtype=torch.float32)
|
| 187 |
+
param_dict["w13_deq_scale"] = torch.empty(
|
| 188 |
+
num_experts,
|
| 189 |
+
2 * intermediate_size_per_partition,
|
| 190 |
+
dtype=torch.float32)
|
| 191 |
+
param_dict["w2_input_scale"] = torch.empty(num_experts,
|
| 192 |
+
1,
|
| 193 |
+
dtype=torch.float32)
|
| 194 |
+
param_dict["w13_input_scale"] = torch.empty(num_experts,
|
| 195 |
+
1,
|
| 196 |
+
dtype=torch.float32)
|
| 197 |
+
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
| 198 |
+
1,
|
| 199 |
+
dtype=torch.int8)
|
| 200 |
+
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
| 201 |
+
1,
|
| 202 |
+
dtype=torch.int8)
|
| 203 |
+
param_dict["quant_bias"] = torch.empty(num_experts,
|
| 204 |
+
hidden_sizes,
|
| 205 |
+
dtype=torch.int32)
|
| 206 |
+
|
| 207 |
+
return param_dict
|
| 208 |
+
|
| 209 |
+
def apply(
|
| 210 |
+
self,
|
| 211 |
+
layer: torch.nn.Module,
|
| 212 |
+
x: torch.Tensor,
|
| 213 |
+
router_logits: torch.Tensor,
|
| 214 |
+
top_k: int,
|
| 215 |
+
renormalize: bool,
|
| 216 |
+
use_grouped_topk: bool = False,
|
| 217 |
+
global_num_experts: int = -1,
|
| 218 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 219 |
+
topk_group: Optional[int] = None,
|
| 220 |
+
num_expert_group: Optional[int] = None,
|
| 221 |
+
custom_routing_function: Optional[Callable] = None,
|
| 222 |
+
scoring_func: str = "softmax",
|
| 223 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 224 |
+
is_prefill: bool = True,
|
| 225 |
+
enable_force_load_balance: bool = False,
|
| 226 |
+
log2phy: torch.Tensor = None,
|
| 227 |
+
global_redundant_expert_num: int = 0,
|
| 228 |
+
shared_experts: Optional[Any] = None,
|
| 229 |
+
**kwargs,
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
assert router_logits.shape[
|
| 232 |
+
1] == global_num_experts, "Number of global experts mismatch"
|
| 233 |
+
|
| 234 |
+
topk_weights, topk_ids = select_experts(
|
| 235 |
+
hidden_states=x,
|
| 236 |
+
router_logits=router_logits,
|
| 237 |
+
top_k=top_k,
|
| 238 |
+
use_grouped_topk=use_grouped_topk,
|
| 239 |
+
renormalize=renormalize,
|
| 240 |
+
topk_group=topk_group,
|
| 241 |
+
num_expert_group=num_expert_group,
|
| 242 |
+
custom_routing_function=custom_routing_function,
|
| 243 |
+
scoring_func=scoring_func,
|
| 244 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 245 |
+
global_num_experts=global_num_experts,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if is_310p():
|
| 249 |
+
return fused_experts_310p(hidden_states=x,
|
| 250 |
+
w1=layer.w13_weight,
|
| 251 |
+
w1_scale=layer.w13_weight_scale,
|
| 252 |
+
w1_input_scale=layer.w13_input_scale,
|
| 253 |
+
w2=layer.w2_weight,
|
| 254 |
+
w2_scale=layer.w2_weight_scale,
|
| 255 |
+
w2_input_scale=layer.w2_input_scale,
|
| 256 |
+
topk_weights=topk_weights,
|
| 257 |
+
topk_ids=topk_ids,
|
| 258 |
+
top_k=top_k,
|
| 259 |
+
global_num_experts=global_num_experts,
|
| 260 |
+
expert_map=expert_map)
|
| 261 |
+
return fused_experts(hidden_states=x,
|
| 262 |
+
w1=layer.w13_weight,
|
| 263 |
+
w1_scale=layer.w13_weight_scale,
|
| 264 |
+
w1_input_scale=layer.w13_input_scale,
|
| 265 |
+
w1_input_offset=layer.w13_input_offset,
|
| 266 |
+
w2=layer.w2_weight,
|
| 267 |
+
w2_scale=layer.w2_weight_scale,
|
| 268 |
+
w2_input_scale=layer.w2_input_scale,
|
| 269 |
+
w2_input_offset=layer.w2_input_offset,
|
| 270 |
+
topk_weights=topk_weights,
|
| 271 |
+
topk_ids=topk_ids,
|
| 272 |
+
top_k=top_k,
|
| 273 |
+
global_num_experts=global_num_experts,
|
| 274 |
+
expert_map=expert_map)
|
| 275 |
+
|
| 276 |
+
def process_weights_after_loading(self, layer):
|
| 277 |
+
if not is_310p():
|
| 278 |
+
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
| 279 |
+
1, 2).contiguous()
|
| 280 |
+
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
| 281 |
+
1, 2).contiguous()
|
| 282 |
+
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
| 283 |
+
layer.w13_weight_scale.data.shape[0], -1)
|
| 284 |
+
|
| 285 |
+
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
| 286 |
+
layer.w13_weight_offset.data.shape[0], -1)
|
| 287 |
+
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
| 288 |
+
layer.w2_weight_scale.data.shape[0], -1)
|
| 289 |
+
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
| 290 |
+
layer.w2_weight_offset.data.shape[0], -1)
|
| 291 |
+
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
| 292 |
+
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
| 293 |
+
|
| 294 |
+
if is_310p():
|
| 295 |
+
layer.w13_input_scale.data = torch.nn.Parameter(
|
| 296 |
+
layer.w13_input_scale.data.max())
|
| 297 |
+
layer.w2_input_scale.data = torch.nn.Parameter(
|
| 298 |
+
layer.w2_input_scale.data.max())
|
| 299 |
+
else:
|
| 300 |
+
layer.w13_input_scale.data = torch.nn.Parameter(
|
| 301 |
+
layer.w13_input_scale.data.repeat(1,
|
| 302 |
+
expanding_factor_w13)[0:1])
|
| 303 |
+
layer.w2_input_scale.data = torch.nn.Parameter(
|
| 304 |
+
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
| 305 |
+
|
| 306 |
+
layer.w13_input_offset.data = torch.nn.Parameter(
|
| 307 |
+
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
|
| 308 |
+
layer.w2_input_offset.data = torch.nn.Parameter(
|
| 309 |
+
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
| 310 |
+
|
| 311 |
+
# converting ACL_FORMAT_FRACTAL_NZ.
|
| 312 |
+
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
| 313 |
+
# ACL_FORMAT_FRACTAL_NZ.
|
| 314 |
+
if not is_310p():
|
| 315 |
+
layer.w13_weight.data = torch_npu.npu_format_cast(
|
| 316 |
+
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
| 317 |
+
layer.w2_weight.data = torch_npu.npu_format_cast(
|
| 318 |
+
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class AscendC8KVCacheMethod:
|
| 322 |
+
|
| 323 |
+
def __init__(self) -> None:
|
| 324 |
+
self.antiquant_scale_comb = None
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def create_weights(layer) -> None:
|
| 328 |
+
param_dict = {} # num_kv_heads * head_size
|
| 329 |
+
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
| 330 |
+
layer.head_size,
|
| 331 |
+
dtype=torch.float16,
|
| 332 |
+
requires_grad=False)
|
| 333 |
+
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
| 334 |
+
layer.head_size,
|
| 335 |
+
dtype=torch.float16,
|
| 336 |
+
requires_grad=False)
|
| 337 |
+
for weight_name, weight_param in param_dict.items():
|
| 338 |
+
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
| 339 |
+
layer.register_parameter(weight_name, param)
|
| 340 |
+
|
| 341 |
+
def process_weights_after_loading(self, layer):
|
| 342 |
+
self.antiquant_scale_comb = torch.cat(
|
| 343 |
+
(layer.key_antiquant_scale.data.unsqueeze(0),
|
| 344 |
+
layer.value_antiquant_scale.data.unsqueeze(0)),
|
| 345 |
+
dim=0).to(torch.float16).contiguous()
|
| 346 |
+
|
| 347 |
+
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
|
| 348 |
+
attn_type, scale, output) -> torch.Tensor:
|
| 349 |
+
num_tokens = query.shape[0]
|
| 350 |
+
if attn_metadata is None:
|
| 351 |
+
return output.view(num_tokens, layer.num_heads * layer.head_size)
|
| 352 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 353 |
+
if attn_type != AttentionType.DECODER:
|
| 354 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 355 |
+
"encoder/decoder cross-attention "
|
| 356 |
+
"are not implemented for "
|
| 357 |
+
"PallasAttentionBackendImpl")
|
| 358 |
+
|
| 359 |
+
# C8
|
| 360 |
+
quant_key = quant_per_tensor(
|
| 361 |
+
key.view(-1, layer.num_kv_heads * layer.head_size),
|
| 362 |
+
layer.key_antiquant_scale.data.view(-1), None, True)
|
| 363 |
+
quant_value = quant_per_tensor(
|
| 364 |
+
value.view(-1, layer.num_kv_heads * layer.head_size),
|
| 365 |
+
layer.value_antiquant_scale.data.view(-1), None, True)
|
| 366 |
+
|
| 367 |
+
# View q k v to BSH.
|
| 368 |
+
query = query.view(-1, layer.num_heads, layer.head_size)
|
| 369 |
+
key = key.view(-1, layer.num_kv_heads, layer.head_size)
|
| 370 |
+
value = value.view(-1, layer.num_kv_heads, layer.head_size)
|
| 371 |
+
# TODO: Remove this contiguous in the future.
|
| 372 |
+
value = value.contiguous()
|
| 373 |
+
|
| 374 |
+
if kv_cache[0].numel() > 0:
|
| 375 |
+
# if key_cache is None:
|
| 376 |
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
| 377 |
+
slots = attn_metadata.slot_mapping
|
| 378 |
+
|
| 379 |
+
block_size = key_cache.shape[1]
|
| 380 |
+
slots_indices = slots.reshape(-1, 1)
|
| 381 |
+
block_indices = slots_indices // block_size
|
| 382 |
+
slots_indices = slots_indices % block_size
|
| 383 |
+
indices = torch.cat((block_indices, slots_indices), dim=1)
|
| 384 |
+
|
| 385 |
+
# C8
|
| 386 |
+
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
|
| 387 |
+
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
|
| 388 |
+
|
| 389 |
+
# V0-Style scheduler situation.
|
| 390 |
+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 391 |
+
assert attn_metadata is not None
|
| 392 |
+
assert attn_metadata.attn_mask is not None
|
| 393 |
+
mask = attn_metadata.attn_mask
|
| 394 |
+
torch_npu._npu_flash_attention(query=query,
|
| 395 |
+
key=key,
|
| 396 |
+
value=value,
|
| 397 |
+
mask=mask,
|
| 398 |
+
seq_len=attn_metadata.seq_lens,
|
| 399 |
+
scale_value=scale,
|
| 400 |
+
num_heads=layer.num_heads,
|
| 401 |
+
num_kv_heads=layer.num_kv_heads,
|
| 402 |
+
out=output.reshape(query.shape))
|
| 403 |
+
|
| 404 |
+
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
| 405 |
+
raise NotImplementedError("kv cache int8 are not "
|
| 406 |
+
"implemented for "
|
| 407 |
+
"PrefillCacheHit")
|
| 408 |
+
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
| 409 |
+
if hasattr(attn_metadata, "decode"):
|
| 410 |
+
# torch_air
|
| 411 |
+
decode_meta = attn_metadata.decode
|
| 412 |
+
seq_lens = decode_meta.seq_lens_list
|
| 413 |
+
else:
|
| 414 |
+
seq_lens = attn_metadata.seq_lens
|
| 415 |
+
block_size = key_cache.shape[1]
|
| 416 |
+
query = query.view(num_tokens, 1, layer.num_heads *
|
| 417 |
+
layer.head_size).contiguous() # changed
|
| 418 |
+
|
| 419 |
+
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
|
| 420 |
+
key = key_cache
|
| 421 |
+
value = value_cache
|
| 422 |
+
|
| 423 |
+
output = torch_npu.npu_incre_flash_attention(
|
| 424 |
+
query,
|
| 425 |
+
key,
|
| 426 |
+
value,
|
| 427 |
+
num_key_value_heads=layer.num_kv_heads,
|
| 428 |
+
num_heads=layer.num_heads,
|
| 429 |
+
actual_seq_lengths=seq_lens,
|
| 430 |
+
scale_value=scale,
|
| 431 |
+
input_layout='BSH',
|
| 432 |
+
block_size=block_size,
|
| 433 |
+
block_table=attn_metadata.block_tables,
|
| 434 |
+
antiquant_scale=self.antiquant_scale_comb,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Normal V1 situation.
|
| 438 |
+
else:
|
| 439 |
+
raise NotImplementedError("kv cache int8 are not "
|
| 440 |
+
"implemented for "
|
| 441 |
+
"other case")
|
| 442 |
+
return output
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def fused_experts_310p(
|
| 446 |
+
hidden_states: torch.Tensor,
|
| 447 |
+
w1: torch.Tensor,
|
| 448 |
+
w1_scale: torch.Tensor,
|
| 449 |
+
w1_input_scale: torch.Tensor,
|
| 450 |
+
w2: torch.Tensor,
|
| 451 |
+
w2_scale: torch.Tensor,
|
| 452 |
+
w2_input_scale: torch.Tensor,
|
| 453 |
+
topk_weights: torch.Tensor,
|
| 454 |
+
topk_ids: torch.Tensor,
|
| 455 |
+
top_k: int,
|
| 456 |
+
global_num_experts: int,
|
| 457 |
+
expert_map: torch.Tensor = None,
|
| 458 |
+
) -> torch.Tensor:
|
| 459 |
+
ep_size = get_ep_group().world_size
|
| 460 |
+
local_num_experts = global_num_experts // ep_size
|
| 461 |
+
local_num_group = top_k // ep_size
|
| 462 |
+
|
| 463 |
+
bsz, _ = hidden_states.shape
|
| 464 |
+
flatten_topk_ids = topk_ids.view(-1)
|
| 465 |
+
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
| 466 |
+
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
| 467 |
+
sorted_hidden_states = hidden_states.index_select(
|
| 468 |
+
0, sorted_topk_ids // local_num_group)
|
| 469 |
+
|
| 470 |
+
experts_id = torch.arange(0,
|
| 471 |
+
local_num_experts,
|
| 472 |
+
dtype=topk_ids.dtype,
|
| 473 |
+
device=topk_ids.device)
|
| 474 |
+
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
| 475 |
+
torch.float32).sum(0)
|
| 476 |
+
topk_scales = topk_weights.view(-1).index_select(
|
| 477 |
+
0, sorted_topk_ids).unsqueeze(-1)
|
| 478 |
+
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
| 479 |
+
|
| 480 |
+
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
| 481 |
+
x=sorted_hidden_states,
|
| 482 |
+
quantized_weight=w1,
|
| 483 |
+
weight_scale=w1_scale,
|
| 484 |
+
group_list=group_list,
|
| 485 |
+
x_scale=w1_input_scale,
|
| 486 |
+
quant_mode="pertensor")
|
| 487 |
+
|
| 488 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
| 489 |
+
torch.float16)
|
| 490 |
+
gate_up_out *= topk_scales
|
| 491 |
+
|
| 492 |
+
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
| 493 |
+
x=gate_up_out,
|
| 494 |
+
quantized_weight=w2,
|
| 495 |
+
weight_scale=w2_scale,
|
| 496 |
+
group_list=group_list,
|
| 497 |
+
x_scale=w2_input_scale,
|
| 498 |
+
quant_mode="pertensor")
|
| 499 |
+
|
| 500 |
+
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
| 501 |
+
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
|
| 502 |
+
final_hidden_states = unsorted_hidden_states.reshape(
|
| 503 |
+
bsz, top_k // ep_size, -1).sum(1)
|
| 504 |
+
|
| 505 |
+
return final_hidden_states
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def fused_experts(
|
| 509 |
+
hidden_states: torch.Tensor,
|
| 510 |
+
w1: torch.Tensor,
|
| 511 |
+
w1_scale: torch.Tensor,
|
| 512 |
+
w1_input_scale: torch.Tensor,
|
| 513 |
+
w1_input_offset: torch.Tensor,
|
| 514 |
+
w2: torch.Tensor,
|
| 515 |
+
w2_scale: torch.Tensor,
|
| 516 |
+
w2_input_scale: torch.Tensor,
|
| 517 |
+
w2_input_offset: torch.Tensor,
|
| 518 |
+
topk_weights: torch.Tensor,
|
| 519 |
+
topk_ids: torch.Tensor,
|
| 520 |
+
top_k: int,
|
| 521 |
+
global_num_experts: int,
|
| 522 |
+
expert_map: torch.Tensor = None,
|
| 523 |
+
) -> torch.Tensor:
|
| 524 |
+
"""
|
| 525 |
+
Fused experts with top-k routing.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 529 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 530 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 531 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 532 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 533 |
+
top_k: Number of experts to select.
|
| 534 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
hidden_states: Hidden states after routing.
|
| 538 |
+
"""
|
| 539 |
+
"""
|
| 540 |
+
# Check constraints.
|
| 541 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
| 542 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
| 543 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 544 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 545 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
original_dtype = hidden_states.dtype
|
| 549 |
+
ep_size = get_ep_group().world_size
|
| 550 |
+
local_num_experts = global_num_experts // ep_size
|
| 551 |
+
w1_input_scale, _ = w1_input_scale.max(0)
|
| 552 |
+
quant_sorted_hidden_states = quant_per_tensor(
|
| 553 |
+
hidden_states,
|
| 554 |
+
w1_input_scale,
|
| 555 |
+
None,
|
| 556 |
+
True,
|
| 557 |
+
)
|
| 558 |
+
if expert_map is not None:
|
| 559 |
+
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
| 560 |
+
quant_sorted_hidden_states,
|
| 561 |
+
topk_ids,
|
| 562 |
+
scale=None,
|
| 563 |
+
active_num=topk_ids.numel(),
|
| 564 |
+
expert_capacity=-1,
|
| 565 |
+
expert_num=local_num_experts,
|
| 566 |
+
drop_pad_mode=0,
|
| 567 |
+
expert_tokens_num_type=1,
|
| 568 |
+
expert_tokens_num_flag=True,
|
| 569 |
+
quant_mode=-1,
|
| 570 |
+
active_expert_range=[0, local_num_experts],
|
| 571 |
+
row_idx_type=0,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
else:
|
| 575 |
+
raise NotImplementedError(
|
| 576 |
+
"The quantified version of MOE class models "
|
| 577 |
+
"currently does not support tensor parallelism")
|
| 578 |
+
if expanded_x.dtype != w1.dtype:
|
| 579 |
+
w1_input_scale, _ = w1_input_scale.max(0)
|
| 580 |
+
quant_sorted_hidden_states = quant_per_tensor(
|
| 581 |
+
expanded_x,
|
| 582 |
+
w1_input_scale,
|
| 583 |
+
None,
|
| 584 |
+
True,
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
quant_sorted_hidden_states = expanded_x
|
| 588 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 589 |
+
x=[quant_sorted_hidden_states],
|
| 590 |
+
weight=[w1],
|
| 591 |
+
scale=[w1_scale * w1_input_scale[0]],
|
| 592 |
+
split_item=2,
|
| 593 |
+
group_list_type=1,
|
| 594 |
+
group_type=0,
|
| 595 |
+
group_list=expert_token_count,
|
| 596 |
+
output_dtype=original_dtype,
|
| 597 |
+
)[0]
|
| 598 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 599 |
+
|
| 600 |
+
if gate_up_out.dtype != w2.dtype:
|
| 601 |
+
w2_input_scale, _ = w2_input_scale.max(0)
|
| 602 |
+
quant_gate_up_out = quant_per_tensor(
|
| 603 |
+
gate_up_out,
|
| 604 |
+
w2_input_scale,
|
| 605 |
+
None,
|
| 606 |
+
True,
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
quant_gate_up_out = gate_up_out
|
| 610 |
+
|
| 611 |
+
down_out = torch_npu.npu_grouped_matmul(
|
| 612 |
+
x=[quant_gate_up_out],
|
| 613 |
+
weight=[w2],
|
| 614 |
+
scale=[w2_scale * w2_input_scale[0]],
|
| 615 |
+
split_item=2,
|
| 616 |
+
group_list_type=1,
|
| 617 |
+
group_type=0,
|
| 618 |
+
group_list=expert_token_count,
|
| 619 |
+
output_dtype=original_dtype,
|
| 620 |
+
)[0]
|
| 621 |
+
|
| 622 |
+
if expert_map is not None:
|
| 623 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 624 |
+
down_out,
|
| 625 |
+
skip1=None,
|
| 626 |
+
skip2=None,
|
| 627 |
+
bias=None,
|
| 628 |
+
scales=topk_weights.to(down_out.dtype),
|
| 629 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 630 |
+
export_for_source_row=topk_ids,
|
| 631 |
+
drop_pad_mode=2,
|
| 632 |
+
)
|
| 633 |
+
else:
|
| 634 |
+
raise NotImplementedError(
|
| 635 |
+
"The quantified version of MOE class models "
|
| 636 |
+
"currently does not support tensor parallelism")
|
| 637 |
+
|
| 638 |
+
return final_hidden_states
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def select_experts(
|
| 642 |
+
hidden_states: torch.Tensor,
|
| 643 |
+
router_logits: torch.Tensor,
|
| 644 |
+
top_k: int,
|
| 645 |
+
use_grouped_topk: bool,
|
| 646 |
+
renormalize: bool,
|
| 647 |
+
topk_group: Optional[int] = None,
|
| 648 |
+
num_expert_group: Optional[int] = None,
|
| 649 |
+
custom_routing_function: Optional[Callable] = None,
|
| 650 |
+
scoring_func: str = "softmax",
|
| 651 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 652 |
+
global_num_experts=-1,
|
| 653 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 654 |
+
"""
|
| 655 |
+
Select top-k experts based on router logits.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 659 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 660 |
+
top_k: Number of experts to select.
|
| 661 |
+
use_grouped_topk: Whether to group experts before selecting top-k.
|
| 662 |
+
renormalize: Whether to renormalize the routing weights.
|
| 663 |
+
topk_group: Number of expert groups to select from.
|
| 664 |
+
num_expert_group: Number of experts in each group.
|
| 665 |
+
custom_routing_function: Custom routing function.
|
| 666 |
+
scoring_func: Scoring function to use.
|
| 667 |
+
e_score_correction_bias: Correction bias to apply to expert scores.
|
| 668 |
+
|
| 669 |
+
Returns:
|
| 670 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 671 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 672 |
+
|
| 673 |
+
Raises:
|
| 674 |
+
ValueError: If an unsupported scoring function is provided.
|
| 675 |
+
"""
|
| 676 |
+
|
| 677 |
+
if scoring_func == "softmax":
|
| 678 |
+
# NOTE: vLLM use dtype=torch.float here
|
| 679 |
+
topk_weights = router_logits.softmax(dim=-1)
|
| 680 |
+
elif scoring_func == "sigmoid":
|
| 681 |
+
topk_weights = router_logits.sigmoid()
|
| 682 |
+
else:
|
| 683 |
+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
| 684 |
+
|
| 685 |
+
if use_grouped_topk:
|
| 686 |
+
assert topk_group is not None
|
| 687 |
+
assert num_expert_group is not None
|
| 688 |
+
|
| 689 |
+
if e_score_correction_bias is not None:
|
| 690 |
+
# Store original scores before applying correction bias. We use biased
|
| 691 |
+
# scores for expert selection but original scores for routing weights
|
| 692 |
+
original_weights = topk_weights
|
| 693 |
+
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
| 694 |
+
|
| 695 |
+
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
| 696 |
+
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
| 697 |
+
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
| 698 |
+
topk_group)
|
| 699 |
+
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
| 700 |
+
if e_score_correction_bias is not None:
|
| 701 |
+
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 702 |
+
k=top_k,
|
| 703 |
+
dim=-1,
|
| 704 |
+
sorted=False)[1]
|
| 705 |
+
# Use original unbiased scores for the routing weights
|
| 706 |
+
topk_weights = original_weights.gather(1, topk_ids)
|
| 707 |
+
else:
|
| 708 |
+
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 709 |
+
k=top_k,
|
| 710 |
+
dim=-1,
|
| 711 |
+
sorted=False)
|
| 712 |
+
elif custom_routing_function is None:
|
| 713 |
+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
| 714 |
+
else:
|
| 715 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 716 |
+
hidden_states=hidden_states,
|
| 717 |
+
gating_output=router_logits,
|
| 718 |
+
topk=top_k,
|
| 719 |
+
renormalize=renormalize,
|
| 720 |
+
global_num_experts=global_num_experts,
|
| 721 |
+
)
|
| 722 |
+
# Required by npu_moe_init_routing
|
| 723 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 724 |
+
return topk_weights, topk_ids
|
| 725 |
+
|
| 726 |
+
# Required by npu_moe_init_routing
|
| 727 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 728 |
+
|
| 729 |
+
if renormalize:
|
| 730 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 731 |
+
|
| 732 |
+
return topk_weights, topk_ids
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def native_grouped_topk(
|
| 736 |
+
topk_weights: torch.Tensor,
|
| 737 |
+
num_expert_group: Optional[int],
|
| 738 |
+
topk_group: Optional[int],
|
| 739 |
+
):
|
| 740 |
+
topk_group = 0 if topk_group is None else topk_group
|
| 741 |
+
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
| 742 |
+
|
| 743 |
+
num_token = topk_weights.shape[0]
|
| 744 |
+
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
| 745 |
+
-1).max(dim=-1).values
|
| 746 |
+
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
| 747 |
+
k=topk_group,
|
| 748 |
+
dim=-1,
|
| 749 |
+
sorted=False)[1]
|
| 750 |
+
topk_group_mask = torch.zeros_like(grouped_weights)
|
| 751 |
+
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
| 752 |
+
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
| 753 |
+
num_token, num_expert_group,
|
| 754 |
+
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
| 755 |
+
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
| 756 |
+
|
| 757 |
+
return topk_weights
|
inference/vllm_ascend/quantization/w8a8_dynamic.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
import torch_npu
|
| 23 |
+
from vllm.distributed import GroupCoordinator
|
| 24 |
+
|
| 25 |
+
import vllm_ascend.envs as envs
|
| 26 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 27 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 28 |
+
from vllm_ascend.ops.fused_moe import select_experts
|
| 29 |
+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
|
| 30 |
+
dispose_tensor, get_fused_moe_state,
|
| 31 |
+
npu_stream_switch, npu_wait_tensor)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def apply_mlp(hidden_states: torch.Tensor,
|
| 35 |
+
w1: torch.Tensor,
|
| 36 |
+
w1_scale: torch.Tensor,
|
| 37 |
+
w2: torch.Tensor,
|
| 38 |
+
w2_scale: torch.Tensor,
|
| 39 |
+
group_list: torch.Tensor,
|
| 40 |
+
dynamic_scale: torch.Tensor = None,
|
| 41 |
+
group_list_type: int = 1) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
apply MLP: gate_up_proj -> swiglu -> down_proj
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
hidden_states: input hidden states with shape (num_tokens, hidden_size).
|
| 47 |
+
w1: expert weights1 with shape
|
| 48 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 49 |
+
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
| 50 |
+
w2: expert weights2 with shape
|
| 51 |
+
(num_experts, intermediate_size, hidden_size)
|
| 52 |
+
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
| 53 |
+
group_list: number of tokens for each expert, follow cumsum mode, and
|
| 54 |
+
with shape (num_experts).
|
| 55 |
+
transpose_weight:
|
| 56 |
+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
| 57 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 58 |
+
w2: (num_experts, hidden_size, intermediate_size) ->
|
| 59 |
+
(num_experts, intermediate_size, hidden_size)
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
hidden_states: output hidden states after MLP.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
if dynamic_scale is None:
|
| 66 |
+
unquantized_hidden_states = hidden_states
|
| 67 |
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
| 68 |
+
hidden_states)
|
| 69 |
+
# Dispose the original unquantized hidden states
|
| 70 |
+
# to save npu memory because they're no longer used.
|
| 71 |
+
dispose_tensor(unquantized_hidden_states)
|
| 72 |
+
else:
|
| 73 |
+
pertoken_scale = dynamic_scale
|
| 74 |
+
|
| 75 |
+
# gmm1: gate_up_proj
|
| 76 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 77 |
+
x=[hidden_states],
|
| 78 |
+
weight=[w1],
|
| 79 |
+
scale=[w1_scale],
|
| 80 |
+
per_token_scale=[pertoken_scale],
|
| 81 |
+
split_item=2,
|
| 82 |
+
group_list_type=group_list_type,
|
| 83 |
+
group_type=0,
|
| 84 |
+
group_list=group_list,
|
| 85 |
+
output_dtype=w2_scale.dtype)[0]
|
| 86 |
+
|
| 87 |
+
# act_fn: swiglu
|
| 88 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 89 |
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
| 90 |
+
hidden_states)
|
| 91 |
+
|
| 92 |
+
# gmm2: down_proj
|
| 93 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 94 |
+
x=[hidden_states],
|
| 95 |
+
weight=[w2],
|
| 96 |
+
scale=[w2_scale],
|
| 97 |
+
per_token_scale=[swiglu_out_scale],
|
| 98 |
+
split_item=2,
|
| 99 |
+
group_list_type=group_list_type,
|
| 100 |
+
group_type=0,
|
| 101 |
+
group_list=group_list,
|
| 102 |
+
output_dtype=w2_scale.dtype)[0]
|
| 103 |
+
|
| 104 |
+
return hidden_states
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def fused_experts_with_mc2(
|
| 108 |
+
hidden_states: torch.Tensor,
|
| 109 |
+
w1: torch.Tensor,
|
| 110 |
+
w2: torch.Tensor,
|
| 111 |
+
w1_scale: torch.Tensor,
|
| 112 |
+
w2_scale: torch.Tensor,
|
| 113 |
+
topk_weights: torch.Tensor,
|
| 114 |
+
topk_ids: torch.Tensor,
|
| 115 |
+
top_k: int,
|
| 116 |
+
expert_map: torch.Tensor = None,
|
| 117 |
+
moe_all_to_all_group_name: str = "",
|
| 118 |
+
log2phy: torch.Tensor = None,
|
| 119 |
+
global_redundant_expert_num: int = 0,
|
| 120 |
+
shared_experts: Optional[Any] = None,
|
| 121 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 122 |
+
if log2phy is not None:
|
| 123 |
+
topk_ids = log2phy[topk_ids]
|
| 124 |
+
global_bs = 0
|
| 125 |
+
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
| 126 |
+
# hidden_states = hidden_states.bfloat16()
|
| 127 |
+
kwargs_mc2 = {
|
| 128 |
+
"x": hidden_states,
|
| 129 |
+
"expert_ids": topk_ids,
|
| 130 |
+
"expert_shard_type": 0,
|
| 131 |
+
"shared_expert_rank_num": 0,
|
| 132 |
+
"moe_expert_num": moe_expert_num,
|
| 133 |
+
"global_bs": global_bs,
|
| 134 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
rank = torch.distributed.get_rank()
|
| 138 |
+
|
| 139 |
+
quant_mode = 2
|
| 140 |
+
ep_group = get_ep_group().device_group
|
| 141 |
+
local_rank = torch.distributed.get_rank(group=ep_group)
|
| 142 |
+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
| 143 |
+
|
| 144 |
+
world_size = torch.distributed.get_world_size()
|
| 145 |
+
tp_size = world_size // all_to_all_group_size
|
| 146 |
+
tp_rank = rank % tp_size
|
| 147 |
+
|
| 148 |
+
stage1_kwargs = {
|
| 149 |
+
"scales": None,
|
| 150 |
+
"quant_mode": quant_mode,
|
| 151 |
+
"group_ep": moe_all_to_all_group_name,
|
| 152 |
+
"ep_world_size": all_to_all_group_size,
|
| 153 |
+
"ep_rank_id": local_rank,
|
| 154 |
+
# "group_tp": self.moe_rs_group_name,
|
| 155 |
+
"group_tp": moe_all_to_all_group_name,
|
| 156 |
+
"tp_world_size": tp_size,
|
| 157 |
+
"tp_rank_id": tp_rank,
|
| 158 |
+
}
|
| 159 |
+
kwargs_mc2.update(stage1_kwargs)
|
| 160 |
+
|
| 161 |
+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
| 162 |
+
# comm_stream.wait_stream(torch.npu.current_stream())
|
| 163 |
+
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
|
| 164 |
+
0:7]
|
| 165 |
+
|
| 166 |
+
if shared_experts is not None:
|
| 167 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 168 |
+
npu_wait_tensor(hidden_states, topk_weights)
|
| 169 |
+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
| 170 |
+
npu_wait_tensor(shared_gate_up[0], expand_x)
|
| 171 |
+
shared_act = shared_experts.act_fn(shared_gate_up)
|
| 172 |
+
|
| 173 |
+
# `expand_x` will be disposed in the `apply_mlp` function
|
| 174 |
+
down_out_list = apply_mlp(expand_x,
|
| 175 |
+
w1,
|
| 176 |
+
w1_scale,
|
| 177 |
+
w2,
|
| 178 |
+
w2_scale,
|
| 179 |
+
expert_token_nums,
|
| 180 |
+
dynamic_scale=dynamic_scale)
|
| 181 |
+
|
| 182 |
+
# moeCombine
|
| 183 |
+
kwargs_mc2 = {
|
| 184 |
+
"expand_x": down_out_list,
|
| 185 |
+
"expert_ids": topk_ids,
|
| 186 |
+
"expand_idx": expand_idx,
|
| 187 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 188 |
+
"expert_shard_type": 0,
|
| 189 |
+
"shared_expert_rank_num": 0,
|
| 190 |
+
"moe_expert_num": moe_expert_num,
|
| 191 |
+
"global_bs": 0,
|
| 192 |
+
"expand_scales": expand_scales,
|
| 193 |
+
}
|
| 194 |
+
tp_recv_counts = torch.empty(1,
|
| 195 |
+
dtype=torch.int32,
|
| 196 |
+
device=hidden_states.device)
|
| 197 |
+
stage3_kwargs = {
|
| 198 |
+
"ep_send_counts": ep_recv_counts,
|
| 199 |
+
"group_ep": moe_all_to_all_group_name,
|
| 200 |
+
"ep_world_size": all_to_all_group_size,
|
| 201 |
+
"ep_rank_id": local_rank,
|
| 202 |
+
"tp_send_counts": tp_recv_counts,
|
| 203 |
+
# "group_tp": self.moe_rs_group_name,
|
| 204 |
+
"group_tp": moe_all_to_all_group_name,
|
| 205 |
+
"tp_world_size": tp_size,
|
| 206 |
+
"tp_rank_id": tp_rank,
|
| 207 |
+
}
|
| 208 |
+
kwargs_mc2.update(stage3_kwargs)
|
| 209 |
+
|
| 210 |
+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
| 211 |
+
|
| 212 |
+
if shared_experts is None:
|
| 213 |
+
return hidden_states
|
| 214 |
+
else:
|
| 215 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 216 |
+
npu_wait_tensor(shared_act[0], down_out_list)
|
| 217 |
+
shared_output, _ = shared_experts.down_proj(shared_act)
|
| 218 |
+
return hidden_states, shared_output
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# currently expert parallelism implemented with all2all
|
| 222 |
+
# is under-optimized.
|
| 223 |
+
def fused_experts_with_all2all(
|
| 224 |
+
hidden_states: torch.Tensor,
|
| 225 |
+
w1: torch.Tensor,
|
| 226 |
+
w1_scale: torch.Tensor,
|
| 227 |
+
w2: torch.Tensor,
|
| 228 |
+
w2_scale: torch.Tensor,
|
| 229 |
+
topk_weights: torch.Tensor,
|
| 230 |
+
topk_ids: torch.Tensor,
|
| 231 |
+
top_k: int,
|
| 232 |
+
expert_map: torch.Tensor = None,
|
| 233 |
+
ep_group: GroupCoordinator = None,
|
| 234 |
+
log2phy: torch.Tensor = None,
|
| 235 |
+
global_redundant_expert_num: int = 0,
|
| 236 |
+
):
|
| 237 |
+
if log2phy is not None:
|
| 238 |
+
topk_ids = log2phy[topk_ids]
|
| 239 |
+
original_shape = hidden_states.shape
|
| 240 |
+
if len(original_shape) == 3:
|
| 241 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 242 |
+
|
| 243 |
+
num_tokens, _ = hidden_states.shape
|
| 244 |
+
num_experts = w1.shape[0]
|
| 245 |
+
device = hidden_states.device
|
| 246 |
+
|
| 247 |
+
if expert_map is not None:
|
| 248 |
+
global_num_experts = len(expert_map) + global_redundant_expert_num
|
| 249 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 250 |
+
row_idx_len = num_tokens * top_k
|
| 251 |
+
row_idx = (torch.arange(0,
|
| 252 |
+
row_idx_len,
|
| 253 |
+
dtype=torch.int32,
|
| 254 |
+
device=device).view(top_k, -1).permute(
|
| 255 |
+
1, 0).contiguous())
|
| 256 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 257 |
+
hidden_states,
|
| 258 |
+
row_idx=row_idx,
|
| 259 |
+
expert_idx=topk_ids,
|
| 260 |
+
active_num=num_tokens)
|
| 261 |
+
|
| 262 |
+
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
| 263 |
+
minlength=global_num_experts)
|
| 264 |
+
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
| 265 |
+
-1).sum(-1)
|
| 266 |
+
|
| 267 |
+
gather_sizes = torch.empty_like(scatter_sizes)
|
| 268 |
+
dist.all_to_all_single(gather_sizes,
|
| 269 |
+
scatter_sizes,
|
| 270 |
+
group=ep_group.device_group)
|
| 271 |
+
scatter_size_list = scatter_sizes.cpu().tolist()
|
| 272 |
+
gather_size_list = gather_sizes.cpu().tolist()
|
| 273 |
+
|
| 274 |
+
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
| 275 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 276 |
+
scatter_size_list,
|
| 277 |
+
gather_size_list)
|
| 278 |
+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
| 279 |
+
scatter_size_list,
|
| 280 |
+
gather_size_list)
|
| 281 |
+
|
| 282 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
| 283 |
+
|
| 284 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 285 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 286 |
+
|
| 287 |
+
hidden_states = hidden_states[sorted_idx]
|
| 288 |
+
group_list_type = 0
|
| 289 |
+
else:
|
| 290 |
+
row_idx_len = num_tokens * top_k
|
| 291 |
+
row_idx = torch.arange(0,
|
| 292 |
+
row_idx_len,
|
| 293 |
+
dtype=torch.int32,
|
| 294 |
+
device=topk_weights.device).view(
|
| 295 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 296 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 297 |
+
hidden_states,
|
| 298 |
+
row_idx=row_idx,
|
| 299 |
+
expert_idx=topk_ids,
|
| 300 |
+
active_num=num_tokens)
|
| 301 |
+
|
| 302 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 303 |
+
expanded_expert_idx, num_experts)
|
| 304 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 305 |
+
group_list_type = 0
|
| 306 |
+
|
| 307 |
+
# `hidden_states` will be disposed in the `apply_mlp` function
|
| 308 |
+
hidden_states = apply_mlp(
|
| 309 |
+
hidden_states,
|
| 310 |
+
w1,
|
| 311 |
+
w1_scale, #17
|
| 312 |
+
w2,
|
| 313 |
+
w2_scale,
|
| 314 |
+
expert_tokens, #16
|
| 315 |
+
group_list_type=group_list_type)
|
| 316 |
+
|
| 317 |
+
if expert_map is not None:
|
| 318 |
+
resorted_idx = torch.argsort(sorted_idx)
|
| 319 |
+
hidden_states = hidden_states[resorted_idx]
|
| 320 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 321 |
+
gather_size_list,
|
| 322 |
+
scatter_size_list)
|
| 323 |
+
|
| 324 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 325 |
+
hidden_states,
|
| 326 |
+
skip1=None,
|
| 327 |
+
skip2=None,
|
| 328 |
+
bias=None,
|
| 329 |
+
scales=topk_weights,
|
| 330 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 331 |
+
export_for_source_row=topk_ids,
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 335 |
+
# implementation here when suitable operators become available.
|
| 336 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 337 |
+
hidden_states,
|
| 338 |
+
skip1=None,
|
| 339 |
+
skip2=None,
|
| 340 |
+
bias=None,
|
| 341 |
+
scales=topk_weights,
|
| 342 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 343 |
+
export_for_source_row=topk_ids,
|
| 344 |
+
)
|
| 345 |
+
if len(original_shape) == 3:
|
| 346 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 347 |
+
return final_hidden_states
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def fused_experts_with_allgather(hidden_states: torch.Tensor,
|
| 351 |
+
w1: torch.Tensor,
|
| 352 |
+
w1_scale: torch.Tensor,
|
| 353 |
+
w2: torch.Tensor,
|
| 354 |
+
w2_scale: torch.Tensor,
|
| 355 |
+
topk_weights: torch.Tensor,
|
| 356 |
+
topk_ids: torch.Tensor,
|
| 357 |
+
top_k: int,
|
| 358 |
+
expert_map: torch.Tensor = None):
|
| 359 |
+
original_shape = hidden_states.shape
|
| 360 |
+
if len(original_shape) == 3:
|
| 361 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 362 |
+
num_tokens = hidden_states.shape[0]
|
| 363 |
+
batch_size, hidden_size = hidden_states.shape
|
| 364 |
+
|
| 365 |
+
ep_group = get_ep_group().device_group
|
| 366 |
+
ep_rank = torch.distributed.get_rank(group=ep_group)
|
| 367 |
+
ep_size = torch.distributed.get_world_size(ep_group)
|
| 368 |
+
|
| 369 |
+
global_num_experts = len(expert_map)
|
| 370 |
+
local_num_experts = global_num_experts // ep_size
|
| 371 |
+
|
| 372 |
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
| 373 |
+
|
| 374 |
+
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
|
| 375 |
+
hidden_states,
|
| 376 |
+
topk_ids,
|
| 377 |
+
scale=pertoken_scale,
|
| 378 |
+
offset=None,
|
| 379 |
+
active_num=num_tokens * top_k,
|
| 380 |
+
expert_num=global_num_experts,
|
| 381 |
+
expert_tokens_num_type=1,
|
| 382 |
+
expert_tokens_num_flag=True,
|
| 383 |
+
active_expert_range=[
|
| 384 |
+
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
|
| 385 |
+
],
|
| 386 |
+
quant_mode=-1,
|
| 387 |
+
row_idx_type=0)
|
| 388 |
+
group_list_type = 1
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 392 |
+
x=[hidden_states],
|
| 393 |
+
weight=[w1],
|
| 394 |
+
split_item=3,
|
| 395 |
+
group_list_type=group_list_type,
|
| 396 |
+
group_type=0,
|
| 397 |
+
group_list=expert_tokens,
|
| 398 |
+
output_dtype=torch.int32)[0]
|
| 399 |
+
|
| 400 |
+
# act_fn: swiglu
|
| 401 |
+
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
|
| 402 |
+
x=hidden_states,
|
| 403 |
+
weight_scale=w1_scale.to(torch.float32),
|
| 404 |
+
activation_scale=pertoken_scale,
|
| 405 |
+
bias=None,
|
| 406 |
+
quant_scale=None,
|
| 407 |
+
quant_offset=None,
|
| 408 |
+
group_index=expert_tokens,
|
| 409 |
+
activate_left=True,
|
| 410 |
+
quant_mode=1,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 414 |
+
x=[hidden_states],
|
| 415 |
+
weight=[w2],
|
| 416 |
+
scale=[w2_scale.to(torch.bfloat16)],
|
| 417 |
+
per_token_scale=[pertoken_scale.view(-1)],
|
| 418 |
+
split_item=3,
|
| 419 |
+
group_list_type=group_list_type,
|
| 420 |
+
group_type=0,
|
| 421 |
+
group_list=expert_tokens,
|
| 422 |
+
output_dtype=torch.bfloat16)[0]
|
| 423 |
+
|
| 424 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 425 |
+
expanded_permuted_rows=hidden_states.unsqueeze(1),
|
| 426 |
+
skip1=None,
|
| 427 |
+
skip2=None,
|
| 428 |
+
bias=None,
|
| 429 |
+
scales=topk_weights.to(torch.bfloat16),
|
| 430 |
+
expanded_src_to_dst_row=expanded_x_idx.to(torch.int32),
|
| 431 |
+
export_for_source_row=topk_ids,
|
| 432 |
+
drop_pad_mode=3
|
| 433 |
+
).to(torch.bfloat16)
|
| 434 |
+
|
| 435 |
+
if len(original_shape) == 3:
|
| 436 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 437 |
+
|
| 438 |
+
return final_hidden_states
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def fused_experts(hidden_states: torch.Tensor,
|
| 442 |
+
w1: torch.Tensor,
|
| 443 |
+
w1_scale: torch.Tensor,
|
| 444 |
+
w2: torch.Tensor,
|
| 445 |
+
w2_scale: torch.Tensor,
|
| 446 |
+
topk_weights: torch.Tensor,
|
| 447 |
+
topk_ids: torch.Tensor,
|
| 448 |
+
top_k: int,
|
| 449 |
+
expert_map: torch.Tensor = None):
|
| 450 |
+
original_shape = hidden_states.shape
|
| 451 |
+
if len(original_shape) == 3:
|
| 452 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 453 |
+
|
| 454 |
+
num_tokens, _ = hidden_states.shape
|
| 455 |
+
num_experts = w1.shape[0]
|
| 456 |
+
dtype = hidden_states.dtype
|
| 457 |
+
device = hidden_states.device
|
| 458 |
+
|
| 459 |
+
if expert_map is not None:
|
| 460 |
+
# Generate token indices and flatten
|
| 461 |
+
token_indices = (torch.arange(num_tokens,
|
| 462 |
+
device=device,
|
| 463 |
+
dtype=torch.int64).unsqueeze(1).expand(
|
| 464 |
+
-1, top_k).reshape(-1))
|
| 465 |
+
|
| 466 |
+
# Flatten token-to-expert mappings and map to local experts
|
| 467 |
+
weights_flat = topk_weights.view(-1)
|
| 468 |
+
experts_flat = topk_ids.view(-1)
|
| 469 |
+
local_experts_flat = expert_map[experts_flat]
|
| 470 |
+
|
| 471 |
+
# Filter valid token-expert pairs
|
| 472 |
+
mask = local_experts_flat != -1
|
| 473 |
+
filtered_weights = torch.where(
|
| 474 |
+
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
| 475 |
+
filtered_experts = torch.where(
|
| 476 |
+
mask, local_experts_flat,
|
| 477 |
+
torch.full_like(local_experts_flat,
|
| 478 |
+
num_experts)).to(topk_ids.dtype)
|
| 479 |
+
|
| 480 |
+
# Sort by local expert IDs
|
| 481 |
+
sort_indices = torch.argsort(filtered_experts)
|
| 482 |
+
sorted_token_indices = token_indices[sort_indices]
|
| 483 |
+
sorted_weights = filtered_weights[sort_indices]
|
| 484 |
+
|
| 485 |
+
# Compute token counts with minlength of num_experts
|
| 486 |
+
# This is equivalent to but faster than:
|
| 487 |
+
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
| 488 |
+
token_counts = torch.zeros(num_experts + 1,
|
| 489 |
+
device=device,
|
| 490 |
+
dtype=torch.int64)
|
| 491 |
+
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
| 492 |
+
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
| 493 |
+
expert_tokens = token_counts[:num_experts]
|
| 494 |
+
# Rearrange hidden_states
|
| 495 |
+
hidden_states = hidden_states[sorted_token_indices]
|
| 496 |
+
group_list_type = 1
|
| 497 |
+
else:
|
| 498 |
+
row_idx_len = num_tokens * top_k
|
| 499 |
+
row_idx = torch.arange(0,
|
| 500 |
+
row_idx_len,
|
| 501 |
+
dtype=torch.int32,
|
| 502 |
+
device=topk_weights.device).view(
|
| 503 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 504 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 505 |
+
hidden_states,
|
| 506 |
+
row_idx=row_idx,
|
| 507 |
+
expert_idx=topk_ids,
|
| 508 |
+
active_num=num_tokens)
|
| 509 |
+
|
| 510 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 511 |
+
expanded_expert_idx, num_experts)
|
| 512 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 513 |
+
group_list_type = 0
|
| 514 |
+
|
| 515 |
+
# `hidden_states` will be disposed in the `apply_mlp` function
|
| 516 |
+
hidden_states = apply_mlp(hidden_states,
|
| 517 |
+
w1,
|
| 518 |
+
w1_scale,
|
| 519 |
+
w2,
|
| 520 |
+
w2_scale,
|
| 521 |
+
expert_tokens,
|
| 522 |
+
group_list_type=group_list_type)
|
| 523 |
+
|
| 524 |
+
if expert_map is not None:
|
| 525 |
+
hidden_states.mul_(sorted_weights.unsqueeze(1))
|
| 526 |
+
final_hidden_states = torch.zeros(*original_shape,
|
| 527 |
+
device=device,
|
| 528 |
+
dtype=dtype)
|
| 529 |
+
|
| 530 |
+
num_valid_tokens = mask.sum()
|
| 531 |
+
valid_token_mask = torch.arange(
|
| 532 |
+
0, sorted_token_indices.shape[0],
|
| 533 |
+
device=device).unsqueeze(1) < num_valid_tokens
|
| 534 |
+
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
|
| 535 |
+
0).to(dtype)
|
| 536 |
+
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
|
| 537 |
+
else:
|
| 538 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 539 |
+
# implementation here when suitable operators become available.
|
| 540 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 541 |
+
hidden_states,
|
| 542 |
+
skip1=None,
|
| 543 |
+
skip2=None,
|
| 544 |
+
bias=None,
|
| 545 |
+
scales=topk_weights,
|
| 546 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 547 |
+
export_for_source_row=topk_ids,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if len(original_shape) == 3:
|
| 551 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 552 |
+
return final_hidden_states
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
class AscendW8A8DynamicLinearMethod:
|
| 556 |
+
"""Linear method for Ascend W8A8_DYNAMIC.
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
def __init__(self):
|
| 560 |
+
self.transpose_weight = True
|
| 561 |
+
|
| 562 |
+
@staticmethod
|
| 563 |
+
def get_weight(input_size: int, output_size: int,
|
| 564 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 565 |
+
params_dict = {
|
| 566 |
+
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
| 567 |
+
}
|
| 568 |
+
return params_dict
|
| 569 |
+
|
| 570 |
+
@staticmethod
|
| 571 |
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 572 |
+
return {}
|
| 573 |
+
|
| 574 |
+
@staticmethod
|
| 575 |
+
def get_perchannel_param(
|
| 576 |
+
output_size: int,
|
| 577 |
+
params_dtype: torch.dtype,
|
| 578 |
+
) -> Dict[str, Any]:
|
| 579 |
+
params_dict = {}
|
| 580 |
+
params_dict["weight_scale"] = torch.empty(output_size,
|
| 581 |
+
1,
|
| 582 |
+
dtype=params_dtype)
|
| 583 |
+
params_dict["weight_offset"] = torch.empty(output_size,
|
| 584 |
+
1,
|
| 585 |
+
dtype=params_dtype)
|
| 586 |
+
return params_dict
|
| 587 |
+
|
| 588 |
+
@staticmethod
|
| 589 |
+
def apply(
|
| 590 |
+
layer: torch.nn.Module,
|
| 591 |
+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
| 592 |
+
bias: Optional[torch.Tensor] = None,
|
| 593 |
+
tp_rank: Optional[int] = 0,
|
| 594 |
+
) -> torch.Tensor:
|
| 595 |
+
config = getattr(layer, "_ascend_quant_config", {})
|
| 596 |
+
if not isinstance(x, tuple):
|
| 597 |
+
output_dtype = config.get("output_dtype", x.dtype)
|
| 598 |
+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
| 599 |
+
else:
|
| 600 |
+
assert "output_dtype" in config.keys(), (
|
| 601 |
+
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
| 602 |
+
f"for pre-quantized input, got config [{config}]")
|
| 603 |
+
output_dtype = config["output_dtype"]
|
| 604 |
+
quantized_x, dynamic_scale = x
|
| 605 |
+
pertoken_scale = (dynamic_scale
|
| 606 |
+
if config.get("pertoken_scale", True) else None)
|
| 607 |
+
|
| 608 |
+
output = torch_npu.npu_quant_matmul(
|
| 609 |
+
quantized_x,
|
| 610 |
+
layer.weight,
|
| 611 |
+
layer.weight_scale,
|
| 612 |
+
pertoken_scale=pertoken_scale,
|
| 613 |
+
bias=bias,
|
| 614 |
+
output_dtype=output_dtype,
|
| 615 |
+
)
|
| 616 |
+
return ((output, dynamic_scale)
|
| 617 |
+
if config.get("return_scale", False) else output)
|
| 618 |
+
|
| 619 |
+
def process_weights_after_loading(self, layer):
|
| 620 |
+
if self.transpose_weight:
|
| 621 |
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
| 622 |
+
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
| 623 |
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
| 624 |
+
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
| 625 |
+
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
| 626 |
+
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class AscendW8A8DynamicFusedMoEMethod:
|
| 630 |
+
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
| 631 |
+
"""
|
| 632 |
+
|
| 633 |
+
def __init__(self):
|
| 634 |
+
self.transpose_weight = True
|
| 635 |
+
|
| 636 |
+
self.ep_group = get_ep_group()
|
| 637 |
+
|
| 638 |
+
ascend_config = get_ascend_config()
|
| 639 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 640 |
+
|
| 641 |
+
try:
|
| 642 |
+
device_group = self.ep_group.device_group
|
| 643 |
+
# TODO: Try local_rank = ep_group.rank_in_group
|
| 644 |
+
local_rank = torch.distributed.get_rank(group=device_group)
|
| 645 |
+
backend = device_group._get_backend(torch.device("npu"))
|
| 646 |
+
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
| 647 |
+
local_rank)
|
| 648 |
+
except AttributeError:
|
| 649 |
+
self.moe_all_to_all_group_name = ""
|
| 650 |
+
|
| 651 |
+
@staticmethod
|
| 652 |
+
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
| 653 |
+
hidden_sizes: int,
|
| 654 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 655 |
+
param_dict = {}
|
| 656 |
+
param_dict["w13_weight"] = torch.empty(num_experts,
|
| 657 |
+
2 *
|
| 658 |
+
intermediate_size_per_partition,
|
| 659 |
+
hidden_sizes,
|
| 660 |
+
dtype=torch.int8)
|
| 661 |
+
param_dict["w2_weight"] = torch.empty(num_experts,
|
| 662 |
+
hidden_sizes,
|
| 663 |
+
intermediate_size_per_partition,
|
| 664 |
+
dtype=torch.int8)
|
| 665 |
+
return param_dict
|
| 666 |
+
|
| 667 |
+
@staticmethod
|
| 668 |
+
def get_dynamic_quant_param(num_experts: int,
|
| 669 |
+
intermediate_size_per_partition: int,
|
| 670 |
+
hidden_sizes: int,
|
| 671 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 672 |
+
param_dict = {}
|
| 673 |
+
param_dict["w13_weight_scale"] = torch.empty(
|
| 674 |
+
num_experts,
|
| 675 |
+
2 * intermediate_size_per_partition,
|
| 676 |
+
1,
|
| 677 |
+
dtype=params_dtype)
|
| 678 |
+
param_dict["w13_weight_offset"] = torch.empty(
|
| 679 |
+
num_experts,
|
| 680 |
+
2 * intermediate_size_per_partition,
|
| 681 |
+
1,
|
| 682 |
+
dtype=params_dtype)
|
| 683 |
+
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
| 684 |
+
hidden_sizes,
|
| 685 |
+
1,
|
| 686 |
+
dtype=params_dtype)
|
| 687 |
+
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
| 688 |
+
hidden_sizes,
|
| 689 |
+
1,
|
| 690 |
+
dtype=params_dtype)
|
| 691 |
+
return param_dict
|
| 692 |
+
|
| 693 |
+
def apply(
|
| 694 |
+
self,
|
| 695 |
+
layer: torch.nn.Module,
|
| 696 |
+
x: torch.Tensor,
|
| 697 |
+
router_logits: torch.Tensor,
|
| 698 |
+
top_k: int,
|
| 699 |
+
renormalize: bool,
|
| 700 |
+
use_grouped_topk: bool = False,
|
| 701 |
+
global_num_experts: int = -1,
|
| 702 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 703 |
+
topk_group: Optional[int] = None,
|
| 704 |
+
num_expert_group: Optional[int] = None,
|
| 705 |
+
custom_routing_function: Optional[Callable] = None,
|
| 706 |
+
scoring_func: str = "softmax",
|
| 707 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 708 |
+
is_prefill: bool = True,
|
| 709 |
+
enable_force_load_balance: bool = True,
|
| 710 |
+
log2phy: torch.Tensor = None,
|
| 711 |
+
global_redundant_expert_num: int = 0,
|
| 712 |
+
shared_experts: Optional[Any] = None,
|
| 713 |
+
**kwargs,
|
| 714 |
+
) -> torch.Tensor:
|
| 715 |
+
assert router_logits.shape[
|
| 716 |
+
1] == global_num_experts, "Number of global experts mismatch"
|
| 717 |
+
|
| 718 |
+
is_deepseek_v3_r1 = global_num_experts == 256
|
| 719 |
+
use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
|
| 720 |
+
|
| 721 |
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
| 722 |
+
if use_grouped_topk and is_deepseek_v3_r1:
|
| 723 |
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
| 724 |
+
router_logits,
|
| 725 |
+
k=top_k, # topk当前写8
|
| 726 |
+
bias=e_score_correction_bias,
|
| 727 |
+
k_group=topk_group, # fix: 4
|
| 728 |
+
group_count=num_expert_group, # fix 8
|
| 729 |
+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
| 730 |
+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
| 731 |
+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
| 732 |
+
# out_flag=False, # todo new api; 第三个输出是否输出
|
| 733 |
+
# y2_flag=False, # old api; 第三个输出是否输出
|
| 734 |
+
routed_scaling_factor=1,
|
| 735 |
+
eps=float(1e-20))
|
| 736 |
+
else:
|
| 737 |
+
topk_weights, topk_ids = select_experts(
|
| 738 |
+
hidden_states=x,
|
| 739 |
+
router_logits=router_logits,
|
| 740 |
+
top_k=top_k,
|
| 741 |
+
use_grouped_topk=use_grouped_topk,
|
| 742 |
+
renormalize=renormalize,
|
| 743 |
+
topk_group=topk_group,
|
| 744 |
+
num_expert_group=num_expert_group,
|
| 745 |
+
custom_routing_function=custom_routing_function,
|
| 746 |
+
scoring_func=scoring_func,
|
| 747 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# this is a naive implementation for experts load balance so as
|
| 751 |
+
# to avoid accumulating too much tokens on a single rank.
|
| 752 |
+
# currently it is only activated when doing profile runs.
|
| 753 |
+
if enable_force_load_balance:
|
| 754 |
+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
| 755 |
+
|
| 756 |
+
topk_weights = topk_weights.to(x.dtype)
|
| 757 |
+
|
| 758 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
| 759 |
+
is_prefill, is_deepseek_v3_r1)
|
| 760 |
+
if fused_moe_state == FusedMoEState.AllGatherEP:
|
| 761 |
+
return fused_experts_with_allgather(
|
| 762 |
+
hidden_states=x,
|
| 763 |
+
w1=layer.w13_weight,
|
| 764 |
+
w1_scale=layer.w13_weight_scale,
|
| 765 |
+
w2=layer.w2_weight,
|
| 766 |
+
w2_scale=layer.w2_weight_scale,
|
| 767 |
+
topk_weights=topk_weights,
|
| 768 |
+
topk_ids=topk_ids,
|
| 769 |
+
top_k=top_k,
|
| 770 |
+
expert_map=expert_map)
|
| 771 |
+
elif fused_moe_state == FusedMoEState.MC2:
|
| 772 |
+
return fused_experts_with_mc2(
|
| 773 |
+
hidden_states=x,
|
| 774 |
+
w1=layer.w13_weight,
|
| 775 |
+
w2=layer.w2_weight,
|
| 776 |
+
w1_scale=layer.w13_weight_scale,
|
| 777 |
+
w2_scale=layer.w2_weight_scale,
|
| 778 |
+
topk_weights=topk_weights,
|
| 779 |
+
topk_ids=topk_ids,
|
| 780 |
+
top_k=top_k,
|
| 781 |
+
expert_map=expert_map,
|
| 782 |
+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
| 783 |
+
log2phy=log2phy,
|
| 784 |
+
global_redundant_expert_num=global_redundant_expert_num,
|
| 785 |
+
shared_experts=shared_experts)
|
| 786 |
+
elif fused_moe_state in [
|
| 787 |
+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
| 788 |
+
]:
|
| 789 |
+
return fused_experts(hidden_states=x,
|
| 790 |
+
w1=layer.w13_weight,
|
| 791 |
+
w1_scale=layer.w13_weight_scale,
|
| 792 |
+
w2=layer.w2_weight,
|
| 793 |
+
w2_scale=layer.w2_weight_scale,
|
| 794 |
+
topk_weights=topk_weights,
|
| 795 |
+
topk_ids=topk_ids,
|
| 796 |
+
top_k=top_k,
|
| 797 |
+
expert_map=expert_map)
|
| 798 |
+
else:
|
| 799 |
+
# The current implementation of deepseek moe splits hidden_states
|
| 800 |
+
# according to tp_size before they are feed into fused_moe module.
|
| 801 |
+
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
| 802 |
+
# dispatch/combine tokens.
|
| 803 |
+
return fused_experts_with_all2all(
|
| 804 |
+
hidden_states=x,
|
| 805 |
+
w1=layer.w13_weight,
|
| 806 |
+
w1_scale=layer.w13_weight_scale,
|
| 807 |
+
w2=layer.w2_weight,
|
| 808 |
+
w2_scale=layer.w2_weight_scale,
|
| 809 |
+
topk_weights=topk_weights,
|
| 810 |
+
topk_ids=topk_ids,
|
| 811 |
+
top_k=top_k,
|
| 812 |
+
expert_map=expert_map,
|
| 813 |
+
ep_group=self.ep_group,
|
| 814 |
+
log2phy=log2phy,
|
| 815 |
+
global_redundant_expert_num=global_redundant_expert_num,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
def process_weights_after_loading(self, layer):
|
| 819 |
+
if self.transpose_weight:
|
| 820 |
+
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
| 821 |
+
1, 2).contiguous()
|
| 822 |
+
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
| 823 |
+
1, 2).contiguous()
|
| 824 |
+
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
| 825 |
+
layer.w13_weight_scale.data.shape[0], -1)
|
| 826 |
+
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
| 827 |
+
layer.w13_weight_offset.data.shape[0], -1)
|
| 828 |
+
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
| 829 |
+
layer.w2_weight_scale.data.shape[0], -1)
|
| 830 |
+
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
| 831 |
+
layer.w2_weight_offset.data.shape[0], -1)
|
inference/vllm_ascend/tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"add_bos_token": true, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguTokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguTokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '<s>[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}
|
inference/vllm_ascend/utils.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# This file is a part of the vllm-ascend project.
|
| 17 |
+
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
import atexit
|
| 21 |
+
import fcntl
|
| 22 |
+
import math
|
| 23 |
+
import os
|
| 24 |
+
import shutil
|
| 25 |
+
from contextlib import contextmanager, nullcontext
|
| 26 |
+
from enum import Enum
|
| 27 |
+
from threading import Lock
|
| 28 |
+
from typing import TYPE_CHECKING, List, Tuple
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch_npu # noqa: F401 # noqa: F401
|
| 32 |
+
from packaging.version import InvalidVersion, Version
|
| 33 |
+
from torch_npu.npu.streams import Event
|
| 34 |
+
from vllm.logger import logger
|
| 35 |
+
|
| 36 |
+
import vllm_ascend.envs as envs
|
| 37 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Recent release of torchair has moved these ops to `.scope`.
|
| 41 |
+
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
| 42 |
+
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
| 43 |
+
except ImportError:
|
| 44 |
+
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
| 45 |
+
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
from vllm.config import VllmConfig
|
| 49 |
+
else:
|
| 50 |
+
VllmConfig = None
|
| 51 |
+
|
| 52 |
+
# NOTE: Currently, we can only capture 1920 graphs at most,
|
| 53 |
+
# due to the limitation of ACL graph. This number is bounded by
|
| 54 |
+
# the number of streams, which is 2048, we save 128 streams
|
| 55 |
+
# as a buffer.
|
| 56 |
+
# Maximum number of graphs that can be captured by ACL Graph
|
| 57 |
+
MAX_CAPTURE_SIZE = 1920
|
| 58 |
+
|
| 59 |
+
ASCEND_QUATIZATION_METHOD = "ascend"
|
| 60 |
+
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
| 61 |
+
|
| 62 |
+
ACL_FORMAT_FRACTAL_ND = 2
|
| 63 |
+
ACL_FORMAT_FRACTAL_NZ = 29
|
| 64 |
+
|
| 65 |
+
_CUSTOM_OP_ENABLED = None
|
| 66 |
+
_IS_310P = None
|
| 67 |
+
_SLEEP_MODE_ENABLED = None
|
| 68 |
+
_CURRENT_STREAM = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def is_310p():
|
| 72 |
+
global _IS_310P
|
| 73 |
+
if _IS_310P is None:
|
| 74 |
+
from vllm_ascend import _build_info # type: ignore
|
| 75 |
+
_IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
|
| 76 |
+
return _IS_310P
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def sleep_mode_enabled():
|
| 80 |
+
global _SLEEP_MODE_ENABLED
|
| 81 |
+
if _SLEEP_MODE_ENABLED is None:
|
| 82 |
+
from vllm_ascend import _build_info # type: ignore
|
| 83 |
+
_SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__
|
| 84 |
+
return _SLEEP_MODE_ENABLED
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _round_up(x: int, align: int):
|
| 88 |
+
# round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc.
|
| 89 |
+
# input: 15, 16 -> output: 16
|
| 90 |
+
# input: 17, 16 -> output: 32
|
| 91 |
+
# input: 30, 16 -> output: 32
|
| 92 |
+
# input: 33, 16 -> output: 48
|
| 93 |
+
# ...
|
| 94 |
+
return (x + align - 1) // align * align
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _custom_pad(x, pad_dims):
|
| 98 |
+
# pad the input tensor to the shape of pad_dims
|
| 99 |
+
# input: (13, 30), pad_dims: [0, 2, 0, 3]
|
| 100 |
+
# output: (16, 32)
|
| 101 |
+
return torch.nn.functional.pad(x, pad_dims)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _custom_reshape(x, target_shape):
|
| 105 |
+
# reshape the input tensor to the shape of target_shape
|
| 106 |
+
# input: (16, 32), target_shape: [1, 16, 2, 16]
|
| 107 |
+
# output: (1, 16, 2, 16)
|
| 108 |
+
return x.reshape(target_shape)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _custom_transpose(x, dim1, dim2):
|
| 112 |
+
# transpose the input tensor
|
| 113 |
+
# input: (1, 16, 2, 16), dim1: 1, dim2: 2
|
| 114 |
+
# output: (1, 2, 16, 16)
|
| 115 |
+
return x.transpose(dim1, dim2)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
| 119 |
+
# in_tensor: (13, 30)
|
| 120 |
+
aux_dims = [1, 0, 0, 16]
|
| 121 |
+
# aux_dims[1]: 16
|
| 122 |
+
aux_dims[1] = _round_up(in_tensor.size(0), 16)
|
| 123 |
+
# aux_dims[2]: 2
|
| 124 |
+
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
|
| 125 |
+
|
| 126 |
+
# after: aux_dims: [1, 16, 2, 16]
|
| 127 |
+
|
| 128 |
+
pad_dims = [0, 0, 0, 0]
|
| 129 |
+
# pad_dims[1]: 2
|
| 130 |
+
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
|
| 131 |
+
# pad_dims[3]: 3
|
| 132 |
+
pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
|
| 133 |
+
|
| 134 |
+
# after: pad_dims: [0, 2, 0, 3]
|
| 135 |
+
|
| 136 |
+
# return: (1, 2, 16, 16)
|
| 137 |
+
return _custom_transpose(
|
| 138 |
+
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
| 139 |
+
2).contiguous()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
num_tokens = mask_tensor.shape[0]
|
| 144 |
+
max_seq_len = mask_tensor.shape[1]
|
| 145 |
+
|
| 146 |
+
tokens_pad = (num_tokens + 15) // 16 * 16
|
| 147 |
+
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
|
| 148 |
+
|
| 149 |
+
mask_tensor_pad = \
|
| 150 |
+
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
| 151 |
+
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
|
| 152 |
+
mask = mask_tensor_pad.reshape(
|
| 153 |
+
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
| 154 |
+
return mask
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def aligned_16(tensor: torch.Tensor):
|
| 158 |
+
"""Aligned tensor for 310P"""
|
| 159 |
+
|
| 160 |
+
# Get the size of the current 0th dimension
|
| 161 |
+
n = tensor.size(0)
|
| 162 |
+
|
| 163 |
+
# Calculate the aligned size
|
| 164 |
+
n_aligned = ((n + 15) // 16) * 16
|
| 165 |
+
|
| 166 |
+
# If already aligned, return the original tensor
|
| 167 |
+
if n == n_aligned:
|
| 168 |
+
return tensor
|
| 169 |
+
|
| 170 |
+
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
|
| 171 |
+
new_tensor = torch.zeros(n_aligned,
|
| 172 |
+
*tensor.shape[1:],
|
| 173 |
+
dtype=tensor.dtype,
|
| 174 |
+
device=tensor.device)
|
| 175 |
+
|
| 176 |
+
# Copy the original tensor to the first N positions of the new tensor
|
| 177 |
+
new_tensor[:n] = tensor
|
| 178 |
+
|
| 179 |
+
return new_tensor
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
|
| 183 |
+
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
| 184 |
+
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
| 185 |
+
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
| 186 |
+
# conversion when using torchair graph mode on 300I Duo platform.
|
| 187 |
+
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
| 188 |
+
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
| 189 |
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
| 190 |
+
|
| 191 |
+
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
| 192 |
+
if not is_310p() or not use_torchair:
|
| 193 |
+
return
|
| 194 |
+
for module in model.modules():
|
| 195 |
+
if isinstance(module, FusedMoE):
|
| 196 |
+
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
| 197 |
+
return
|
| 198 |
+
module.w13_weight.data = torch_npu.npu_format_cast(
|
| 199 |
+
module.w13_weight.data, format)
|
| 200 |
+
module.w2_weight.data = torch_npu.npu_format_cast(
|
| 201 |
+
module.w2_weight.data, format)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def try_register_lib(lib_name: str, lib_info: str = ""):
|
| 205 |
+
import importlib
|
| 206 |
+
import importlib.util
|
| 207 |
+
try:
|
| 208 |
+
module_spec = importlib.util.find_spec(lib_name)
|
| 209 |
+
if module_spec is not None:
|
| 210 |
+
importlib.import_module(lib_name)
|
| 211 |
+
if lib_info:
|
| 212 |
+
logger.info(lib_info)
|
| 213 |
+
except Exception:
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def enable_custom_op():
|
| 218 |
+
"""
|
| 219 |
+
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
|
| 220 |
+
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
|
| 221 |
+
"""
|
| 222 |
+
global _CUSTOM_OP_ENABLED
|
| 223 |
+
if _CUSTOM_OP_ENABLED is not None:
|
| 224 |
+
return _CUSTOM_OP_ENABLED
|
| 225 |
+
try:
|
| 226 |
+
# register custom ops into torch_library here
|
| 227 |
+
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
| 228 |
+
_CUSTOM_OP_ENABLED = True
|
| 229 |
+
except ImportError:
|
| 230 |
+
_CUSTOM_OP_ENABLED = False
|
| 231 |
+
logger.warning(
|
| 232 |
+
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
| 233 |
+
)
|
| 234 |
+
return _CUSTOM_OP_ENABLED
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def find_hccl_library() -> str:
|
| 238 |
+
"""
|
| 239 |
+
We either use the library file specified by the `HCCL_SO_PATH`
|
| 240 |
+
environment variable, or we find the library file brought by PyTorch.
|
| 241 |
+
After importing `torch`, `libhccl.so` can be
|
| 242 |
+
found by `ctypes` automatically.
|
| 243 |
+
"""
|
| 244 |
+
so_file = envs.HCCL_SO_PATH
|
| 245 |
+
|
| 246 |
+
# manually load the hccl library
|
| 247 |
+
if so_file:
|
| 248 |
+
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
|
| 249 |
+
so_file)
|
| 250 |
+
else:
|
| 251 |
+
if torch.version.cann is not None:
|
| 252 |
+
so_file = "libhccl.so"
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError("HCCL only supports Ascend NPU backends.")
|
| 255 |
+
logger.info("Found hccl from library %s", so_file)
|
| 256 |
+
return so_file
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def current_stream() -> torch.npu.Stream:
|
| 260 |
+
"""
|
| 261 |
+
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
|
| 262 |
+
it turns out that `torch.npu.current_stream()` is quite expensive,
|
| 263 |
+
as it will construct a new stream object at each call.
|
| 264 |
+
here we patch `torch.npu.set_stream` to keep track of the current stream
|
| 265 |
+
directly, so that we can avoid calling `torch.npu.current_stream()`.
|
| 266 |
+
|
| 267 |
+
"""
|
| 268 |
+
global _CURRENT_STREAM
|
| 269 |
+
if _CURRENT_STREAM is None:
|
| 270 |
+
# when this function is called before any stream is set,
|
| 271 |
+
# we return the default stream.
|
| 272 |
+
_CURRENT_STREAM = torch.npu.current_stream()
|
| 273 |
+
return _CURRENT_STREAM
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def adapt_patch(is_global_patch: bool = False):
|
| 277 |
+
if is_global_patch:
|
| 278 |
+
from vllm_ascend.patch import platform # noqa: F401
|
| 279 |
+
else:
|
| 280 |
+
from vllm_ascend.patch import worker # noqa: F401
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def vllm_version_is(target_vllm_version: str):
|
| 284 |
+
if envs.VLLM_VERSION is not None:
|
| 285 |
+
vllm_version = envs.VLLM_VERSION
|
| 286 |
+
else:
|
| 287 |
+
import vllm
|
| 288 |
+
vllm_version = vllm.__version__
|
| 289 |
+
try:
|
| 290 |
+
return Version(vllm_version) == Version(target_vllm_version)
|
| 291 |
+
except InvalidVersion:
|
| 292 |
+
raise ValueError(
|
| 293 |
+
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
|
| 294 |
+
"is installed probably. Set the environment variable VLLM_VERSION "
|
| 295 |
+
"to control it by hand. And please make sure the value follows the "
|
| 296 |
+
"format of x.y.z.")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
| 300 |
+
"""Update ACL graph capture sizes based on hardware limitations"""
|
| 301 |
+
# Store original configuration and temporarily clear it
|
| 302 |
+
compilation_config = vllm_config.compilation_config
|
| 303 |
+
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
| 304 |
+
compilation_config.cudagraph_capture_sizes, None
|
| 305 |
+
|
| 306 |
+
# Calculate parallel configuration factor
|
| 307 |
+
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
| 308 |
+
parallel_config = vllm_config.parallel_config
|
| 309 |
+
|
| 310 |
+
# TODO: Find out whether we need to take into account the pp_size
|
| 311 |
+
parallel_factor = 1 + sum(size > 1 for size in [
|
| 312 |
+
parallel_config.data_parallel_size_local,
|
| 313 |
+
parallel_config.tensor_parallel_size,
|
| 314 |
+
parallel_config.expert_parallel_size,
|
| 315 |
+
parallel_config.expert_tensor_parallel_size,
|
| 316 |
+
])
|
| 317 |
+
|
| 318 |
+
# Calculate maximum supported batch sizes considering model architecture
|
| 319 |
+
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
| 320 |
+
(num_hidden_layers + 1) / parallel_factor)
|
| 321 |
+
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
|
| 322 |
+
max_num_batch_sizes)
|
| 323 |
+
|
| 324 |
+
# If original sizes exceed maximum, sample a representative subset
|
| 325 |
+
if max_num_batch_sizes < len(original_sizes):
|
| 326 |
+
# Sample uniformly from original sizes
|
| 327 |
+
step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
|
| 328 |
+
indices = [round(i * step) for i in range(max_num_batch_sizes)]
|
| 329 |
+
|
| 330 |
+
# Ensure first and last elements are preserved
|
| 331 |
+
indices[0], indices[-1] = 0, len(original_sizes) - 1
|
| 332 |
+
|
| 333 |
+
sampled_sizes = [original_sizes[i] for i in indices]
|
| 334 |
+
compilation_config.init_with_cudagraph_sizes(sampled_sizes)
|
| 335 |
+
|
| 336 |
+
logger.info(
|
| 337 |
+
"Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
|
| 338 |
+
vllm_config.model_config.architectures[0],
|
| 339 |
+
num_hidden_layers,
|
| 340 |
+
len(original_sizes),
|
| 341 |
+
len(compilation_config.
|
| 342 |
+
cudagraph_capture_sizes # type: ignore[arg-type]
|
| 343 |
+
))
|
| 344 |
+
else:
|
| 345 |
+
# No adjustment needed
|
| 346 |
+
compilation_config.cudagraph_capture_sizes = original_sizes
|
| 347 |
+
logger.info(
|
| 348 |
+
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
| 349 |
+
vllm_config.model_config.architectures[0], num_hidden_layers,
|
| 350 |
+
len(original_sizes))
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# TODO(wxy): Move to ops module
|
| 354 |
+
def dispose_tensor(x: torch.Tensor):
|
| 355 |
+
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class ProfileExecuteDuration:
|
| 359 |
+
_instance = None
|
| 360 |
+
_observations: List[Tuple[str, Event, Event]] = []
|
| 361 |
+
_lock = Lock()
|
| 362 |
+
|
| 363 |
+
def __new__(cls):
|
| 364 |
+
with cls._lock:
|
| 365 |
+
if cls._instance is None:
|
| 366 |
+
cls._instance = super().__new__(cls)
|
| 367 |
+
atexit.register(cls._instance.destroy)
|
| 368 |
+
return cls._instance
|
| 369 |
+
|
| 370 |
+
def destroy(self):
|
| 371 |
+
with self._lock:
|
| 372 |
+
self._observations.clear()
|
| 373 |
+
|
| 374 |
+
@contextmanager
|
| 375 |
+
def capture_async(self, duration_tag: str):
|
| 376 |
+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
| 377 |
+
yield
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
observe_start = Event(enable_timing=True)
|
| 381 |
+
observe_start.record()
|
| 382 |
+
try:
|
| 383 |
+
yield
|
| 384 |
+
finally:
|
| 385 |
+
observe_end = Event(enable_timing=True)
|
| 386 |
+
observe_end.record()
|
| 387 |
+
with self._lock:
|
| 388 |
+
self._observations.append(
|
| 389 |
+
(duration_tag, observe_start, observe_end))
|
| 390 |
+
|
| 391 |
+
def pop_captured_sync(self) -> dict:
|
| 392 |
+
"""Pop and synchronize all events in the observation list"""
|
| 393 |
+
durations: dict[str, float] = {}
|
| 394 |
+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
| 395 |
+
return durations
|
| 396 |
+
|
| 397 |
+
while self._observations:
|
| 398 |
+
with self._lock:
|
| 399 |
+
tag, observe_start, observe_end = self._observations.pop()
|
| 400 |
+
observe_end.synchronize()
|
| 401 |
+
durations[tag] = observe_start.elapsed_time(observe_end)
|
| 402 |
+
|
| 403 |
+
return durations
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# TODO(wxy): Move to ops module
|
| 407 |
+
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
| 408 |
+
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# TODO(wxy): Move to ops module
|
| 412 |
+
def npu_wait_tensor(self: torch.Tensor,
|
| 413 |
+
dependency: torch.Tensor,
|
| 414 |
+
*,
|
| 415 |
+
enabled: bool = True):
|
| 416 |
+
return _npu_wait_tensor(self, dependency) if enabled else self
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# TODO(wxy): Move to ops module
|
| 420 |
+
def npu_prefetch(input: torch.Tensor,
|
| 421 |
+
dependency: torch.Tensor,
|
| 422 |
+
max_size: int = 0,
|
| 423 |
+
*,
|
| 424 |
+
enabled: bool = True):
|
| 425 |
+
if not enabled:
|
| 426 |
+
return
|
| 427 |
+
input_size = input.element_size() * input.numel()
|
| 428 |
+
if max_size <= 0 or max_size > input_size:
|
| 429 |
+
max_size = input_size
|
| 430 |
+
torch_npu.npu_prefetch(input, dependency, max_size)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# TODO(zzzzwwjj): move this into forward_context
|
| 434 |
+
class FusedMoEState(Enum):
|
| 435 |
+
AllGather = 0
|
| 436 |
+
All2All = 1
|
| 437 |
+
MC2 = 2
|
| 438 |
+
AllGatherEP = 3
|
| 439 |
+
NaiveMulticast = 4
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# TODO(ttanzhiqiang): rm_router_logits
|
| 443 |
+
# dp>1 will trigger
|
| 444 |
+
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
| 445 |
+
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
| 446 |
+
is_deepseek_v3_r1: bool):
|
| 447 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 448 |
+
# only supports deepseek v3/r1
|
| 449 |
+
if dp_size > 1:
|
| 450 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 451 |
+
and is_deepseek_v3_r1):
|
| 452 |
+
return True
|
| 453 |
+
elif ep_size == 1 and is_deepseek_v3_r1:
|
| 454 |
+
return True
|
| 455 |
+
return False
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# TODO(ttanzhiqiang): all_reduce merge
|
| 459 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 460 |
+
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
| 461 |
+
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
| 462 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 463 |
+
# only supports deepseek v3/r1
|
| 464 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 465 |
+
and is_deepseek_v3_r1):
|
| 466 |
+
return True
|
| 467 |
+
elif ep_size == 1 and is_deepseek_v3_r1:
|
| 468 |
+
return True
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# TODO(zzzzwwjj): add soc_version to choose branch
|
| 473 |
+
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
| 474 |
+
is_deepseek_v3_r1: bool):
|
| 475 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 476 |
+
# only supports deepseek v3/r1
|
| 477 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 478 |
+
and is_deepseek_v3_r1 and not with_prefill):
|
| 479 |
+
return FusedMoEState.AllGatherEP
|
| 480 |
+
elif ep_size == 1:
|
| 481 |
+
if with_prefill:
|
| 482 |
+
return FusedMoEState.NaiveMulticast
|
| 483 |
+
else:
|
| 484 |
+
return FusedMoEState.AllGather
|
| 485 |
+
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
| 486 |
+
elif ep_size < 16 or with_prefill:
|
| 487 |
+
return FusedMoEState.All2All
|
| 488 |
+
else:
|
| 489 |
+
return FusedMoEState.MC2
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
| 493 |
+
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
| 494 |
+
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
|
| 495 |
+
TORCHAIR_CACHE_DIR = os.getenv(
|
| 496 |
+
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def get_torchair_current_work_dir(file_name=None):
|
| 500 |
+
if file_name is None:
|
| 501 |
+
return TORCHAIR_CACHE_DIR
|
| 502 |
+
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def check_torchair_cache_exist():
|
| 506 |
+
res = False
|
| 507 |
+
torch_air_abs_path = get_torchair_current_work_dir()
|
| 508 |
+
if os.path.exists(torch_air_abs_path):
|
| 509 |
+
file_list = os.listdir(torch_air_abs_path)
|
| 510 |
+
if len(file_list) != 0:
|
| 511 |
+
res = True
|
| 512 |
+
return res
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def check_kv_cache_bytes_cache_exist():
|
| 516 |
+
res = False
|
| 517 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 518 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 519 |
+
if os.path.exists(kv_cache_bytes_cache_abs_path):
|
| 520 |
+
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
|
| 521 |
+
if len(file_list) != 0:
|
| 522 |
+
res = True
|
| 523 |
+
return res
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def read_kv_cache_bytes_from_file(rank) -> int:
|
| 527 |
+
kv_cache_bytes = -1
|
| 528 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 529 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 530 |
+
kv_cache_bytes_file = os.path.join(
|
| 531 |
+
kv_cache_bytes_cache_abs_path,
|
| 532 |
+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
| 533 |
+
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
|
| 534 |
+
with file_lock(f, fcntl.LOCK_SH):
|
| 535 |
+
kv_cache_bytes = int(f.readline())
|
| 536 |
+
return kv_cache_bytes
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@contextmanager
|
| 540 |
+
def file_lock(file_descriptor, lock_type):
|
| 541 |
+
fcntl.flock(file_descriptor, lock_type)
|
| 542 |
+
try:
|
| 543 |
+
yield
|
| 544 |
+
finally:
|
| 545 |
+
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
|
| 549 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 550 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 551 |
+
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
|
| 552 |
+
kv_cache_bytes_file = os.path.join(
|
| 553 |
+
kv_cache_bytes_cache_abs_path,
|
| 554 |
+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
| 555 |
+
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
|
| 556 |
+
with file_lock(f, fcntl.LOCK_EX):
|
| 557 |
+
f.write(f"{kv_cache_bytes}")
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def delete_torchair_cache_file():
|
| 561 |
+
torch_air_abs_path = get_torchair_current_work_dir()
|
| 562 |
+
if os.path.exists(torch_air_abs_path):
|
| 563 |
+
shutil.rmtree(torch_air_abs_path)
|
inference/vllm_ascend/worker/model_runner_v1.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
inference/vllm_ascend/worker/npu_input_batch.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# This file is a part of the vllm-ascend project.
|
| 17 |
+
# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional, cast, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from vllm.lora.request import LoRARequest
|
| 26 |
+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
| 27 |
+
from vllm.pooling_params import PoolingParams
|
| 28 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 29 |
+
from vllm.utils import swap_dict_values
|
| 30 |
+
from vllm.v1.outputs import LogprobsTensors
|
| 31 |
+
from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
|
| 32 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 33 |
+
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
| 34 |
+
from vllm.v1.utils import copy_slice
|
| 35 |
+
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
| 36 |
+
|
| 37 |
+
from vllm_ascend.pool.metadata import PoolingMetadata
|
| 38 |
+
|
| 39 |
+
_SAMPLING_EPS = 1e-5
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class CachedRequestState:
|
| 44 |
+
|
| 45 |
+
req_id: str
|
| 46 |
+
prompt_token_ids: list[int]
|
| 47 |
+
mm_inputs: list[MultiModalKwargs]
|
| 48 |
+
mm_positions: list[PlaceholderRange]
|
| 49 |
+
sampling_params: Optional[SamplingParams]
|
| 50 |
+
pooling_params: Optional[PoolingParams]
|
| 51 |
+
generator: Optional[torch.Generator]
|
| 52 |
+
|
| 53 |
+
block_ids: tuple[list[int], ...]
|
| 54 |
+
num_computed_tokens: int
|
| 55 |
+
output_token_ids: list[int]
|
| 56 |
+
|
| 57 |
+
mrope_positions: Optional[torch.Tensor] = None
|
| 58 |
+
mrope_position_delta: Optional[int] = None
|
| 59 |
+
|
| 60 |
+
lora_request: Optional[LoRARequest] = None
|
| 61 |
+
|
| 62 |
+
def __post_init__(self):
|
| 63 |
+
self.num_prompt_tokens = len(self.prompt_token_ids)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def num_tokens(self) -> int:
|
| 67 |
+
return self.num_prompt_tokens + len(self.output_token_ids)
|
| 68 |
+
|
| 69 |
+
def get_token_id(self, idx: int) -> int:
|
| 70 |
+
if idx < self.num_prompt_tokens:
|
| 71 |
+
return self.prompt_token_ids[idx]
|
| 72 |
+
else:
|
| 73 |
+
return self.output_token_ids[idx - self.num_prompt_tokens]
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class SamplingMetadataTopNSigma(SamplingMetadata):
|
| 77 |
+
top_n_sigma: torch.Tensor
|
| 78 |
+
no_top_n_sigma: bool
|
| 79 |
+
|
| 80 |
+
class InputBatch:
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
max_num_reqs: int,
|
| 85 |
+
max_model_len: int,
|
| 86 |
+
max_num_batched_tokens: int,
|
| 87 |
+
device: torch.device,
|
| 88 |
+
pin_memory: bool,
|
| 89 |
+
vocab_size: int,
|
| 90 |
+
block_sizes: list[int], # The block_size of each kv cache group
|
| 91 |
+
logits_processing_needs_token_ids: bool = False,
|
| 92 |
+
is_spec_decode: bool = False,
|
| 93 |
+
):
|
| 94 |
+
self.is_spec_decode = is_spec_decode
|
| 95 |
+
self.max_num_reqs = max_num_reqs
|
| 96 |
+
self.max_model_len = max_model_len
|
| 97 |
+
self.max_num_batched_tokens = max_num_batched_tokens
|
| 98 |
+
self.device = device
|
| 99 |
+
self.pin_memory = pin_memory
|
| 100 |
+
self.vocab_size = vocab_size
|
| 101 |
+
self.logits_processing_needs_token_ids = (
|
| 102 |
+
logits_processing_needs_token_ids)
|
| 103 |
+
|
| 104 |
+
self._req_ids: list[Optional[str]] = []
|
| 105 |
+
self.req_id_to_index: dict[str, int] = {}
|
| 106 |
+
|
| 107 |
+
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
| 108 |
+
# Find a way to reduce the CPU memory usage.
|
| 109 |
+
# This buffer is not directly transferred to the NPU, so it does not
|
| 110 |
+
# need to be pinned.
|
| 111 |
+
self.token_ids_cpu_tensor = torch.zeros(
|
| 112 |
+
(max_num_reqs, max_model_len),
|
| 113 |
+
device="cpu",
|
| 114 |
+
dtype=torch.int32,
|
| 115 |
+
pin_memory=False,
|
| 116 |
+
)
|
| 117 |
+
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
| 118 |
+
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 119 |
+
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
| 120 |
+
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 121 |
+
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
| 122 |
+
(max_num_reqs, ),
|
| 123 |
+
device="cpu",
|
| 124 |
+
dtype=torch.int32,
|
| 125 |
+
pin_memory=pin_memory,
|
| 126 |
+
)
|
| 127 |
+
self.num_computed_tokens_cpu = \
|
| 128 |
+
self.num_computed_tokens_cpu_tensor.numpy()
|
| 129 |
+
|
| 130 |
+
# Block table.
|
| 131 |
+
self.block_table = MultiGroupBlockTable(
|
| 132 |
+
max_num_reqs=max_num_reqs,
|
| 133 |
+
max_model_len=max_model_len,
|
| 134 |
+
max_num_batched_tokens=max_num_batched_tokens,
|
| 135 |
+
pin_memory=pin_memory,
|
| 136 |
+
device=device,
|
| 137 |
+
block_sizes=block_sizes,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Sampling-related.
|
| 141 |
+
self.temperature = torch.empty((max_num_reqs, ),
|
| 142 |
+
dtype=torch.float32,
|
| 143 |
+
device=device)
|
| 144 |
+
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 145 |
+
dtype=torch.float32,
|
| 146 |
+
device="cpu",
|
| 147 |
+
pin_memory=pin_memory)
|
| 148 |
+
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
| 149 |
+
self.greedy_reqs: set[str] = set()
|
| 150 |
+
self.random_reqs: set[str] = set()
|
| 151 |
+
|
| 152 |
+
self.top_p = torch.empty((max_num_reqs, ),
|
| 153 |
+
dtype=torch.float32,
|
| 154 |
+
device=device)
|
| 155 |
+
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 156 |
+
dtype=torch.float32,
|
| 157 |
+
device="cpu",
|
| 158 |
+
pin_memory=pin_memory)
|
| 159 |
+
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
| 160 |
+
self.top_p_reqs: set[str] = set()
|
| 161 |
+
|
| 162 |
+
self.top_k = torch.empty((max_num_reqs, ),
|
| 163 |
+
dtype=torch.int32,
|
| 164 |
+
device=device)
|
| 165 |
+
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 166 |
+
dtype=torch.int32,
|
| 167 |
+
device="cpu",
|
| 168 |
+
pin_memory=pin_memory)
|
| 169 |
+
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
| 170 |
+
self.top_k_reqs: set[str] = set()
|
| 171 |
+
|
| 172 |
+
# IDs of requests which do not support spec decoding
|
| 173 |
+
self.spec_decode_unsupported_reqs: set[str] = set()
|
| 174 |
+
|
| 175 |
+
self.min_p = torch.empty((max_num_reqs, ),
|
| 176 |
+
dtype=torch.float32,
|
| 177 |
+
device=device)
|
| 178 |
+
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 179 |
+
dtype=torch.float32,
|
| 180 |
+
device="cpu",
|
| 181 |
+
pin_memory=pin_memory)
|
| 182 |
+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
| 183 |
+
self.min_p_reqs: set[str] = set()
|
| 184 |
+
|
| 185 |
+
# topnsigma penalty
|
| 186 |
+
self.top_n_sigma = torch.empty((max_num_reqs, ),
|
| 187 |
+
dtype=torch.float,
|
| 188 |
+
device=device)
|
| 189 |
+
self.top_n_sigma_cpu_tensor = torch.empty(
|
| 190 |
+
(max_num_reqs, ),
|
| 191 |
+
dtype=torch.float,
|
| 192 |
+
device="cpu",
|
| 193 |
+
pin_memory=pin_memory)
|
| 194 |
+
self.top_n_sigma_cpu = \
|
| 195 |
+
self.top_n_sigma_cpu_tensor.numpy()
|
| 196 |
+
self.top_n_sigma_reqs: set[str] = set()
|
| 197 |
+
|
| 198 |
+
# Frequency penalty related data structures
|
| 199 |
+
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
| 200 |
+
dtype=torch.float,
|
| 201 |
+
device=device)
|
| 202 |
+
self.frequency_penalties_cpu_tensor = torch.empty(
|
| 203 |
+
(max_num_reqs, ),
|
| 204 |
+
dtype=torch.float,
|
| 205 |
+
device="cpu",
|
| 206 |
+
pin_memory=pin_memory)
|
| 207 |
+
self.frequency_penalties_cpu = \
|
| 208 |
+
self.frequency_penalties_cpu_tensor.numpy()
|
| 209 |
+
self.frequency_penalties_reqs: set[str] = set()
|
| 210 |
+
|
| 211 |
+
# Presence penalty related data structures
|
| 212 |
+
self.presence_penalties = torch.empty((max_num_reqs, ),
|
| 213 |
+
dtype=torch.float,
|
| 214 |
+
device=device)
|
| 215 |
+
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 216 |
+
dtype=torch.float,
|
| 217 |
+
device="cpu",
|
| 218 |
+
pin_memory=pin_memory)
|
| 219 |
+
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
| 220 |
+
)
|
| 221 |
+
self.presence_penalties_reqs: set[str] = set()
|
| 222 |
+
|
| 223 |
+
# Repetition penalty related data structures
|
| 224 |
+
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
| 225 |
+
dtype=torch.float,
|
| 226 |
+
device=device)
|
| 227 |
+
self.repetition_penalties_cpu_tensor = torch.empty(
|
| 228 |
+
(max_num_reqs, ),
|
| 229 |
+
dtype=torch.float,
|
| 230 |
+
device="cpu",
|
| 231 |
+
pin_memory=pin_memory)
|
| 232 |
+
self.repetition_penalties_cpu = \
|
| 233 |
+
self.repetition_penalties_cpu_tensor.numpy()
|
| 234 |
+
self.repetition_penalties_reqs: set[str] = set()
|
| 235 |
+
|
| 236 |
+
# req_index -> (min_tokens, stop_token_ids)
|
| 237 |
+
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
| 238 |
+
|
| 239 |
+
# lora related
|
| 240 |
+
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
| 241 |
+
dtype=np.int32)
|
| 242 |
+
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
| 243 |
+
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
| 244 |
+
|
| 245 |
+
# req_index -> generator
|
| 246 |
+
# NOTE(woosuk): The indices of the requests that do not have their own
|
| 247 |
+
# generator should not be included in the dictionary.
|
| 248 |
+
self.generators: dict[int, torch.Generator] = {}
|
| 249 |
+
|
| 250 |
+
self.num_logprobs: dict[str, int] = {}
|
| 251 |
+
# NOTE(rob): num_prompt_logprobs only includes reqs
|
| 252 |
+
# that are currently in the prefill phase.
|
| 253 |
+
self.num_prompt_logprobs: dict[str, int] = {}
|
| 254 |
+
|
| 255 |
+
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
| 256 |
+
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
| 257 |
+
|
| 258 |
+
self.logit_bias: list[Optional[dict[int,
|
| 259 |
+
float]]] = [None] * max_num_reqs
|
| 260 |
+
self.has_allowed_token_ids: set[str] = set()
|
| 261 |
+
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
| 262 |
+
# the value is False. Since we use masked_fill_ to set -inf.
|
| 263 |
+
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 264 |
+
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
| 265 |
+
|
| 266 |
+
# req_index -> bad_words_token_ids
|
| 267 |
+
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
| 268 |
+
|
| 269 |
+
self.req_output_token_ids: list[Optional[list[int]]] = []
|
| 270 |
+
|
| 271 |
+
# Define logits processors.
|
| 272 |
+
# TODO(andy): logits processor list should be extensible via engine
|
| 273 |
+
# constructor argument; for now the list is fixed.
|
| 274 |
+
self.logitsprocs = init_builtin_logitsprocs(
|
| 275 |
+
pin_memory_available=pin_memory,
|
| 276 |
+
max_num_reqs=max_num_reqs + 1,
|
| 277 |
+
device=device)
|
| 278 |
+
|
| 279 |
+
# This is updated each time the batch constituents change.
|
| 280 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 281 |
+
|
| 282 |
+
self.pooling_params: dict[str, PoolingParams] = {}
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def req_ids(self) -> list[str]:
|
| 286 |
+
# None elements should only be present transiently
|
| 287 |
+
# while performing state updates to the batch.
|
| 288 |
+
return cast(list[str], self._req_ids)
|
| 289 |
+
|
| 290 |
+
def add_request(
|
| 291 |
+
self,
|
| 292 |
+
request: "CachedRequestState",
|
| 293 |
+
req_index: Optional[int] = None,
|
| 294 |
+
) -> None:
|
| 295 |
+
if req_index is None:
|
| 296 |
+
req_index = self.num_reqs
|
| 297 |
+
assert req_index < self.max_num_reqs
|
| 298 |
+
|
| 299 |
+
req_id = request.req_id
|
| 300 |
+
if req_index == len(self._req_ids):
|
| 301 |
+
self._req_ids.append(req_id)
|
| 302 |
+
self.req_output_token_ids.append(request.output_token_ids)
|
| 303 |
+
else:
|
| 304 |
+
self._req_ids[req_index] = req_id
|
| 305 |
+
self.req_output_token_ids[req_index] = request.output_token_ids
|
| 306 |
+
|
| 307 |
+
self.req_id_to_index[req_id] = req_index
|
| 308 |
+
|
| 309 |
+
# Copy the prompt token ids and output token ids.
|
| 310 |
+
num_prompt_tokens = len(request.prompt_token_ids)
|
| 311 |
+
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
| 312 |
+
self.token_ids_cpu[
|
| 313 |
+
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
| 314 |
+
start_idx = num_prompt_tokens
|
| 315 |
+
end_idx = start_idx + len(request.output_token_ids)
|
| 316 |
+
self.token_ids_cpu[req_index,
|
| 317 |
+
start_idx:end_idx] = request.output_token_ids
|
| 318 |
+
# Number of token ids in token_ids_cpu.
|
| 319 |
+
# NOTE(woosuk): This may include spec decode tokens.
|
| 320 |
+
self.num_tokens[req_index] = request.num_tokens
|
| 321 |
+
# Number of tokens without spec decode tokens.
|
| 322 |
+
self.num_tokens_no_spec[req_index] = request.num_tokens
|
| 323 |
+
|
| 324 |
+
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
| 325 |
+
self.block_table.add_row(request.block_ids, req_index)
|
| 326 |
+
|
| 327 |
+
if sampling_params := request.sampling_params:
|
| 328 |
+
if self.is_spec_decode and is_spec_decode_unsupported(
|
| 329 |
+
sampling_params):
|
| 330 |
+
self.spec_decode_unsupported_reqs.add(req_id)
|
| 331 |
+
if sampling_params.sampling_type == SamplingType.GREEDY:
|
| 332 |
+
# Avoid later division by zero.
|
| 333 |
+
self.temperature_cpu[req_index] = -1.0
|
| 334 |
+
self.greedy_reqs.add(req_id)
|
| 335 |
+
else:
|
| 336 |
+
self.temperature_cpu[req_index] = sampling_params.temperature
|
| 337 |
+
self.random_reqs.add(req_id)
|
| 338 |
+
|
| 339 |
+
self.top_p_cpu[req_index] = sampling_params.top_p
|
| 340 |
+
if sampling_params.top_p < 1:
|
| 341 |
+
self.top_p_reqs.add(req_id)
|
| 342 |
+
top_k = sampling_params.top_k
|
| 343 |
+
if 0 < top_k < self.vocab_size:
|
| 344 |
+
self.top_k_reqs.add(req_id)
|
| 345 |
+
else:
|
| 346 |
+
top_k = self.vocab_size
|
| 347 |
+
self.top_k_cpu[req_index] = top_k
|
| 348 |
+
self.min_p_cpu[req_index] = sampling_params.min_p
|
| 349 |
+
self.frequency_penalties_cpu[
|
| 350 |
+
req_index] = sampling_params.frequency_penalty
|
| 351 |
+
if sampling_params.min_p > _SAMPLING_EPS:
|
| 352 |
+
self.min_p_reqs.add(req_id)
|
| 353 |
+
if sampling_params.frequency_penalty != 0.0:
|
| 354 |
+
self.frequency_penalties_reqs.add(req_id)
|
| 355 |
+
self.presence_penalties_cpu[
|
| 356 |
+
req_index] = sampling_params.presence_penalty
|
| 357 |
+
if sampling_params.presence_penalty != 0.0:
|
| 358 |
+
self.presence_penalties_reqs.add(req_id)
|
| 359 |
+
self.repetition_penalties_cpu[
|
| 360 |
+
req_index] = sampling_params.repetition_penalty
|
| 361 |
+
if sampling_params.repetition_penalty != 1.0:
|
| 362 |
+
self.repetition_penalties_reqs.add(req_id)
|
| 363 |
+
if sampling_params.min_tokens:
|
| 364 |
+
self.min_tokens[req_index] = (
|
| 365 |
+
sampling_params.min_tokens,
|
| 366 |
+
sampling_params.all_stop_token_ids)
|
| 367 |
+
|
| 368 |
+
if sampling_params.extra_args and "top_n_sigma" in sampling_params.extra_args:
|
| 369 |
+
self.top_n_sigma_cpu[
|
| 370 |
+
req_index] = sampling_params.extra_args["top_n_sigma"]
|
| 371 |
+
self.top_n_sigma_reqs.add(req_id)
|
| 372 |
+
else:
|
| 373 |
+
self.top_n_sigma_cpu[req_index] = -1
|
| 374 |
+
|
| 375 |
+
# NOTE(woosuk): self.generators should not include the requests that
|
| 376 |
+
# do not have their own generator.
|
| 377 |
+
if request.generator is not None:
|
| 378 |
+
self.generators[req_index] = request.generator
|
| 379 |
+
|
| 380 |
+
if sampling_params.logprobs is not None:
|
| 381 |
+
self.num_logprobs[req_id] = sampling_params.logprobs
|
| 382 |
+
if sampling_params.prompt_logprobs is not None:
|
| 383 |
+
self.num_prompt_logprobs[
|
| 384 |
+
req_id] = sampling_params.prompt_logprobs
|
| 385 |
+
if sampling_params.logit_bias is not None:
|
| 386 |
+
self.logit_bias[req_index] = sampling_params.logit_bias
|
| 387 |
+
|
| 388 |
+
if sampling_params.allowed_token_ids:
|
| 389 |
+
self.has_allowed_token_ids.add(req_id)
|
| 390 |
+
if self.allowed_token_ids_mask_cpu_tensor is None:
|
| 391 |
+
# Lazy allocation for this tensor, which can be large.
|
| 392 |
+
# False means we don't fill with -inf.
|
| 393 |
+
self.allowed_token_ids_mask = torch.zeros(
|
| 394 |
+
self.max_num_reqs,
|
| 395 |
+
self.vocab_size,
|
| 396 |
+
dtype=torch.bool,
|
| 397 |
+
device=self.device)
|
| 398 |
+
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
| 399 |
+
self.max_num_reqs,
|
| 400 |
+
self.vocab_size,
|
| 401 |
+
dtype=torch.bool,
|
| 402 |
+
device="cpu")
|
| 403 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
| 404 |
+
# False means we don't fill with -inf.
|
| 405 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
| 406 |
+
sampling_params.allowed_token_ids] = False
|
| 407 |
+
|
| 408 |
+
if sampling_params.bad_words_token_ids:
|
| 409 |
+
self.bad_words_token_ids[
|
| 410 |
+
req_index] = sampling_params.bad_words_token_ids
|
| 411 |
+
else:
|
| 412 |
+
assert request.pooling_params is not None
|
| 413 |
+
self.pooling_params[req_id] = request.pooling_params
|
| 414 |
+
|
| 415 |
+
# Add request lora ID
|
| 416 |
+
if request.lora_request:
|
| 417 |
+
lora_id = request.lora_request.lora_int_id
|
| 418 |
+
if lora_id not in self.lora_id_to_request_ids:
|
| 419 |
+
self.lora_id_to_request_ids[lora_id] = set()
|
| 420 |
+
|
| 421 |
+
self.request_lora_mapping[req_index] = lora_id
|
| 422 |
+
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
| 423 |
+
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
| 424 |
+
else:
|
| 425 |
+
# No LoRA
|
| 426 |
+
self.request_lora_mapping[req_index] = 0
|
| 427 |
+
|
| 428 |
+
def remove_request(self, req_id: str) -> Optional[int]:
|
| 429 |
+
"""This method must always be followed by a call to condense()."""
|
| 430 |
+
|
| 431 |
+
req_index = self.req_id_to_index.pop(req_id, None)
|
| 432 |
+
if req_index is None:
|
| 433 |
+
return None
|
| 434 |
+
self._req_ids[req_index] = None
|
| 435 |
+
self.req_output_token_ids[req_index] = None
|
| 436 |
+
|
| 437 |
+
self.greedy_reqs.discard(req_id)
|
| 438 |
+
self.random_reqs.discard(req_id)
|
| 439 |
+
self.top_p_reqs.discard(req_id)
|
| 440 |
+
self.top_k_reqs.discard(req_id)
|
| 441 |
+
self.min_p_reqs.discard(req_id)
|
| 442 |
+
self.min_tokens.pop(req_index, None)
|
| 443 |
+
self.frequency_penalties_reqs.discard(req_id)
|
| 444 |
+
self.presence_penalties_reqs.discard(req_id)
|
| 445 |
+
self.repetition_penalties_reqs.discard(req_id)
|
| 446 |
+
self.spec_decode_unsupported_reqs.discard(req_id)
|
| 447 |
+
self.top_n_sigma_reqs.discard(req_id)
|
| 448 |
+
self.generators.pop(req_index, None)
|
| 449 |
+
self.num_logprobs.pop(req_id, None)
|
| 450 |
+
self.num_prompt_logprobs.pop(req_id, None)
|
| 451 |
+
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
| 452 |
+
|
| 453 |
+
# LoRA
|
| 454 |
+
lora_id = self.request_lora_mapping[req_index]
|
| 455 |
+
if lora_id != 0:
|
| 456 |
+
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
| 457 |
+
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
| 458 |
+
self.lora_id_to_request_ids.pop(lora_id)
|
| 459 |
+
self.lora_id_to_lora_request.pop(lora_id)
|
| 460 |
+
self.request_lora_mapping[req_index] = 0
|
| 461 |
+
|
| 462 |
+
self.logit_bias[req_index] = None
|
| 463 |
+
self.has_allowed_token_ids.discard(req_id)
|
| 464 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 465 |
+
# False means we don't fill with -inf.
|
| 466 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
| 467 |
+
self.bad_words_token_ids.pop(req_index, None)
|
| 468 |
+
self.pooling_params.pop(req_id, None)
|
| 469 |
+
return req_index
|
| 470 |
+
|
| 471 |
+
def swap_states(self, i1: int, i2: int) -> None:
|
| 472 |
+
old_id_i1 = self._req_ids[i1]
|
| 473 |
+
old_id_i2 = self._req_ids[i2]
|
| 474 |
+
self._req_ids[i1], self._req_ids[i2] =\
|
| 475 |
+
self._req_ids[i2], self._req_ids[i1] # noqa
|
| 476 |
+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
| 477 |
+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
| 478 |
+
assert old_id_i1 is not None and old_id_i2 is not None
|
| 479 |
+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
| 480 |
+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
| 481 |
+
self.num_tokens[i1], self.num_tokens[i2] =\
|
| 482 |
+
self.num_tokens[i2], self.num_tokens[i1]
|
| 483 |
+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
| 484 |
+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
| 485 |
+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
| 486 |
+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
| 487 |
+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
| 488 |
+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
| 489 |
+
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
| 490 |
+
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
| 491 |
+
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
| 492 |
+
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
| 493 |
+
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
| 494 |
+
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
| 495 |
+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
| 496 |
+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
| 497 |
+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
| 498 |
+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
| 499 |
+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
| 500 |
+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
| 501 |
+
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
| 502 |
+
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
| 503 |
+
self.top_n_sigma_cpu[i1], self.top_n_sigma_cpu[i2] =\
|
| 504 |
+
self.top_n_sigma_cpu[i2], self.top_n_sigma_cpu[i1]
|
| 505 |
+
|
| 506 |
+
# NOTE: the following is unsafe
|
| 507 |
+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
| 508 |
+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
| 509 |
+
# instead, we need to temporiarily copy the data for one of the indices
|
| 510 |
+
# TODO(lucas): optimize this by only copying valid indices
|
| 511 |
+
tmp = self.token_ids_cpu[i1, ...].copy()
|
| 512 |
+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
| 513 |
+
self.token_ids_cpu[i2, ...] = tmp
|
| 514 |
+
|
| 515 |
+
swap_dict_values(self.generators, i1, i2)
|
| 516 |
+
swap_dict_values(self.min_tokens, i1, i2)
|
| 517 |
+
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
| 518 |
+
|
| 519 |
+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
| 520 |
+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
| 521 |
+
self.logit_bias[i1], self.logit_bias[i2] =\
|
| 522 |
+
self.logit_bias[i2], self.logit_bias[i1]
|
| 523 |
+
|
| 524 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 525 |
+
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
| 526 |
+
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
| 527 |
+
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
| 528 |
+
self.allowed_token_ids_mask_cpu_tensor[i1]
|
| 529 |
+
self.block_table.swap_row(i1, i2)
|
| 530 |
+
|
| 531 |
+
def condense(self, empty_req_indices: list[int]) -> None:
|
| 532 |
+
"""Move non-empty requests down into lower, empty indices.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
empty_req_indices: empty batch indices, sorted descending.
|
| 536 |
+
"""
|
| 537 |
+
num_reqs = self.num_reqs
|
| 538 |
+
if num_reqs == 0:
|
| 539 |
+
# The batched states are empty.
|
| 540 |
+
self._req_ids.clear()
|
| 541 |
+
self.req_output_token_ids.clear()
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
# NOTE(woosuk): This function assumes that the empty_req_indices
|
| 545 |
+
# is sorted in descending order.
|
| 546 |
+
last_req_index = num_reqs + len(empty_req_indices) - 1
|
| 547 |
+
while empty_req_indices:
|
| 548 |
+
# Find the largest non-empty index.
|
| 549 |
+
while last_req_index in empty_req_indices:
|
| 550 |
+
last_req_index -= 1
|
| 551 |
+
|
| 552 |
+
# Find the smallest empty index.
|
| 553 |
+
empty_index = empty_req_indices.pop()
|
| 554 |
+
if empty_index >= last_req_index:
|
| 555 |
+
break
|
| 556 |
+
|
| 557 |
+
# Swap the states.
|
| 558 |
+
req_id = self._req_ids[last_req_index]
|
| 559 |
+
output_token_ids = self.req_output_token_ids[last_req_index]
|
| 560 |
+
assert req_id is not None
|
| 561 |
+
self._req_ids[empty_index] = req_id
|
| 562 |
+
self._req_ids[last_req_index] = None
|
| 563 |
+
self.req_output_token_ids[empty_index] = output_token_ids
|
| 564 |
+
self.req_output_token_ids[last_req_index] = None
|
| 565 |
+
self.req_id_to_index[req_id] = empty_index
|
| 566 |
+
|
| 567 |
+
num_tokens = self.num_tokens[last_req_index]
|
| 568 |
+
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
| 569 |
+
last_req_index, :num_tokens]
|
| 570 |
+
self.num_tokens[empty_index] = num_tokens
|
| 571 |
+
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
| 572 |
+
last_req_index]
|
| 573 |
+
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
| 574 |
+
last_req_index]
|
| 575 |
+
self.num_computed_tokens_cpu[
|
| 576 |
+
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
| 577 |
+
self.block_table.move_row(last_req_index, empty_index)
|
| 578 |
+
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
| 579 |
+
last_req_index]
|
| 580 |
+
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
| 581 |
+
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
| 582 |
+
self.frequency_penalties_cpu[
|
| 583 |
+
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
| 584 |
+
self.presence_penalties_cpu[
|
| 585 |
+
empty_index] = self.presence_penalties_cpu[last_req_index]
|
| 586 |
+
self.repetition_penalties_cpu[
|
| 587 |
+
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
| 588 |
+
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
| 589 |
+
self.top_n_sigma_cpu[
|
| 590 |
+
empty_index] = self.top_n_sigma_cpu[last_req_index]
|
| 591 |
+
generator = self.generators.pop(last_req_index, None)
|
| 592 |
+
if generator is not None:
|
| 593 |
+
self.generators[empty_index] = generator
|
| 594 |
+
|
| 595 |
+
min_token = self.min_tokens.pop(last_req_index, None)
|
| 596 |
+
if min_token is not None:
|
| 597 |
+
self.min_tokens[empty_index] = min_token
|
| 598 |
+
|
| 599 |
+
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
| 600 |
+
last_req_index]
|
| 601 |
+
|
| 602 |
+
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
| 603 |
+
|
| 604 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 605 |
+
self.allowed_token_ids_mask_cpu_tensor[
|
| 606 |
+
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
| 607 |
+
last_req_index]
|
| 608 |
+
|
| 609 |
+
bad_words_token_ids = self.bad_words_token_ids.pop(
|
| 610 |
+
last_req_index, None)
|
| 611 |
+
if bad_words_token_ids is not None:
|
| 612 |
+
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
| 613 |
+
# Decrement last_req_index since it is now empty.
|
| 614 |
+
last_req_index -= 1
|
| 615 |
+
|
| 616 |
+
# Trim lists to the batch size.
|
| 617 |
+
del self._req_ids[self.num_reqs:]
|
| 618 |
+
del self.req_output_token_ids[self.num_reqs:]
|
| 619 |
+
|
| 620 |
+
def refresh_sampling_metadata(self):
|
| 621 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 622 |
+
|
| 623 |
+
def _make_sampling_metadata(self) -> Union[SamplingMetadata, SamplingMetadataTopNSigma]:
|
| 624 |
+
num_reqs = self.num_reqs
|
| 625 |
+
if not self.all_greedy:
|
| 626 |
+
temperature = copy_slice(self.temperature_cpu_tensor,
|
| 627 |
+
self.temperature, num_reqs)
|
| 628 |
+
else:
|
| 629 |
+
temperature = None
|
| 630 |
+
if not self.no_top_p:
|
| 631 |
+
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
| 632 |
+
if not self.no_top_k:
|
| 633 |
+
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
| 634 |
+
if not self.no_min_p:
|
| 635 |
+
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
| 636 |
+
|
| 637 |
+
if not self.no_penalties:
|
| 638 |
+
# Since syncing these tensors is expensive only copy them
|
| 639 |
+
# if necessary i.e. if there are requests which require
|
| 640 |
+
# penalties to be applied during sampling.
|
| 641 |
+
copy_slice(self.frequency_penalties_cpu_tensor,
|
| 642 |
+
self.frequency_penalties, num_reqs)
|
| 643 |
+
copy_slice(self.presence_penalties_cpu_tensor,
|
| 644 |
+
self.presence_penalties, num_reqs)
|
| 645 |
+
copy_slice(self.repetition_penalties_cpu_tensor,
|
| 646 |
+
self.repetition_penalties, num_reqs)
|
| 647 |
+
|
| 648 |
+
if not self.no_top_n_sigma:
|
| 649 |
+
copy_slice(self.top_n_sigma_cpu_tensor,
|
| 650 |
+
self.top_n_sigma, num_reqs)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
needs_prompt_token_ids = (not self.no_penalties or
|
| 654 |
+
(self.num_reqs > 0
|
| 655 |
+
and self.logits_processing_needs_token_ids))
|
| 656 |
+
if needs_prompt_token_ids:
|
| 657 |
+
# The prompt tokens are used only for applying penalties or
|
| 658 |
+
# step pooling during the sampling/pooling process.
|
| 659 |
+
# Hence copy these tensors only when there are requests which
|
| 660 |
+
# need penalties/step_pooler to be applied.
|
| 661 |
+
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
| 662 |
+
else:
|
| 663 |
+
prompt_token_ids = None
|
| 664 |
+
|
| 665 |
+
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 666 |
+
if not self.no_allowed_token_ids:
|
| 667 |
+
assert self.allowed_token_ids_mask is not None
|
| 668 |
+
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
| 669 |
+
self.allowed_token_ids_mask, num_reqs)
|
| 670 |
+
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
| 671 |
+
|
| 672 |
+
return SamplingMetadataTopNSigma(
|
| 673 |
+
temperature=temperature,
|
| 674 |
+
all_greedy=self.all_greedy,
|
| 675 |
+
all_random=self.all_random,
|
| 676 |
+
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
| 677 |
+
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
| 678 |
+
generators=self.generators,
|
| 679 |
+
max_num_logprobs=self.max_num_logprobs,
|
| 680 |
+
prompt_token_ids=prompt_token_ids,
|
| 681 |
+
frequency_penalties=self.frequency_penalties[:num_reqs],
|
| 682 |
+
presence_penalties=self.presence_penalties[:num_reqs],
|
| 683 |
+
repetition_penalties=self.repetition_penalties[:num_reqs],
|
| 684 |
+
top_n_sigma=self.top_n_sigma[:num_reqs],
|
| 685 |
+
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
| 686 |
+
no_penalties=self.no_penalties,
|
| 687 |
+
no_top_n_sigma=self.no_top_n_sigma,
|
| 688 |
+
allowed_token_ids_mask=allowed_token_ids_mask,
|
| 689 |
+
bad_words_token_ids=self.bad_words_token_ids,
|
| 690 |
+
logitsprocs=self.logitsprocs,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
@property
|
| 694 |
+
def pooling_metadata(self) -> PoolingMetadata:
|
| 695 |
+
if len(self.pooling_params) == 0:
|
| 696 |
+
pooling_params = []
|
| 697 |
+
else:
|
| 698 |
+
# Note, for now this assumes that all request in the batch
|
| 699 |
+
# are either sampling or pooling requests
|
| 700 |
+
assert len(self.req_ids) == len(self.pooling_params)
|
| 701 |
+
pooling_params = [
|
| 702 |
+
self.pooling_params[req_id] for req_id in self.req_ids
|
| 703 |
+
]
|
| 704 |
+
|
| 705 |
+
return PoolingMetadata(
|
| 706 |
+
prompt_lens=torch.from_numpy(
|
| 707 |
+
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
| 708 |
+
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
| 709 |
+
pooling_params=pooling_params,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
| 713 |
+
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
| 714 |
+
prompt_token_ids_cpu_tensor = torch.empty(
|
| 715 |
+
(self.num_reqs, max_prompt_len),
|
| 716 |
+
device="cpu",
|
| 717 |
+
dtype=torch.int64,
|
| 718 |
+
pin_memory=self.pin_memory,
|
| 719 |
+
)
|
| 720 |
+
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
| 721 |
+
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
| 722 |
+
num_reqs, :max_prompt_len]
|
| 723 |
+
# Use the value of vocab_size as a pad since we don't have a
|
| 724 |
+
# token_id of this value.
|
| 725 |
+
for i in range(self.num_reqs):
|
| 726 |
+
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
| 727 |
+
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
| 728 |
+
non_blocking=True)
|
| 729 |
+
|
| 730 |
+
def make_lora_inputs(
|
| 731 |
+
self, num_scheduled_tokens: np.ndarray
|
| 732 |
+
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
| 733 |
+
"""
|
| 734 |
+
Given the num_scheduled_tokens for each request in the batch, return
|
| 735 |
+
datastructures used to activate the current LoRAs.
|
| 736 |
+
Returns:
|
| 737 |
+
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
| 738 |
+
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
| 739 |
+
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
| 740 |
+
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
| 741 |
+
3. lora_requests: Set of relevant LoRA requests.
|
| 742 |
+
"""
|
| 743 |
+
|
| 744 |
+
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
| 745 |
+
prompt_lora_mapping = tuple(req_lora_mapping)
|
| 746 |
+
token_lora_mapping = tuple(
|
| 747 |
+
req_lora_mapping.repeat(num_scheduled_tokens))
|
| 748 |
+
active_lora_requests: set[LoRARequest] = set(
|
| 749 |
+
self.lora_id_to_lora_request.values())
|
| 750 |
+
|
| 751 |
+
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
| 752 |
+
|
| 753 |
+
@property
|
| 754 |
+
def num_reqs(self) -> int:
|
| 755 |
+
return len(self.req_id_to_index)
|
| 756 |
+
|
| 757 |
+
@property
|
| 758 |
+
def all_greedy(self) -> bool:
|
| 759 |
+
return len(self.random_reqs) == 0
|
| 760 |
+
|
| 761 |
+
@property
|
| 762 |
+
def all_random(self) -> bool:
|
| 763 |
+
return len(self.greedy_reqs) == 0
|
| 764 |
+
|
| 765 |
+
@property
|
| 766 |
+
def no_top_p(self) -> bool:
|
| 767 |
+
return len(self.top_p_reqs) == 0
|
| 768 |
+
|
| 769 |
+
@property
|
| 770 |
+
def no_top_k(self) -> bool:
|
| 771 |
+
return len(self.top_k_reqs) == 0
|
| 772 |
+
|
| 773 |
+
@property
|
| 774 |
+
def no_min_p(self) -> bool:
|
| 775 |
+
return len(self.min_p_reqs) == 0
|
| 776 |
+
|
| 777 |
+
@property
|
| 778 |
+
def no_penalties(self) -> bool:
|
| 779 |
+
return (len(self.presence_penalties_reqs) == 0
|
| 780 |
+
and len(self.frequency_penalties_reqs) == 0
|
| 781 |
+
and len(self.repetition_penalties_reqs) == 0)
|
| 782 |
+
@property
|
| 783 |
+
def no_top_n_sigma(self) -> bool:
|
| 784 |
+
return len(self.top_n_sigma_reqs) == 0
|
| 785 |
+
|
| 786 |
+
@property
|
| 787 |
+
def max_num_logprobs(self) -> Optional[int]:
|
| 788 |
+
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
| 789 |
+
|
| 790 |
+
@property
|
| 791 |
+
def no_prompt_logprob(self) -> bool:
|
| 792 |
+
return not self.num_prompt_logprobs
|
| 793 |
+
|
| 794 |
+
@property
|
| 795 |
+
def no_allowed_token_ids(self) -> bool:
|
| 796 |
+
return len(self.has_allowed_token_ids) == 0
|
inference/vllm_ascend_for_openpangu_embedded_7b.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Deployment Guide of openPangu Embedded 7B Based on [vllm-ascend](https://github.com/vllm-project/vllm-ascend)
|
| 2 |
+
|
| 3 |
+
### Deployment Environment Description
|
| 4 |
+
|
| 5 |
+
The Atlas 800T A2 (64 GB) supports the deployment of openPangu Embedded 7B (bf16) with 4 cards. The vllm-ascend community image v0.9.1-dev is used and needs to be pulled on multiple nodes.
|
| 6 |
+
```bash
|
| 7 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
### Docker Boot and Inference Code
|
| 11 |
+
|
| 12 |
+
Perform the following operations on all nodes.
|
| 13 |
+
|
| 14 |
+
Run the following command to start the docker:
|
| 15 |
+
```bash
|
| 16 |
+
# Update the vllm-ascend image
|
| 17 |
+
export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
|
| 18 |
+
export NAME=vllm-ascend # Custom docker name
|
| 19 |
+
|
| 20 |
+
# Run the container using the defined variables
|
| 21 |
+
# Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
|
| 22 |
+
# To prevent device interference from other docker containers, add the argument "--privileged"
|
| 23 |
+
docker run --rm \
|
| 24 |
+
--name $NAME \
|
| 25 |
+
--network host \
|
| 26 |
+
--device /dev/davinci0 \
|
| 27 |
+
--device /dev/davinci1 \
|
| 28 |
+
--device /dev/davinci2 \
|
| 29 |
+
--device /dev/davinci3 \
|
| 30 |
+
--device /dev/davinci4 \
|
| 31 |
+
--device /dev/davinci5 \
|
| 32 |
+
--device /dev/davinci6 \
|
| 33 |
+
--device /dev/davinci7 \
|
| 34 |
+
--device /dev/davinci_manager \
|
| 35 |
+
--device /dev/devmm_svm \
|
| 36 |
+
--device /dev/hisi_hdc \
|
| 37 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 38 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 39 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 40 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 41 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 42 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 43 |
+
-v /mnt/sfs_turbo/.cache:/root/.cache \
|
| 44 |
+
-it $IMAGE bash
|
| 45 |
+
```
|
| 46 |
+
If not inside the container, enter the container as the root user:
|
| 47 |
+
```
|
| 48 |
+
docker exec -itu root $NAME /bin/bash
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Download vllm (v0.9.2) to replace the built-in vllm code of the image.
|
| 52 |
+
```bash
|
| 53 |
+
pip install --no-deps vllm==0.9.2 pybase64==1.4.1
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Download [vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1) and replace the built-in vllm-ascend code in the image (/vllm-workspace/vllm-ascend/). For example, download [Source code (tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz) from Assets to get v0.9.2rc1.tar.gz, then extract and replace:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
|
| 60 |
+
export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Use the Pangu model-adapted vllm-ascend code from the current repository to replace parts of the code in `/vllm-workspace/vllm-ascend/vllm_ascend/`:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Replace the `tokenizer_config.json` file with the one that has added `special token`, [old file](../tokenizer_config.json) -> [new file](./vllm_ascend/tokenizer_config.json)
|
| 70 |
+
```bash
|
| 71 |
+
cp ./vllm_ascend/tokenizer_config.json ../tokenizer_config.json
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### openPangu Embedded 7B Inference
|
| 75 |
+
|
| 76 |
+
Perform the following operations on all nodes.
|
| 77 |
+
|
| 78 |
+
Configuration:
|
| 79 |
+
```bash
|
| 80 |
+
export VLLM_USE_V1=1
|
| 81 |
+
# Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
|
| 82 |
+
# Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
|
| 83 |
+
HOST=xxx.xxx.xxx.xxx
|
| 84 |
+
PORT=8080
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
openPangu Embedded 7B running command:
|
| 88 |
+
```bash
|
| 89 |
+
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
|
| 90 |
+
LOCAL_CKPT_DIR=/root/.cache/pangu_embedded_7b # The pangu_embedded_7b bf16 weight
|
| 91 |
+
SERVED_MODEL_NAME=pangu_embedded_7b
|
| 92 |
+
|
| 93 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 94 |
+
--served-model-name $SERVED_MODEL_NAME \
|
| 95 |
+
--tensor-parallel-size 4 \
|
| 96 |
+
--trust-remote-code \
|
| 97 |
+
--host $HOST \
|
| 98 |
+
--port $PORT \
|
| 99 |
+
--max-num-seqs 32 \
|
| 100 |
+
--max-model-len 32768 \
|
| 101 |
+
--max-num-batched-tokens 4096 \
|
| 102 |
+
--tokenizer-mode "slow" \
|
| 103 |
+
--dtype bfloat16 \
|
| 104 |
+
--distributed-executor-backend mp \
|
| 105 |
+
--gpu-memory-utilization 0.93 \
|
| 106 |
+
--no-enable-prefix-caching \
|
| 107 |
+
--no-enable-chunked-prefill \
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### Test Request
|
| 111 |
+
|
| 112 |
+
After server launched, send test request from master node or other nodes:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
MASTER_NODE_IP=xxx.xxx.xxx.xxx # server node ip
|
| 116 |
+
curl http://${MASTER_NODE_IP}:${PORT}/v1/chat/completions \
|
| 117 |
+
-H "Content-Type: application/json" \
|
| 118 |
+
-d '{
|
| 119 |
+
"model": "'$SERVED_MODEL_NAME'",
|
| 120 |
+
"messages": [
|
| 121 |
+
{
|
| 122 |
+
"role": "user",
|
| 123 |
+
"content": "Who are you?"
|
| 124 |
+
}
|
| 125 |
+
],
|
| 126 |
+
"max_tokens": 512,
|
| 127 |
+
"temperature": 0
|
| 128 |
+
}'
|
| 129 |
+
```
|
inference/vllm_ascend_for_openpangu_embedded_7b.zh.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## openPangu Embedded 7B 在[vllm-ascend](https://github.com/vllm-project/vllm-ascend)部署指导文档
|
| 2 |
+
|
| 3 |
+
### 部署环境说明
|
| 4 |
+
|
| 5 |
+
Atlas 800T A2(64GB) 4卡可部署openPangu Embedded 7B (bf16),选用vllm-ascend社区镜像v0.9.1-dev。
|
| 6 |
+
```bash
|
| 7 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
### 镜像启动和推理代码适配
|
| 11 |
+
|
| 12 |
+
以下操作需在每个节点都执行。
|
| 13 |
+
|
| 14 |
+
启动镜像。
|
| 15 |
+
```bash
|
| 16 |
+
# Update the vllm-ascend image
|
| 17 |
+
export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
|
| 18 |
+
export NAME=vllm-ascend # Custom docker name
|
| 19 |
+
|
| 20 |
+
# Run the container using the defined variables
|
| 21 |
+
# Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
|
| 22 |
+
# To prevent device interference from other docker containers, add the argument "--privileged"
|
| 23 |
+
docker run --rm \
|
| 24 |
+
--name $NAME \
|
| 25 |
+
--network host \
|
| 26 |
+
--device /dev/davinci0 \
|
| 27 |
+
--device /dev/davinci1 \
|
| 28 |
+
--device /dev/davinci2 \
|
| 29 |
+
--device /dev/davinci3 \
|
| 30 |
+
--device /dev/davinci4 \
|
| 31 |
+
--device /dev/davinci5 \
|
| 32 |
+
--device /dev/davinci6 \
|
| 33 |
+
--device /dev/davinci7 \
|
| 34 |
+
--device /dev/davinci_manager \
|
| 35 |
+
--device /dev/devmm_svm \
|
| 36 |
+
--device /dev/hisi_hdc \
|
| 37 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 38 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 39 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 40 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 41 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 42 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 43 |
+
-v /mnt/sfs_turbo/.cache:/root/.cache \
|
| 44 |
+
-it $IMAGE bash
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
如果未进入容器,需以root用户进入容器。
|
| 48 |
+
```
|
| 49 |
+
docker exec -itu root $NAME /bin/bash
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
下载vllm (v0.9.2),替换镜像内置的vllm代码。
|
| 53 |
+
```bash
|
| 54 |
+
pip install --no-deps vllm==0.9.2 pybase64==1.4.1
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
下载[vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1),替换镜像内置的vllm-ascend代码(`/vllm-workspace/vllm-ascend/`)。例如下载Assets中的[Source code
|
| 58 |
+
(tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz)得到v0.9.2rc1.tar.gz,然后解压并替换:
|
| 59 |
+
```bash
|
| 60 |
+
tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
|
| 61 |
+
export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
使用当前代码仓中适配盘古模型的vllm-ascend代码替换`/vllm-workspace/vllm-ascend/vllm_ascend/`中的部分代码。
|
| 65 |
+
```bash
|
| 66 |
+
yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
替换增加`special token`后的`tokenizer_config.json`文件,[旧文件](../tokenizer_config.json) -> [新文件](./vllm_ascend/tokenizer_config.json)
|
| 70 |
+
```bash
|
| 71 |
+
cp ./vllm_ascend/tokenizer_config.json ../tokenizer_config.json
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### openPangu Embedded 7B推理
|
| 75 |
+
|
| 76 |
+
以下操作需在每个节点都执行。
|
| 77 |
+
|
| 78 |
+
配置:
|
| 79 |
+
```bash
|
| 80 |
+
export VLLM_USE_V1=1
|
| 81 |
+
# Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
|
| 82 |
+
# Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
|
| 83 |
+
HOST=xxx.xxx.xxx.xxx
|
| 84 |
+
PORT=8080
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
openPangu Embedded 7B 运行命令:
|
| 88 |
+
```bash
|
| 89 |
+
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
|
| 90 |
+
LOCAL_CKPT_DIR=/root/.cache/pangu_embedded_7b # The pangu_embedded_7b bf16 weight
|
| 91 |
+
SERVED_MODEL_NAME=pangu_embedded_7b
|
| 92 |
+
|
| 93 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 94 |
+
--served-model-name $SERVED_MODEL_NAME \
|
| 95 |
+
--tensor-parallel-size 4 \
|
| 96 |
+
--trust-remote-code \
|
| 97 |
+
--host $HOST \
|
| 98 |
+
--port $PORT \
|
| 99 |
+
--max-num-seqs 32 \
|
| 100 |
+
--max-model-len 32768 \
|
| 101 |
+
--max-num-batched-tokens 4096 \
|
| 102 |
+
--tokenizer-mode "slow" \
|
| 103 |
+
--dtype bfloat16 \
|
| 104 |
+
--distributed-executor-backend mp \
|
| 105 |
+
--gpu-memory-utilization 0.93 \
|
| 106 |
+
--no-enable-prefix-caching \
|
| 107 |
+
--no-enable-chunked-prefill \
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### 发请求测试
|
| 111 |
+
|
| 112 |
+
服务启动后,在主节点或者其他节点向主节点发送测试请求:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
MASTER_NODE_IP=xxx.xxx.xxx.xxx # server node ip
|
| 116 |
+
curl http://${MASTER_NODE_IP}:${PORT}/v1/chat/completions \
|
| 117 |
+
-H "Content-Type: application/json" \
|
| 118 |
+
-d '{
|
| 119 |
+
"model": "'$SERVED_MODEL_NAME'",
|
| 120 |
+
"messages": [
|
| 121 |
+
{
|
| 122 |
+
"role": "user",
|
| 123 |
+
"content": "Who are you?"
|
| 124 |
+
}
|
| 125 |
+
],
|
| 126 |
+
"max_tokens": 512,
|
| 127 |
+
"temperature": 0
|
| 128 |
+
}'
|
| 129 |
+
```
|
model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8af395f1ed0b7b4ca70828c4d87db0cc72e3e5720c3d462265dcc87a977fd1dd
|
| 3 |
+
size 4926842416
|
model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb8081a9578a22fb500f851dcfd5af07029af73516663b589724d09215dc7edc
|
| 3 |
+
size 4991686576
|
model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea5b8e7ac544c08954a47e1a98ea7b5c6fc26f3c12c39469159896158fa43e9c
|
| 3 |
+
size 4886853760
|
model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b65417ef013c072961910a91db017836bbfbf300422fd3985652c7807f45adc9
|
| 3 |
+
size 1256456320
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 16061784576
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"lm_head.weight": "model-00004-of-00004.safetensors",
|
| 7 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
| 8 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 9 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 10 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 11 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 12 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 13 |
+
"model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 14 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 15 |
+
"model.layers.0.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 16 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 17 |
+
"model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 18 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 19 |
+
"model.layers.0.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 20 |
+
"model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 21 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 22 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 23 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 24 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 25 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 26 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 27 |
+
"model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 28 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 29 |
+
"model.layers.1.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 30 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 31 |
+
"model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 32 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 33 |
+
"model.layers.1.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 34 |
+
"model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 35 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 36 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 37 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 38 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 39 |
+
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 40 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 41 |
+
"model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 42 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 43 |
+
"model.layers.10.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 44 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 45 |
+
"model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 46 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 47 |
+
"model.layers.10.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 48 |
+
"model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 49 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 50 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 51 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 52 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 53 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 54 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 55 |
+
"model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 56 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 57 |
+
"model.layers.11.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 58 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 59 |
+
"model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 60 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 61 |
+
"model.layers.11.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 62 |
+
"model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 63 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 64 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 65 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 66 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 67 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 68 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 69 |
+
"model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 70 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 71 |
+
"model.layers.12.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 72 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 73 |
+
"model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 74 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 75 |
+
"model.layers.12.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 76 |
+
"model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 77 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 78 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 79 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 80 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 81 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 82 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 83 |
+
"model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 84 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 85 |
+
"model.layers.13.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 86 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 87 |
+
"model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 88 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 89 |
+
"model.layers.13.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 90 |
+
"model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 91 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 92 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 93 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 94 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 95 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 96 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 97 |
+
"model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 98 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 99 |
+
"model.layers.14.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 100 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 101 |
+
"model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 102 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 103 |
+
"model.layers.14.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 104 |
+
"model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 105 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 106 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 107 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 108 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 109 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 110 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 111 |
+
"model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 112 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 113 |
+
"model.layers.15.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 114 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 115 |
+
"model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 116 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 117 |
+
"model.layers.15.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 118 |
+
"model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 119 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 120 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 121 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 122 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 123 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 124 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 125 |
+
"model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 126 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 127 |
+
"model.layers.16.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 128 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 129 |
+
"model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 130 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 131 |
+
"model.layers.16.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 132 |
+
"model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 133 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 134 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 135 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 136 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 137 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 138 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 139 |
+
"model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 140 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 141 |
+
"model.layers.17.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 142 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 143 |
+
"model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 144 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 145 |
+
"model.layers.17.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 146 |
+
"model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 147 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 148 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 149 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 150 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 151 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 152 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 153 |
+
"model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 154 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 155 |
+
"model.layers.18.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 156 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 157 |
+
"model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 158 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 159 |
+
"model.layers.18.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 160 |
+
"model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 161 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 162 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 163 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 164 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 165 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 166 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 167 |
+
"model.layers.19.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 168 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 169 |
+
"model.layers.19.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 170 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 171 |
+
"model.layers.19.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 172 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 173 |
+
"model.layers.19.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 174 |
+
"model.layers.19.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 175 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 176 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 177 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 178 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 179 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 180 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 181 |
+
"model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 182 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 183 |
+
"model.layers.2.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 184 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 185 |
+
"model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 186 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 187 |
+
"model.layers.2.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 188 |
+
"model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 189 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 190 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 191 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 192 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 193 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 194 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 195 |
+
"model.layers.20.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 196 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 197 |
+
"model.layers.20.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 198 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 199 |
+
"model.layers.20.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 200 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 201 |
+
"model.layers.20.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 202 |
+
"model.layers.20.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 203 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 204 |
+
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 205 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 206 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 207 |
+
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 208 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 209 |
+
"model.layers.21.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 210 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 211 |
+
"model.layers.21.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 212 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 213 |
+
"model.layers.21.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 214 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 215 |
+
"model.layers.21.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 216 |
+
"model.layers.21.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 217 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 218 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 219 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 220 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 221 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 222 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 223 |
+
"model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 224 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 225 |
+
"model.layers.22.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 226 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 227 |
+
"model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 228 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 229 |
+
"model.layers.22.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 230 |
+
"model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 231 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 232 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 233 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 234 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 235 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 236 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 237 |
+
"model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 238 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 239 |
+
"model.layers.23.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 240 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 241 |
+
"model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 242 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 243 |
+
"model.layers.23.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 244 |
+
"model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 245 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 246 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 247 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 248 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 249 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 250 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 251 |
+
"model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 252 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 253 |
+
"model.layers.24.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 254 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 255 |
+
"model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 256 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 257 |
+
"model.layers.24.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 258 |
+
"model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 259 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 260 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 261 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 262 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 263 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 264 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 265 |
+
"model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 266 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 267 |
+
"model.layers.25.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 268 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 269 |
+
"model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 270 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 271 |
+
"model.layers.25.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 272 |
+
"model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 273 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 274 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 275 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 276 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 277 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 278 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 279 |
+
"model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 280 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 281 |
+
"model.layers.26.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 282 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 283 |
+
"model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 284 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 285 |
+
"model.layers.26.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 286 |
+
"model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 287 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 288 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 289 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 290 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 291 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 292 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 293 |
+
"model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 294 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 295 |
+
"model.layers.27.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 296 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 297 |
+
"model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 298 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 299 |
+
"model.layers.27.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 300 |
+
"model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 301 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 302 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 303 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 304 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 305 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 306 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 307 |
+
"model.layers.28.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 308 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 309 |
+
"model.layers.28.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 310 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 311 |
+
"model.layers.28.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 312 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 313 |
+
"model.layers.28.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 314 |
+
"model.layers.28.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 315 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 316 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 317 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 318 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 319 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 320 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 321 |
+
"model.layers.29.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 322 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 323 |
+
"model.layers.29.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 324 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 325 |
+
"model.layers.29.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 326 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 327 |
+
"model.layers.29.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 328 |
+
"model.layers.29.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 329 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 330 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 331 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 332 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 333 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 334 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 335 |
+
"model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 336 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 337 |
+
"model.layers.3.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 338 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 339 |
+
"model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 340 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 341 |
+
"model.layers.3.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 342 |
+
"model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 343 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 344 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 345 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 346 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 347 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 348 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 349 |
+
"model.layers.30.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 350 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 351 |
+
"model.layers.30.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 352 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 353 |
+
"model.layers.30.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 354 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 355 |
+
"model.layers.30.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 356 |
+
"model.layers.30.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 357 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 358 |
+
"model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 359 |
+
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 360 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 361 |
+
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 362 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 363 |
+
"model.layers.31.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 364 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 365 |
+
"model.layers.31.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 366 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 367 |
+
"model.layers.31.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 368 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 369 |
+
"model.layers.31.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 370 |
+
"model.layers.31.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 371 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 372 |
+
"model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 373 |
+
"model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 374 |
+
"model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 375 |
+
"model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 376 |
+
"model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 377 |
+
"model.layers.32.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 378 |
+
"model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 379 |
+
"model.layers.32.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 380 |
+
"model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 381 |
+
"model.layers.32.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 382 |
+
"model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 383 |
+
"model.layers.32.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 384 |
+
"model.layers.32.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 385 |
+
"model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 386 |
+
"model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 387 |
+
"model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 388 |
+
"model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 389 |
+
"model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 390 |
+
"model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 391 |
+
"model.layers.33.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 392 |
+
"model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 393 |
+
"model.layers.33.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 394 |
+
"model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 395 |
+
"model.layers.33.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 396 |
+
"model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 397 |
+
"model.layers.33.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 398 |
+
"model.layers.33.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 399 |
+
"model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 400 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 401 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 402 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 403 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 404 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 405 |
+
"model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 406 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 407 |
+
"model.layers.4.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 408 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 409 |
+
"model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 410 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 411 |
+
"model.layers.4.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 412 |
+
"model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 413 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 414 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 415 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 416 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 417 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 418 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 419 |
+
"model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 420 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 421 |
+
"model.layers.5.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 422 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 423 |
+
"model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 424 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 425 |
+
"model.layers.5.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 426 |
+
"model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 427 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 428 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 429 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 430 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 431 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 432 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 433 |
+
"model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 434 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 435 |
+
"model.layers.6.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 436 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 437 |
+
"model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 438 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 439 |
+
"model.layers.6.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 440 |
+
"model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 441 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 442 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 443 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 444 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 445 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 446 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 447 |
+
"model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 448 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 449 |
+
"model.layers.7.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 450 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 451 |
+
"model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 452 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 453 |
+
"model.layers.7.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 454 |
+
"model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 455 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 456 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 457 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 458 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 459 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 460 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 461 |
+
"model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 462 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 463 |
+
"model.layers.8.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 464 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 465 |
+
"model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 466 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 467 |
+
"model.layers.8.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 468 |
+
"model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 469 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 470 |
+
"model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 471 |
+
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 472 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 473 |
+
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 474 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 475 |
+
"model.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 476 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 477 |
+
"model.layers.9.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 478 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 479 |
+
"model.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 480 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 481 |
+
"model.layers.9.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 482 |
+
"model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 483 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 484 |
+
"model.norm.weight": "model-00003-of-00004.safetensors"
|
| 485 |
+
}
|
| 486 |
+
}
|
modeling_openpangu_dense.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from modular_openpangu_dense.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_openpangu_dense.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
|
| 8 |
+
# coding=utf-8
|
| 9 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 10 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 13 |
+
# and OPT implementations in this library. It has been modified from its
|
| 14 |
+
# original forms to accommodate minor architectural differences compared
|
| 15 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 16 |
+
#
|
| 17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 18 |
+
# you may not use this file except in compliance with the License.
|
| 19 |
+
# You may obtain a copy of the License at
|
| 20 |
+
#
|
| 21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 22 |
+
#
|
| 23 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 26 |
+
# See the License for the specific language governing permissions and
|
| 27 |
+
# limitations under the License.
|
| 28 |
+
|
| 29 |
+
from typing import Callable, Optional, Union
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from torch import nn
|
| 33 |
+
|
| 34 |
+
import torch_npu
|
| 35 |
+
from torch_npu.contrib import transfer_to_npu
|
| 36 |
+
if "910" in torch.npu.get_device_name():
|
| 37 |
+
NPU_ATTN_INFR = True
|
| 38 |
+
print("[INFO] torch_npu detected. Using NPU fused infer attention.")
|
| 39 |
+
else:
|
| 40 |
+
NPU_ATTN_INFR = False
|
| 41 |
+
|
| 42 |
+
from transformers.activations import ACT2FN
|
| 43 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 44 |
+
from transformers.generation import GenerationMixin
|
| 45 |
+
from transformers.masking_utils import create_causal_mask
|
| 46 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 47 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 48 |
+
from transformers.modeling_outputs import (
|
| 49 |
+
BaseModelOutputWithPast,
|
| 50 |
+
CausalLMOutputWithPast,
|
| 51 |
+
SequenceClassifierOutputWithPast,
|
| 52 |
+
)
|
| 53 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 54 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 55 |
+
from transformers.processing_utils import Unpack
|
| 56 |
+
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
| 57 |
+
from .configuration_openpangu_dense import PanguEmbeddedConfig
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
logger = logging.get_logger(__name__)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PanguEmbeddedRMSNorm(nn.Module):
|
| 64 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 65 |
+
"""
|
| 66 |
+
PanguEmbeddedRMSNorm is equivalent to T5LayerNorm
|
| 67 |
+
"""
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 70 |
+
self.variance_epsilon = eps
|
| 71 |
+
|
| 72 |
+
def forward(self, hidden_states):
|
| 73 |
+
input_dtype = hidden_states.dtype
|
| 74 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 75 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 76 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 77 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 78 |
+
|
| 79 |
+
def extra_repr(self):
|
| 80 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class PanguEmbeddedRotaryEmbedding(nn.Module):
|
| 84 |
+
def __init__(self, config: PanguEmbeddedConfig, device=None):
|
| 85 |
+
super().__init__()
|
| 86 |
+
# BC: "rope_type" was originally "type"
|
| 87 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 88 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 89 |
+
else:
|
| 90 |
+
self.rope_type = "default"
|
| 91 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 92 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 93 |
+
|
| 94 |
+
self.config = config
|
| 95 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 96 |
+
|
| 97 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 98 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 99 |
+
self.original_inv_freq = self.inv_freq
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 103 |
+
def forward(self, x, position_ids):
|
| 104 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 105 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 106 |
+
|
| 107 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 108 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 109 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 110 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 111 |
+
cos = emb.cos() * self.attention_scaling
|
| 112 |
+
sin = emb.sin() * self.attention_scaling
|
| 113 |
+
|
| 114 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def rotate_half(x):
|
| 118 |
+
"""Rotates half the hidden dims of the input."""
|
| 119 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 120 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 121 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 125 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
q (`torch.Tensor`): The query tensor.
|
| 129 |
+
k (`torch.Tensor`): The key tensor.
|
| 130 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 131 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 132 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 133 |
+
Deprecated and unused.
|
| 134 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 135 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 136 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 137 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 138 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 139 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 140 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 141 |
+
Returns:
|
| 142 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 143 |
+
"""
|
| 144 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 145 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 146 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 147 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 148 |
+
return q_embed, k_embed
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class PanguEmbeddedMLP(nn.Module):
|
| 152 |
+
def __init__(self, config):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.config = config
|
| 155 |
+
self.hidden_size = config.hidden_size
|
| 156 |
+
self.intermediate_size = config.intermediate_size
|
| 157 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 158 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 159 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 160 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 164 |
+
return down_proj
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 170 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 171 |
+
"""
|
| 172 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 173 |
+
if n_rep == 1:
|
| 174 |
+
return hidden_states
|
| 175 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 176 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def eager_attention_forward(
|
| 180 |
+
module: nn.Module,
|
| 181 |
+
query: torch.Tensor,
|
| 182 |
+
key: torch.Tensor,
|
| 183 |
+
value: torch.Tensor,
|
| 184 |
+
attention_mask: Optional[torch.Tensor],
|
| 185 |
+
scaling: float,
|
| 186 |
+
dropout: float = 0.0,
|
| 187 |
+
**kwargs,
|
| 188 |
+
):
|
| 189 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 190 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 191 |
+
|
| 192 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 193 |
+
if attention_mask is not None:
|
| 194 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 195 |
+
attn_weights = attn_weights + causal_mask
|
| 196 |
+
|
| 197 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 198 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 199 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 200 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 201 |
+
|
| 202 |
+
return attn_output, attn_weights
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class PanguEmbeddedAttention(nn.Module):
|
| 206 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 207 |
+
|
| 208 |
+
def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.config = config
|
| 211 |
+
self.layer_idx = layer_idx
|
| 212 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 213 |
+
self.num_heads = config.num_attention_heads
|
| 214 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 215 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 216 |
+
self.scaling = self.head_dim**-0.5
|
| 217 |
+
self.attention_dropout = config.attention_dropout
|
| 218 |
+
self.is_causal = True
|
| 219 |
+
|
| 220 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
|
| 221 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 222 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 223 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
hidden_states: torch.Tensor,
|
| 228 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 229 |
+
attention_mask: Optional[torch.Tensor],
|
| 230 |
+
past_key_value: Optional[Cache] = None,
|
| 231 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 232 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 233 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 234 |
+
input_shape = hidden_states.shape[:-1]
|
| 235 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 236 |
+
|
| 237 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 238 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 239 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 240 |
+
|
| 241 |
+
cos, sin = position_embeddings
|
| 242 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 243 |
+
|
| 244 |
+
if past_key_value is not None:
|
| 245 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 246 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 247 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 248 |
+
|
| 249 |
+
attention_interface: Callable = eager_attention_forward
|
| 250 |
+
if self.config._attn_implementation != "eager":
|
| 251 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 252 |
+
|
| 253 |
+
if not self.training and NPU_ATTN_INFR:
|
| 254 |
+
q_len = input_shape[1]
|
| 255 |
+
if attention_mask is not None:
|
| 256 |
+
attention_mask = ~attention_mask.bool()
|
| 257 |
+
elif q_len > 1:
|
| 258 |
+
attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
|
| 259 |
+
|
| 260 |
+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 261 |
+
query_states, key_states, value_states,
|
| 262 |
+
num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
|
| 263 |
+
input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
|
| 264 |
+
attn_output = attn_output.transpose(1, 2)
|
| 265 |
+
attn_weights = None
|
| 266 |
+
else:
|
| 267 |
+
attn_output, attn_weights = attention_interface(
|
| 268 |
+
self,
|
| 269 |
+
query_states,
|
| 270 |
+
key_states,
|
| 271 |
+
value_states,
|
| 272 |
+
attention_mask,
|
| 273 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 274 |
+
scaling=self.scaling,
|
| 275 |
+
**kwargs,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 279 |
+
attn_output = self.o_proj(attn_output)
|
| 280 |
+
return attn_output, attn_weights
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class PanguEmbeddedDecoderLayer(GradientCheckpointingLayer):
|
| 284 |
+
def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.hidden_size = config.hidden_size
|
| 287 |
+
self.self_attn = PanguEmbeddedAttention(config=config, layer_idx=layer_idx)
|
| 288 |
+
self.mlp = PanguEmbeddedMLP(config)
|
| 289 |
+
self.input_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 290 |
+
self.post_attention_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 291 |
+
|
| 292 |
+
def forward(
|
| 293 |
+
self,
|
| 294 |
+
hidden_states: torch.Tensor,
|
| 295 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 296 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 297 |
+
past_key_value: Optional[Cache] = None,
|
| 298 |
+
output_attentions: Optional[bool] = False,
|
| 299 |
+
use_cache: Optional[bool] = False,
|
| 300 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 301 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 302 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 303 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 304 |
+
residual = hidden_states
|
| 305 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 306 |
+
|
| 307 |
+
# Self Attention
|
| 308 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 309 |
+
hidden_states=hidden_states,
|
| 310 |
+
attention_mask=attention_mask,
|
| 311 |
+
position_ids=position_ids,
|
| 312 |
+
past_key_value=past_key_value,
|
| 313 |
+
output_attentions=output_attentions,
|
| 314 |
+
use_cache=use_cache,
|
| 315 |
+
cache_position=cache_position,
|
| 316 |
+
position_embeddings=position_embeddings,
|
| 317 |
+
**kwargs,
|
| 318 |
+
)
|
| 319 |
+
hidden_states = residual + hidden_states
|
| 320 |
+
|
| 321 |
+
# Fully Connected
|
| 322 |
+
residual = hidden_states
|
| 323 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 324 |
+
hidden_states = self.mlp(hidden_states)
|
| 325 |
+
hidden_states = residual + hidden_states
|
| 326 |
+
|
| 327 |
+
outputs = (hidden_states,)
|
| 328 |
+
if output_attentions:
|
| 329 |
+
outputs += (self_attn_weights,)
|
| 330 |
+
|
| 331 |
+
return outputs
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@auto_docstring
|
| 335 |
+
class PanguEmbeddedPreTrainedModel(PreTrainedModel):
|
| 336 |
+
config_class = PanguEmbeddedConfig
|
| 337 |
+
base_model_prefix = "model"
|
| 338 |
+
supports_gradient_checkpointing = True
|
| 339 |
+
_no_split_modules = ["PanguEmbeddedDecoderLayer"]
|
| 340 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 341 |
+
_supports_flash_attn_3 = True
|
| 342 |
+
_supports_flash_attn_2 = True
|
| 343 |
+
_supports_sdpa = True
|
| 344 |
+
_supports_flex_attn = True
|
| 345 |
+
_supports_cache_class = True
|
| 346 |
+
_supports_quantized_cache = True
|
| 347 |
+
_supports_static_cache = True
|
| 348 |
+
_supports_attention_backend = True
|
| 349 |
+
|
| 350 |
+
def _init_weights(self, module):
|
| 351 |
+
std = self.config.initializer_range
|
| 352 |
+
if isinstance(module, nn.Linear):
|
| 353 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 354 |
+
if module.bias is not None:
|
| 355 |
+
module.bias.data.zero_()
|
| 356 |
+
elif isinstance(module, nn.Embedding):
|
| 357 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 358 |
+
if module.padding_idx is not None:
|
| 359 |
+
module.weight.data[module.padding_idx].zero_()
|
| 360 |
+
elif isinstance(module, PanguEmbeddedRMSNorm):
|
| 361 |
+
module.weight.data.fill_(1.0)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@auto_docstring
|
| 365 |
+
class PanguEmbeddedModel(PanguEmbeddedPreTrainedModel):
|
| 366 |
+
def __init__(self, config: PanguEmbeddedConfig):
|
| 367 |
+
super().__init__(config)
|
| 368 |
+
self.padding_idx = config.pad_token_id
|
| 369 |
+
self.vocab_size = config.vocab_size
|
| 370 |
+
|
| 371 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 372 |
+
self.layers = nn.ModuleList(
|
| 373 |
+
[PanguEmbeddedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 374 |
+
)
|
| 375 |
+
self.norm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 376 |
+
self.rotary_emb = PanguEmbeddedRotaryEmbedding(config=config)
|
| 377 |
+
self.gradient_checkpointing = False
|
| 378 |
+
|
| 379 |
+
# Initialize weights and apply final processing
|
| 380 |
+
self.post_init()
|
| 381 |
+
|
| 382 |
+
def get_input_embeddings(self):
|
| 383 |
+
return self.embed_tokens
|
| 384 |
+
|
| 385 |
+
def set_input_embeddings(self, value):
|
| 386 |
+
self.embed_tokens = value
|
| 387 |
+
|
| 388 |
+
@can_return_tuple
|
| 389 |
+
@auto_docstring
|
| 390 |
+
def forward(
|
| 391 |
+
self,
|
| 392 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 393 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 394 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 395 |
+
past_key_values: Optional[Cache] = None,
|
| 396 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 397 |
+
use_cache: Optional[bool] = None,
|
| 398 |
+
output_attentions: Optional[bool] = None,
|
| 399 |
+
output_hidden_states: Optional[bool] = None,
|
| 400 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 401 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 402 |
+
) -> BaseModelOutputWithPast:
|
| 403 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 404 |
+
output_hidden_states = (
|
| 405 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 406 |
+
)
|
| 407 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 408 |
+
|
| 409 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 410 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 411 |
+
|
| 412 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 413 |
+
logger.warning_once(
|
| 414 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 415 |
+
)
|
| 416 |
+
use_cache = False
|
| 417 |
+
|
| 418 |
+
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
| 419 |
+
if not isinstance(past_key_values, (type(None), Cache)):
|
| 420 |
+
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 421 |
+
|
| 422 |
+
if inputs_embeds is None:
|
| 423 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 424 |
+
|
| 425 |
+
if use_cache and past_key_values is None:
|
| 426 |
+
past_key_values = DynamicCache()
|
| 427 |
+
|
| 428 |
+
if cache_position is None:
|
| 429 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 430 |
+
cache_position = torch.arange(
|
| 431 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if position_ids is None:
|
| 435 |
+
position_ids = cache_position.unsqueeze(0)
|
| 436 |
+
|
| 437 |
+
causal_mask = create_causal_mask(
|
| 438 |
+
config=self.config,
|
| 439 |
+
input_embeds=inputs_embeds,
|
| 440 |
+
attention_mask=attention_mask,
|
| 441 |
+
cache_position=cache_position,
|
| 442 |
+
past_key_values=past_key_values,
|
| 443 |
+
position_ids=position_ids,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
hidden_states = inputs_embeds
|
| 447 |
+
|
| 448 |
+
# create position embeddings to be shared across the decoder layers
|
| 449 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 450 |
+
|
| 451 |
+
# decoder layers
|
| 452 |
+
all_hidden_states = () if output_hidden_states else None
|
| 453 |
+
all_self_attns = () if output_attentions else None
|
| 454 |
+
|
| 455 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 456 |
+
if output_hidden_states:
|
| 457 |
+
all_hidden_states += (hidden_states,)
|
| 458 |
+
|
| 459 |
+
layer_outputs = decoder_layer(
|
| 460 |
+
hidden_states,
|
| 461 |
+
attention_mask=causal_mask,
|
| 462 |
+
position_ids=position_ids,
|
| 463 |
+
past_key_value=past_key_values,
|
| 464 |
+
output_attentions=output_attentions,
|
| 465 |
+
use_cache=use_cache,
|
| 466 |
+
cache_position=cache_position,
|
| 467 |
+
position_embeddings=position_embeddings,
|
| 468 |
+
**flash_attn_kwargs,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
hidden_states = layer_outputs[0]
|
| 472 |
+
|
| 473 |
+
if output_attentions:
|
| 474 |
+
all_self_attns += (layer_outputs[1],)
|
| 475 |
+
|
| 476 |
+
hidden_states = self.norm(hidden_states)
|
| 477 |
+
|
| 478 |
+
# add hidden states from the last decoder layer
|
| 479 |
+
if output_hidden_states:
|
| 480 |
+
all_hidden_states += (hidden_states,)
|
| 481 |
+
|
| 482 |
+
return BaseModelOutputWithPast(
|
| 483 |
+
last_hidden_state=hidden_states,
|
| 484 |
+
past_key_values=past_key_values if use_cache else None,
|
| 485 |
+
hidden_states=all_hidden_states,
|
| 486 |
+
attentions=all_self_attns,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@auto_docstring
|
| 494 |
+
class PanguEmbeddedForCausalLM(PanguEmbeddedPreTrainedModel, GenerationMixin):
|
| 495 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 496 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 497 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 498 |
+
|
| 499 |
+
def __init__(self, config):
|
| 500 |
+
super().__init__(config)
|
| 501 |
+
self.model = PanguEmbeddedModel(config)
|
| 502 |
+
self.vocab_size = config.vocab_size
|
| 503 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 504 |
+
|
| 505 |
+
# Initialize weights and apply final processing
|
| 506 |
+
self.post_init()
|
| 507 |
+
|
| 508 |
+
def get_input_embeddings(self):
|
| 509 |
+
return self.model.embed_tokens
|
| 510 |
+
|
| 511 |
+
def set_input_embeddings(self, value):
|
| 512 |
+
self.model.embed_tokens = value
|
| 513 |
+
|
| 514 |
+
def get_output_embeddings(self):
|
| 515 |
+
return self.lm_head
|
| 516 |
+
|
| 517 |
+
def set_output_embeddings(self, new_embeddings):
|
| 518 |
+
self.lm_head = new_embeddings
|
| 519 |
+
|
| 520 |
+
def set_decoder(self, decoder):
|
| 521 |
+
self.model = decoder
|
| 522 |
+
|
| 523 |
+
def get_decoder(self):
|
| 524 |
+
return self.model
|
| 525 |
+
|
| 526 |
+
@can_return_tuple
|
| 527 |
+
@auto_docstring
|
| 528 |
+
def forward(
|
| 529 |
+
self,
|
| 530 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 531 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 532 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 533 |
+
past_key_values: Optional[Cache] = None,
|
| 534 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 535 |
+
labels: Optional[torch.LongTensor] = None,
|
| 536 |
+
use_cache: Optional[bool] = None,
|
| 537 |
+
output_attentions: Optional[bool] = None,
|
| 538 |
+
output_hidden_states: Optional[bool] = None,
|
| 539 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 540 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 541 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 542 |
+
) -> CausalLMOutputWithPast:
|
| 543 |
+
|
| 544 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 545 |
+
output_hidden_states = (
|
| 546 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 550 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 551 |
+
input_ids=input_ids,
|
| 552 |
+
attention_mask=attention_mask,
|
| 553 |
+
position_ids=position_ids,
|
| 554 |
+
past_key_values=past_key_values,
|
| 555 |
+
inputs_embeds=inputs_embeds,
|
| 556 |
+
use_cache=use_cache,
|
| 557 |
+
output_attentions=output_attentions,
|
| 558 |
+
output_hidden_states=output_hidden_states,
|
| 559 |
+
cache_position=cache_position,
|
| 560 |
+
**kwargs,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
hidden_states = outputs.last_hidden_state
|
| 564 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 565 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 566 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 567 |
+
|
| 568 |
+
loss = None
|
| 569 |
+
if labels is not None:
|
| 570 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 571 |
+
|
| 572 |
+
return CausalLMOutputWithPast(
|
| 573 |
+
loss=loss,
|
| 574 |
+
logits=logits,
|
| 575 |
+
past_key_values=outputs.past_key_values,
|
| 576 |
+
hidden_states=outputs.hidden_states,
|
| 577 |
+
attentions=outputs.attentions,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
__all__ = [
|
| 582 |
+
"PanguEmbeddedForCausalLM",
|
| 583 |
+
"PanguEmbeddedModel",
|
| 584 |
+
"PanguEmbeddedPreTrainedModel",
|
| 585 |
+
]
|
modular_openpangu_dense.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from typing import Callable, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
import torch_npu
|
| 28 |
+
from torch_npu.contrib import transfer_to_npu
|
| 29 |
+
if "910" in torch.npu.get_device_name():
|
| 30 |
+
NPU_ATTN_INFR = True
|
| 31 |
+
print("[INFO] torch_npu detected. Using NPU fused infer attention.")
|
| 32 |
+
else:
|
| 33 |
+
NPU_ATTN_INFR = False
|
| 34 |
+
|
| 35 |
+
from transformers.cache_utils import Cache
|
| 36 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 37 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 38 |
+
from transformers.processing_utils import Unpack
|
| 39 |
+
from transformers.utils import logging
|
| 40 |
+
from transformers.models.llama.modeling_llama import (
|
| 41 |
+
LlamaAttention,
|
| 42 |
+
LlamaDecoderLayer,
|
| 43 |
+
LlamaForCausalLM,
|
| 44 |
+
LlamaForSequenceClassification,
|
| 45 |
+
LlamaMLP,
|
| 46 |
+
LlamaModel,
|
| 47 |
+
apply_rotary_pos_emb,
|
| 48 |
+
eager_attention_forward,
|
| 49 |
+
)
|
| 50 |
+
from .configuration_openpangu_dense import PanguEmbeddedConfig
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class PanguEmbeddedMLP(LlamaMLP):
|
| 57 |
+
def __init__(self, config):
|
| 58 |
+
super().__init__(config)
|
| 59 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 60 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 61 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PanguEmbeddedAttention(LlamaAttention):
|
| 65 |
+
def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.config = config
|
| 68 |
+
self.layer_idx = layer_idx
|
| 69 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 70 |
+
self.num_heads = config.num_attention_heads
|
| 71 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 72 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 73 |
+
self.scaling = self.head_dim**-0.5
|
| 74 |
+
self.attention_dropout = config.attention_dropout
|
| 75 |
+
self.is_causal = True
|
| 76 |
+
|
| 77 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
|
| 78 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 79 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 80 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
hidden_states: torch.Tensor,
|
| 85 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 86 |
+
attention_mask: Optional[torch.Tensor],
|
| 87 |
+
past_key_value: Optional[Cache] = None,
|
| 88 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 89 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 90 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 91 |
+
input_shape = hidden_states.shape[:-1]
|
| 92 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 93 |
+
|
| 94 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 95 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 96 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 97 |
+
|
| 98 |
+
cos, sin = position_embeddings
|
| 99 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 100 |
+
|
| 101 |
+
if past_key_value is not None:
|
| 102 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 103 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 104 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 105 |
+
|
| 106 |
+
attention_interface: Callable = eager_attention_forward
|
| 107 |
+
if self.config._attn_implementation != "eager":
|
| 108 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 109 |
+
|
| 110 |
+
if not self.training and NPU_ATTN_INFR:
|
| 111 |
+
q_len = input_shape[1]
|
| 112 |
+
if attention_mask is not None:
|
| 113 |
+
attention_mask = ~attention_mask.bool()
|
| 114 |
+
elif q_len > 1:
|
| 115 |
+
attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
|
| 116 |
+
|
| 117 |
+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 118 |
+
query_states, key_states, value_states,
|
| 119 |
+
num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
|
| 120 |
+
input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
|
| 121 |
+
attn_output = attn_output.transpose(1, 2)
|
| 122 |
+
attn_weights = None
|
| 123 |
+
else:
|
| 124 |
+
attn_output, attn_weights = attention_interface(
|
| 125 |
+
self,
|
| 126 |
+
query_states,
|
| 127 |
+
key_states,
|
| 128 |
+
value_states,
|
| 129 |
+
attention_mask,
|
| 130 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 131 |
+
scaling=self.scaling,
|
| 132 |
+
**kwargs,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 136 |
+
attn_output = self.o_proj(attn_output)
|
| 137 |
+
return attn_output, attn_weights
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class PanguEmbeddedDecoderLayer(LlamaDecoderLayer):
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class PanguEmbeddedModel(LlamaModel):
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class PanguEmbeddedForCausalLM(LlamaForCausalLM):
|
| 149 |
+
pass
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "[unused10]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<unk>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<unk>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
tokenization_openpangu.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
from shutil import copyfile
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import sentencepiece as spm
|
| 27 |
+
|
| 28 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 29 |
+
from transformers.utils import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
|
| 35 |
+
|
| 36 |
+
PRETRAINED_VOCAB_FILES_MAP = {}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def convert_bool(string):
|
| 40 |
+
if isinstance(string, str):
|
| 41 |
+
if string.lower() == "true":
|
| 42 |
+
return True
|
| 43 |
+
elif string.lower() == "false":
|
| 44 |
+
return False
|
| 45 |
+
else:
|
| 46 |
+
return string
|
| 47 |
+
else:
|
| 48 |
+
return string
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PanguTokenizer(PreTrainedTokenizer):
|
| 52 |
+
"""
|
| 53 |
+
Construct a tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
vocab_file (`str`):
|
| 57 |
+
Path to the vocabulary file.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 61 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
| 62 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 63 |
+
_auto_class = "AutoTokenizer"
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
vocab_file,
|
| 68 |
+
unk_token="<unk>",
|
| 69 |
+
bos_token="<s>",
|
| 70 |
+
eos_token="</s>",
|
| 71 |
+
pad_token="</s>",
|
| 72 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 73 |
+
add_bos_token=True,
|
| 74 |
+
add_eos_token=False,
|
| 75 |
+
decode_with_prefix_space=False,
|
| 76 |
+
clean_up_tokenization_spaces=False,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 80 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 81 |
+
self.sp_model.Load(vocab_file)
|
| 82 |
+
super().__init__(
|
| 83 |
+
bos_token=bos_token,
|
| 84 |
+
eos_token=eos_token,
|
| 85 |
+
unk_token=unk_token,
|
| 86 |
+
pad_token=pad_token,
|
| 87 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 88 |
+
**kwargs,
|
| 89 |
+
)
|
| 90 |
+
self.vocab_file = vocab_file
|
| 91 |
+
self.add_bos_token = convert_bool(add_bos_token)
|
| 92 |
+
self.add_eos_token = add_eos_token
|
| 93 |
+
self.decode_with_prefix_space = decode_with_prefix_space
|
| 94 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 95 |
+
self.sp_model.Load(vocab_file)
|
| 96 |
+
self._no_prefix_space_tokens = None
|
| 97 |
+
|
| 98 |
+
""" Initialisation"""
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def no_prefix_space_tokens(self):
|
| 102 |
+
if self._no_prefix_space_tokens is None:
|
| 103 |
+
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
|
| 104 |
+
self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
|
| 105 |
+
return self._no_prefix_space_tokens
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def vocab_size(self):
|
| 109 |
+
"""Returns vocab size"""
|
| 110 |
+
return self.sp_model.get_piece_size()
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def bos_token_id(self) -> Optional[int]:
|
| 114 |
+
return self.sp_model.bos_id()
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def eos_token_id(self) -> Optional[int]:
|
| 118 |
+
return super().eos_token_id
|
| 119 |
+
|
| 120 |
+
def get_vocab(self):
|
| 121 |
+
"""Returns vocab as a dict"""
|
| 122 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 123 |
+
vocab.update(self.added_tokens_encoder)
|
| 124 |
+
return vocab
|
| 125 |
+
|
| 126 |
+
def _tokenize(self, text):
|
| 127 |
+
"""Returns a tokenized string."""
|
| 128 |
+
return self.sp_model.encode(text, out_type=str)
|
| 129 |
+
|
| 130 |
+
def _convert_token_to_id(self, token):
|
| 131 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 132 |
+
return self.sp_model.piece_to_id(token)
|
| 133 |
+
|
| 134 |
+
def _convert_id_to_token(self, index):
|
| 135 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 136 |
+
token = self.sp_model.IdToPiece(index)
|
| 137 |
+
return token
|
| 138 |
+
|
| 139 |
+
def _maybe_add_prefix_space(self, tokens, decoded):
|
| 140 |
+
if tokens and tokens[0] not in self.no_prefix_space_tokens:
|
| 141 |
+
return " " + decoded
|
| 142 |
+
else:
|
| 143 |
+
return decoded
|
| 144 |
+
|
| 145 |
+
def convert_tokens_to_string(self, tokens):
|
| 146 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 147 |
+
current_sub_tokens = []
|
| 148 |
+
out_string = ""
|
| 149 |
+
prev_is_special = False
|
| 150 |
+
for token in tokens:
|
| 151 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 152 |
+
if token in self.all_special_tokens:
|
| 153 |
+
# Decode the current sub-tokens first
|
| 154 |
+
if current_sub_tokens:
|
| 155 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 156 |
+
current_sub_tokens = []
|
| 157 |
+
# Append the special token without adding extra spaces
|
| 158 |
+
out_string += token
|
| 159 |
+
prev_is_special = True
|
| 160 |
+
else:
|
| 161 |
+
current_sub_tokens.append(token)
|
| 162 |
+
prev_is_special = False
|
| 163 |
+
# Decode any remaining sub-tokens
|
| 164 |
+
if current_sub_tokens:
|
| 165 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 166 |
+
# Clean up leading and trailing spaces
|
| 167 |
+
if self.clean_up_tokenization_spaces:
|
| 168 |
+
out_string = self.clean_up_tokenization(out_string)
|
| 169 |
+
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
|
| 170 |
+
return out_string[1:]
|
| 171 |
+
|
| 172 |
+
# Override decode to set spaces_between_special_tokens to True as default
|
| 173 |
+
def decode(self,
|
| 174 |
+
token_ids,
|
| 175 |
+
spaces_between_special_tokens: bool = False,
|
| 176 |
+
**kwargs):
|
| 177 |
+
return super().decode(
|
| 178 |
+
token_ids=token_ids,
|
| 179 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 184 |
+
"""
|
| 185 |
+
Save the vocabulary and special tokens file to a directory.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
save_directory (`str`):
|
| 189 |
+
The directory in which to save the vocabulary.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
`Tuple(str)`: Paths to the files saved.
|
| 193 |
+
"""
|
| 194 |
+
if not os.path.isdir(save_directory):
|
| 195 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 196 |
+
return ("",)
|
| 197 |
+
out_vocab_file = os.path.join(
|
| 198 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 202 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 203 |
+
elif not os.path.isfile(self.vocab_file):
|
| 204 |
+
with open(out_vocab_file, "wb") as fi:
|
| 205 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 206 |
+
fi.write(content_spiece_model)
|
| 207 |
+
|
| 208 |
+
return (out_vocab_file,)
|
| 209 |
+
|
| 210 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 211 |
+
if self.add_bos_token:
|
| 212 |
+
bos_token_ids = [self.bos_token_id]
|
| 213 |
+
else:
|
| 214 |
+
bos_token_ids = []
|
| 215 |
+
|
| 216 |
+
output = bos_token_ids + token_ids_0
|
| 217 |
+
|
| 218 |
+
if token_ids_1 is not None:
|
| 219 |
+
output = output + token_ids_1
|
| 220 |
+
|
| 221 |
+
if self.add_eos_token:
|
| 222 |
+
output = output + [self.eos_token_id]
|
| 223 |
+
|
| 224 |
+
return output
|
| 225 |
+
|
| 226 |
+
def get_special_tokens_mask(
|
| 227 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 228 |
+
) -> List[int]:
|
| 229 |
+
"""
|
| 230 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 231 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
token_ids_0 (`List[int]`):
|
| 235 |
+
List of IDs.
|
| 236 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 237 |
+
Optional second list of IDs for sequence pairs.
|
| 238 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 239 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 243 |
+
"""
|
| 244 |
+
if already_has_special_tokens:
|
| 245 |
+
return super().get_special_tokens_mask(
|
| 246 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if token_ids_1 is None:
|
| 250 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 251 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 252 |
+
|
| 253 |
+
def create_token_type_ids_from_sequences(
|
| 254 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 255 |
+
) -> List[int]:
|
| 256 |
+
"""
|
| 257 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
| 258 |
+
use of token type ids, therefore a list of zeros is returned.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
token_ids_0 (`List[int]`):
|
| 262 |
+
List of IDs.
|
| 263 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 264 |
+
Optional second list of IDs for sequence pairs.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
`List[int]`: List of zeros.
|
| 268 |
+
"""
|
| 269 |
+
eos = [self.eos_token_id]
|
| 270 |
+
|
| 271 |
+
if token_ids_1 is None:
|
| 272 |
+
return len(token_ids_0 + eos) * [0]
|
| 273 |
+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec
|
| 3 |
+
size 2477809
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"add_bos_token": true, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguTokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguTokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}
|