Spaces:
Runtime error
Runtime error
Commit
·
75ba0e0
1
Parent(s):
7438ed6
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +201 -0
- checkpoints.md +10 -0
- colab.md +7 -0
- criterions/__init__.py +2 -0
- criterions/label_smoothed_cross_entropy.py +343 -0
- criterions/scst_loss.py +280 -0
- data/__init__.py +0 -0
- data/data_utils.py +601 -0
- data/file_dataset.py +102 -0
- data/mm_data/__init__.py +0 -0
- data/mm_data/caption_dataset.py +164 -0
- data/ofa_dataset.py +25 -0
- datasets.md +7 -0
- evaluate.py +152 -0
- models/__init__.py +1 -0
- models/ofa/__init__.py +1 -0
- models/ofa/ofa.py +410 -0
- models/ofa/resnet.py +225 -0
- models/ofa/unify_multihead_attention.py +518 -0
- models/ofa/unify_transformer.py +1510 -0
- models/ofa/unify_transformer_layer.py +542 -0
- models/search.py +814 -0
- models/sequence_generator.py +1053 -0
- notebooks/caption_infer.ipynb +0 -0
- ofa_module/__init__.py +5 -0
- run_scripts/caption/coco_eval.py +42 -0
- run_scripts/caption/evaluate_caption.sh +29 -0
- run_scripts/caption/train_caption_stage1.sh +104 -0
- run_scripts/caption/train_caption_stage2.sh +101 -0
- tasks/__init__.py +2 -0
- tasks/mm_tasks/__init__.py +1 -0
- tasks/mm_tasks/caption.py +249 -0
- tasks/ofa_task.py +338 -0
- train.py +523 -0
- trainer.py +1531 -0
- utils/BPE/__init__.py +0 -0
- utils/BPE/dict.txt +0 -0
- utils/BPE/encoder.json +0 -0
- utils/BPE/vocab.bpe +0 -0
- utils/__init__.py +0 -0
- utils/checkpoint_utils.py +875 -0
- utils/cider/pyciderevalcap/__init__.py +1 -0
- utils/cider/pyciderevalcap/cider/__init__.py +1 -0
- utils/cider/pyciderevalcap/cider/cider.py +65 -0
- utils/cider/pyciderevalcap/cider/cider_scorer.py +207 -0
- utils/cider/pyciderevalcap/ciderD/__init__.py +1 -0
- utils/cider/pyciderevalcap/ciderD/ciderD.py +58 -0
- utils/cider/pyciderevalcap/ciderD/ciderD_scorer.py +222 -0
- utils/eval_utils.py +39 -0
- utils/transforms.py +508 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 1999-2022 Alibaba Group Holding Ltd.
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
checkpoints.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Checkpoints
|
| 2 |
+
|
| 3 |
+
We provide links for you to download our checkpoints. We will release all the checkpoints including pretrained and finetuned models on different tasks.
|
| 4 |
+
|
| 5 |
+
## Pretraining
|
| 6 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a>
|
| 7 |
+
|
| 8 |
+
## Finetuning
|
| 9 |
+
|
| 10 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
|
colab.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Colab Notebooks
|
| 2 |
+
|
| 3 |
+
We provide Colab notebooks of different downstream task for you guys to enjoy OFA. See below.
|
| 4 |
+
|
| 5 |
+
[Image Captioning](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing) [![][colab]](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
|
| 6 |
+
|
| 7 |
+
[colab]: <https://colab.research.google.com/assets/colab-badge.svg>
|
criterions/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .scst_loss import ScstRewardCriterion
|
| 2 |
+
from .label_smoothed_cross_entropy import AjustLabelSmoothedCrossEntropyCriterion
|
criterions/label_smoothed_cross_entropy.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import numpy as np
|
| 13 |
+
from fairseq import metrics, utils
|
| 14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 15 |
+
from fairseq.dataclass import FairseqDataclass
|
| 16 |
+
from omegaconf import II
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class AjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
|
| 21 |
+
label_smoothing: float = field(
|
| 22 |
+
default=0.0,
|
| 23 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
| 24 |
+
)
|
| 25 |
+
report_accuracy: bool = field(
|
| 26 |
+
default=False,
|
| 27 |
+
metadata={"help": "report accuracy metric"},
|
| 28 |
+
)
|
| 29 |
+
ignore_prefix_size: int = field(
|
| 30 |
+
default=0,
|
| 31 |
+
metadata={"help": "Ignore first N tokens"},
|
| 32 |
+
)
|
| 33 |
+
ignore_eos: bool = field(
|
| 34 |
+
default=False,
|
| 35 |
+
metadata={"help": "Ignore eos token"},
|
| 36 |
+
)
|
| 37 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
| 38 |
+
drop_worst_ratio: float = field(
|
| 39 |
+
default=0.0,
|
| 40 |
+
metadata={"help": "ratio for discarding bad samples"},
|
| 41 |
+
)
|
| 42 |
+
drop_worst_after: int = field(
|
| 43 |
+
default=0,
|
| 44 |
+
metadata={"help": "steps for discarding bad samples"},
|
| 45 |
+
)
|
| 46 |
+
use_rdrop: bool = field(
|
| 47 |
+
default=False, metadata={"help": "use R-Drop"}
|
| 48 |
+
)
|
| 49 |
+
reg_alpha: float = field(
|
| 50 |
+
default=1.0, metadata={"help": "weight for R-Drop"}
|
| 51 |
+
)
|
| 52 |
+
sample_patch_num: int = field(
|
| 53 |
+
default=196, metadata={"help": "sample patchs for v1"}
|
| 54 |
+
)
|
| 55 |
+
constraint_range: Optional[str] = field(
|
| 56 |
+
default=None,
|
| 57 |
+
metadata={"help": "constraint range"}
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def construct_rdrop_sample(x):
|
| 62 |
+
if isinstance(x, dict):
|
| 63 |
+
for key in x:
|
| 64 |
+
x[key] = construct_rdrop_sample(x[key])
|
| 65 |
+
return x
|
| 66 |
+
elif isinstance(x, torch.Tensor):
|
| 67 |
+
return x.repeat(2, *([1] * (x.dim()-1)))
|
| 68 |
+
elif isinstance(x, int):
|
| 69 |
+
return x * 2
|
| 70 |
+
elif isinstance(x, np.ndarray):
|
| 71 |
+
return x.repeat(2)
|
| 72 |
+
else:
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def kl_loss(p, q):
|
| 77 |
+
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
| 78 |
+
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
| 79 |
+
loss = (p_loss + q_loss) / 2
|
| 80 |
+
return loss
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def label_smoothed_nll_loss(
|
| 84 |
+
lprobs, target, epsilon, update_num, reduce=True,
|
| 85 |
+
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
|
| 86 |
+
constraint_masks=None, constraint_start=None, constraint_end=None
|
| 87 |
+
):
|
| 88 |
+
if target.dim() == lprobs.dim() - 1:
|
| 89 |
+
target = target.unsqueeze(-1)
|
| 90 |
+
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
| 91 |
+
if constraint_masks is not None:
|
| 92 |
+
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
|
| 93 |
+
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
| 94 |
+
elif constraint_start is not None and constraint_end is not None:
|
| 95 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
| 96 |
+
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
|
| 97 |
+
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
| 98 |
+
else:
|
| 99 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
| 100 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
| 101 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
| 102 |
+
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
| 103 |
+
if use_rdrop:
|
| 104 |
+
true_batch_size = loss.size(0) // 2
|
| 105 |
+
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
|
| 106 |
+
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
|
| 107 |
+
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
|
| 108 |
+
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
|
| 109 |
+
else:
|
| 110 |
+
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
|
| 111 |
+
nll_loss = nll_loss[indices]
|
| 112 |
+
lprobs = lprobs[indices]
|
| 113 |
+
|
| 114 |
+
ntokens = loss.numel()
|
| 115 |
+
nll_loss = nll_loss.sum()
|
| 116 |
+
loss = loss.sum()
|
| 117 |
+
if use_rdrop:
|
| 118 |
+
true_batch_size = lprobs.size(0) // 2
|
| 119 |
+
p = lprobs[:true_batch_size]
|
| 120 |
+
q = lprobs[true_batch_size:]
|
| 121 |
+
if constraint_start is not None and constraint_end is not None:
|
| 122 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
| 123 |
+
p = p[:, constraint_range]
|
| 124 |
+
q = q[:, constraint_range]
|
| 125 |
+
loss += kl_loss(p, q) * reg_alpha
|
| 126 |
+
|
| 127 |
+
return loss, nll_loss, ntokens
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@register_criterion(
|
| 131 |
+
"ajust_label_smoothed_cross_entropy", dataclass=AjustLabelSmoothedCrossEntropyCriterionConfig
|
| 132 |
+
)
|
| 133 |
+
class AjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
task,
|
| 137 |
+
sentence_avg,
|
| 138 |
+
label_smoothing,
|
| 139 |
+
ignore_prefix_size=0,
|
| 140 |
+
ignore_eos=False,
|
| 141 |
+
report_accuracy=False,
|
| 142 |
+
drop_worst_ratio=0,
|
| 143 |
+
drop_worst_after=0,
|
| 144 |
+
use_rdrop=False,
|
| 145 |
+
reg_alpha=1.0,
|
| 146 |
+
sample_patch_num=196,
|
| 147 |
+
constraint_range=None
|
| 148 |
+
):
|
| 149 |
+
super().__init__(task)
|
| 150 |
+
self.sentence_avg = sentence_avg
|
| 151 |
+
self.eps = label_smoothing
|
| 152 |
+
self.ignore_prefix_size = ignore_prefix_size
|
| 153 |
+
self.ignore_eos = ignore_eos
|
| 154 |
+
self.report_accuracy = report_accuracy
|
| 155 |
+
self.drop_worst_ratio = drop_worst_ratio
|
| 156 |
+
self.drop_worst_after = drop_worst_after
|
| 157 |
+
self.use_rdrop = use_rdrop
|
| 158 |
+
self.reg_alpha = reg_alpha
|
| 159 |
+
self.sample_patch_num = sample_patch_num
|
| 160 |
+
|
| 161 |
+
self.constraint_start = None
|
| 162 |
+
self.constraint_end = None
|
| 163 |
+
if constraint_range is not None:
|
| 164 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
| 165 |
+
self.constraint_start = int(constraint_start)
|
| 166 |
+
self.constraint_end = int(constraint_end)
|
| 167 |
+
|
| 168 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
| 169 |
+
"""Compute the loss for the given sample.
|
| 170 |
+
|
| 171 |
+
Returns a tuple with three elements:
|
| 172 |
+
1) the loss
|
| 173 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 174 |
+
3) logging outputs to display while training
|
| 175 |
+
"""
|
| 176 |
+
if isinstance(sample, list):
|
| 177 |
+
if self.sample_patch_num > 0:
|
| 178 |
+
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
|
| 179 |
+
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
|
| 180 |
+
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
|
| 181 |
+
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
| 182 |
+
sample_size = 1
|
| 183 |
+
logging_output = {
|
| 184 |
+
"loss": loss.data,
|
| 185 |
+
"loss_v1": loss_v1.data,
|
| 186 |
+
"loss_v2": loss_v2.data,
|
| 187 |
+
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
|
| 188 |
+
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
|
| 189 |
+
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
|
| 190 |
+
"sample_size": 1,
|
| 191 |
+
"sample_size_v1": sample_size_v1,
|
| 192 |
+
"sample_size_v2": sample_size_v2,
|
| 193 |
+
}
|
| 194 |
+
return loss, sample_size, logging_output
|
| 195 |
+
|
| 196 |
+
if self.use_rdrop:
|
| 197 |
+
construct_rdrop_sample(sample)
|
| 198 |
+
|
| 199 |
+
net_output = model(**sample["net_input"])
|
| 200 |
+
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
|
| 201 |
+
sample_size = (
|
| 202 |
+
sample["target"].size(0) if self.sentence_avg else ntokens
|
| 203 |
+
)
|
| 204 |
+
logging_output = {
|
| 205 |
+
"loss": loss.data,
|
| 206 |
+
"nll_loss": nll_loss.data,
|
| 207 |
+
"ntokens": sample["ntokens"],
|
| 208 |
+
"nsentences": sample["nsentences"],
|
| 209 |
+
"sample_size": sample_size,
|
| 210 |
+
}
|
| 211 |
+
if self.report_accuracy:
|
| 212 |
+
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
| 213 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
| 214 |
+
logging_output["total"] = utils.item(total.data)
|
| 215 |
+
return loss, sample_size, logging_output
|
| 216 |
+
|
| 217 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
| 218 |
+
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
|
| 219 |
+
constraint_masks = None
|
| 220 |
+
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
|
| 221 |
+
constraint_masks = sample["constraint_masks"]
|
| 222 |
+
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
| 223 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
| 224 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
| 225 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
| 226 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
|
| 227 |
+
target = model.get_targets(sample, net_output)
|
| 228 |
+
if self.ignore_prefix_size > 0:
|
| 229 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
| 230 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
| 231 |
+
if constraint_masks is not None:
|
| 232 |
+
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
|
| 233 |
+
if self.ignore_eos:
|
| 234 |
+
bsz, seq_len, embed_dim = lprobs.size()
|
| 235 |
+
eos_indices = target.eq(self.task.tgt_dict.eos())
|
| 236 |
+
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
| 237 |
+
target = target[~eos_indices].reshape(bsz, seq_len-1)
|
| 238 |
+
if constraint_masks is not None:
|
| 239 |
+
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
| 240 |
+
if constraint_masks is not None:
|
| 241 |
+
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
|
| 242 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
|
| 243 |
+
|
| 244 |
+
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
|
| 245 |
+
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
|
| 246 |
+
if constraint_masks is not None:
|
| 247 |
+
constraint_masks = constraint_masks[target != self.padding_idx]
|
| 248 |
+
lprobs = lprobs[target != self.padding_idx]
|
| 249 |
+
target = target[target != self.padding_idx]
|
| 250 |
+
loss, nll_loss, ntokens = label_smoothed_nll_loss(
|
| 251 |
+
lprobs,
|
| 252 |
+
target,
|
| 253 |
+
self.eps,
|
| 254 |
+
update_num,
|
| 255 |
+
reduce=reduce,
|
| 256 |
+
drop_worst_ratio=self.drop_worst_ratio,
|
| 257 |
+
drop_worst_after=self.drop_worst_after,
|
| 258 |
+
use_rdrop=self.use_rdrop,
|
| 259 |
+
reg_alpha=self.reg_alpha,
|
| 260 |
+
constraint_masks=constraint_masks,
|
| 261 |
+
constraint_start=self.constraint_start,
|
| 262 |
+
constraint_end=self.constraint_end
|
| 263 |
+
)
|
| 264 |
+
return loss, nll_loss, ntokens
|
| 265 |
+
|
| 266 |
+
def compute_accuracy(self, model, net_output, sample):
|
| 267 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
| 268 |
+
mask = target.ne(self.padding_idx)
|
| 269 |
+
n_correct = torch.sum(
|
| 270 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
| 271 |
+
)
|
| 272 |
+
total = torch.sum(mask)
|
| 273 |
+
return n_correct, total
|
| 274 |
+
|
| 275 |
+
@classmethod
|
| 276 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
| 277 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 278 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 279 |
+
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
|
| 280 |
+
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
|
| 281 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
| 282 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 283 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 284 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 285 |
+
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
|
| 286 |
+
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
|
| 287 |
+
|
| 288 |
+
metrics.log_scalar(
|
| 289 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
| 290 |
+
)
|
| 291 |
+
metrics.log_scalar(
|
| 292 |
+
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
|
| 293 |
+
)
|
| 294 |
+
metrics.log_scalar(
|
| 295 |
+
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
|
| 296 |
+
)
|
| 297 |
+
metrics.log_scalar(
|
| 298 |
+
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
|
| 299 |
+
)
|
| 300 |
+
metrics.log_derived(
|
| 301 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
metrics.log_scalar(
|
| 305 |
+
"ntokens", ntokens, 1, round=3
|
| 306 |
+
)
|
| 307 |
+
metrics.log_scalar(
|
| 308 |
+
"nsentences", nsentences, 1, round=3
|
| 309 |
+
)
|
| 310 |
+
metrics.log_scalar(
|
| 311 |
+
"sample_size", sample_size, 1, round=3
|
| 312 |
+
)
|
| 313 |
+
metrics.log_scalar(
|
| 314 |
+
"sample_size_v1", sample_size_v1, 1, round=3
|
| 315 |
+
)
|
| 316 |
+
metrics.log_scalar(
|
| 317 |
+
"sample_size_v2", sample_size_v2, 1, round=3
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
| 321 |
+
if total > 0:
|
| 322 |
+
metrics.log_scalar("total", total)
|
| 323 |
+
n_correct = utils.item(
|
| 324 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
| 325 |
+
)
|
| 326 |
+
metrics.log_scalar("n_correct", n_correct)
|
| 327 |
+
metrics.log_derived(
|
| 328 |
+
"accuracy",
|
| 329 |
+
lambda meters: round(
|
| 330 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
| 331 |
+
)
|
| 332 |
+
if meters["total"].sum > 0
|
| 333 |
+
else float("nan"),
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
@staticmethod
|
| 337 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 338 |
+
"""
|
| 339 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 340 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 341 |
+
to True will improves distributed training speed.
|
| 342 |
+
"""
|
| 343 |
+
return True
|
criterions/scst_loss.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import string
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from fairseq import metrics, utils
|
| 14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 15 |
+
from fairseq.dataclass import FairseqDataclass
|
| 16 |
+
from omegaconf import II
|
| 17 |
+
|
| 18 |
+
from data import data_utils
|
| 19 |
+
from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
|
| 23 |
+
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
|
| 24 |
+
if ignore_index is not None:
|
| 25 |
+
pad_mask = target.eq(ignore_index)
|
| 26 |
+
loss.masked_fill_(pad_mask, 0.0)
|
| 27 |
+
ntokens = (~pad_mask).sum()
|
| 28 |
+
else:
|
| 29 |
+
loss = loss.squeeze(-1)
|
| 30 |
+
ntokens = target.numel()
|
| 31 |
+
if reduce:
|
| 32 |
+
loss = loss.sum()
|
| 33 |
+
return loss, ntokens
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ScstRewardCriterionConfig(FairseqDataclass):
|
| 37 |
+
scst_cider_cached_tokens: str = field(
|
| 38 |
+
default="coco-train-words.p",
|
| 39 |
+
metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
|
| 40 |
+
)
|
| 41 |
+
ignore_prefix_size: int = field(
|
| 42 |
+
default=0,
|
| 43 |
+
metadata={"help": "Ignore first N tokens"},
|
| 44 |
+
)
|
| 45 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
| 46 |
+
constraint_range: Optional[str] = field(
|
| 47 |
+
default=None,
|
| 48 |
+
metadata={"help": "constraint range"}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@register_criterion(
|
| 53 |
+
"scst_reward_criterion", dataclass=ScstRewardCriterionConfig
|
| 54 |
+
)
|
| 55 |
+
class ScstRewardCriterion(FairseqCriterion):
|
| 56 |
+
CIDER_REWARD_WEIGHT = 1
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
task,
|
| 61 |
+
scst_cider_cached_tokens,
|
| 62 |
+
sentence_avg,
|
| 63 |
+
ignore_prefix_size=0,
|
| 64 |
+
constraint_range=None
|
| 65 |
+
):
|
| 66 |
+
super().__init__(task)
|
| 67 |
+
self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
|
| 68 |
+
self.sentence_avg = sentence_avg
|
| 69 |
+
self.ignore_prefix_size = ignore_prefix_size
|
| 70 |
+
self.transtab = str.maketrans({key: None for key in string.punctuation})
|
| 71 |
+
|
| 72 |
+
self.constraint_start = None
|
| 73 |
+
self.constraint_end = None
|
| 74 |
+
if constraint_range is not None:
|
| 75 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
| 76 |
+
self.constraint_start = int(constraint_start)
|
| 77 |
+
self.constraint_end = int(constraint_end)
|
| 78 |
+
|
| 79 |
+
def forward(self, model, sample, reduce=True):
|
| 80 |
+
"""Compute the loss for the given sample.
|
| 81 |
+
|
| 82 |
+
Returns a tuple with three elements:
|
| 83 |
+
1) the loss
|
| 84 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 85 |
+
3) logging outputs to display while training
|
| 86 |
+
"""
|
| 87 |
+
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
|
| 88 |
+
|
| 89 |
+
sample_size = (
|
| 90 |
+
nsentences if self.sentence_avg else ntokens
|
| 91 |
+
)
|
| 92 |
+
logging_output = {
|
| 93 |
+
"loss": loss.data,
|
| 94 |
+
"score": score,
|
| 95 |
+
"ntokens": ntokens,
|
| 96 |
+
"nsentences": nsentences,
|
| 97 |
+
"sample_size": sample_size,
|
| 98 |
+
}
|
| 99 |
+
return loss, sample_size, logging_output
|
| 100 |
+
|
| 101 |
+
def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
|
| 102 |
+
'''
|
| 103 |
+
gen_res: generated captions, list of str
|
| 104 |
+
gt_idx: list of int, of the same length as gen_res
|
| 105 |
+
gt_res: ground truth captions, list of list of str.
|
| 106 |
+
gen_res[i] corresponds to gt_res[gt_idx[i]]
|
| 107 |
+
Each image can have multiple ground truth captions
|
| 108 |
+
'''
|
| 109 |
+
gen_res_size = len(gen_res)
|
| 110 |
+
|
| 111 |
+
res = OrderedDict()
|
| 112 |
+
for i in range(gen_res_size):
|
| 113 |
+
res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
|
| 114 |
+
|
| 115 |
+
gts = OrderedDict()
|
| 116 |
+
gt_res_ = [
|
| 117 |
+
[self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
|
| 118 |
+
for i in range(len(gt_res))
|
| 119 |
+
]
|
| 120 |
+
for i in range(gen_res_size):
|
| 121 |
+
gts[i] = gt_res_[gt_idx[i]]
|
| 122 |
+
|
| 123 |
+
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
| 124 |
+
_, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
|
| 125 |
+
scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
|
| 126 |
+
return scores
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def _wrap_sentence(self, s):
|
| 130 |
+
# ensure the sentence ends with <eos> token
|
| 131 |
+
# in order to keep consisitent with cider_cached_tokens
|
| 132 |
+
r = s.strip()
|
| 133 |
+
if r.endswith('.'):
|
| 134 |
+
r = r[:-1]
|
| 135 |
+
r += ' <eos>'
|
| 136 |
+
return r
|
| 137 |
+
|
| 138 |
+
def get_generator_out(self, model, sample):
|
| 139 |
+
def decode(toks):
|
| 140 |
+
hypo = toks.int().cpu()
|
| 141 |
+
hypo_str = self.task.tgt_dict.string(hypo)
|
| 142 |
+
hypo_str = self.task.bpe.decode(hypo_str).strip()
|
| 143 |
+
return hypo, hypo_str
|
| 144 |
+
|
| 145 |
+
model.eval()
|
| 146 |
+
with torch.no_grad():
|
| 147 |
+
self.task.scst_generator.model.eval()
|
| 148 |
+
gen_out = self.task.scst_generator.generate([model], sample)
|
| 149 |
+
|
| 150 |
+
gen_target = []
|
| 151 |
+
gen_res = []
|
| 152 |
+
gt_res = []
|
| 153 |
+
for i in range(len(gen_out)):
|
| 154 |
+
for j in range(len(gen_out[i])):
|
| 155 |
+
hypo, hypo_str = decode(gen_out[i][j]["tokens"])
|
| 156 |
+
gen_target.append(hypo)
|
| 157 |
+
gen_res.append(hypo_str)
|
| 158 |
+
gt_res.append(
|
| 159 |
+
decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
return gen_target, gen_res, gt_res
|
| 163 |
+
|
| 164 |
+
def get_reward_and_scores(self, gen_res, gt_res, device):
|
| 165 |
+
batch_size = len(gt_res)
|
| 166 |
+
gen_res_size = len(gen_res)
|
| 167 |
+
seq_per_img = gen_res_size // batch_size
|
| 168 |
+
|
| 169 |
+
gt_idx = [i // seq_per_img for i in range(gen_res_size)]
|
| 170 |
+
scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
|
| 171 |
+
sc_ = scores.reshape(batch_size, seq_per_img)
|
| 172 |
+
baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
|
| 173 |
+
# sample - baseline
|
| 174 |
+
reward = scores.reshape(batch_size, seq_per_img)
|
| 175 |
+
reward = reward - baseline
|
| 176 |
+
reward = reward.reshape(gen_res_size)
|
| 177 |
+
reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
|
| 178 |
+
|
| 179 |
+
return reward, scores
|
| 180 |
+
|
| 181 |
+
def get_net_output(self, model, sample, gen_target):
|
| 182 |
+
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
|
| 183 |
+
return data_utils.collate_tokens(
|
| 184 |
+
sample_list,
|
| 185 |
+
pad_idx=self.padding_idx,
|
| 186 |
+
eos_idx=eos,
|
| 187 |
+
left_pad=False,
|
| 188 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
batch_size = len(sample["target"])
|
| 192 |
+
gen_target_size = len(gen_target)
|
| 193 |
+
seq_per_img = gen_target_size // batch_size
|
| 194 |
+
|
| 195 |
+
model.train()
|
| 196 |
+
sample_src_tokens = torch.repeat_interleave(
|
| 197 |
+
sample['net_input']['src_tokens'], seq_per_img, dim=0
|
| 198 |
+
)
|
| 199 |
+
sample_src_lengths = torch.repeat_interleave(
|
| 200 |
+
sample['net_input']['src_lengths'], seq_per_img, dim=0
|
| 201 |
+
)
|
| 202 |
+
sample_patch_images = torch.repeat_interleave(
|
| 203 |
+
sample['net_input']['patch_images'], seq_per_img, dim=0
|
| 204 |
+
)
|
| 205 |
+
sample_patch_masks = torch.repeat_interleave(
|
| 206 |
+
sample['net_input']['patch_masks'], seq_per_img, dim=0
|
| 207 |
+
)
|
| 208 |
+
gen_prev_output_tokens = torch.as_tensor(
|
| 209 |
+
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
|
| 210 |
+
device=sample["target"].device, dtype=torch.int64
|
| 211 |
+
)
|
| 212 |
+
gen_target_tokens = torch.as_tensor(
|
| 213 |
+
merge(gen_target), device=sample["target"].device, dtype=torch.int64
|
| 214 |
+
)
|
| 215 |
+
net_output = model(
|
| 216 |
+
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
|
| 217 |
+
patch_images=sample_patch_images, patch_masks=sample_patch_masks,
|
| 218 |
+
prev_output_tokens=gen_prev_output_tokens
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
return net_output, gen_target_tokens
|
| 222 |
+
|
| 223 |
+
def get_lprobs_and_target(self, model, net_output, gen_target):
|
| 224 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
| 225 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
| 226 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
| 227 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
| 228 |
+
if self.ignore_prefix_size > 0:
|
| 229 |
+
if getattr(lprobs, "batch_first", False):
|
| 230 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
| 231 |
+
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
|
| 232 |
+
else:
|
| 233 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
| 234 |
+
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
|
| 235 |
+
return lprobs, gen_target
|
| 236 |
+
|
| 237 |
+
def compute_loss(self, model, sample, reduce=True):
|
| 238 |
+
gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
|
| 239 |
+
reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
|
| 240 |
+
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
|
| 241 |
+
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
|
| 242 |
+
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
|
| 243 |
+
nsentences = gen_target_tokens.size(0)
|
| 244 |
+
|
| 245 |
+
return loss, scores.sum(), ntokens, nsentences
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
| 249 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 250 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 251 |
+
score_sum = sum(log.get("score", 0) for log in logging_outputs)
|
| 252 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 253 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 254 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 255 |
+
|
| 256 |
+
metrics.log_scalar(
|
| 257 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
| 258 |
+
)
|
| 259 |
+
metrics.log_scalar(
|
| 260 |
+
"score", score_sum / nsentences, nsentences, round=3
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
metrics.log_scalar(
|
| 264 |
+
"ntokens", ntokens, 1, round=3
|
| 265 |
+
)
|
| 266 |
+
metrics.log_scalar(
|
| 267 |
+
"nsentences", nsentences, 1, round=3
|
| 268 |
+
)
|
| 269 |
+
metrics.log_scalar(
|
| 270 |
+
"sample_size", sample_size, 1, round=3
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 275 |
+
"""
|
| 276 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 277 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 278 |
+
to True will improves distributed training speed.
|
| 279 |
+
"""
|
| 280 |
+
return True
|
data/__init__.py
ADDED
|
File without changes
|
data/data_utils.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from collections.abc import Iterable
|
| 8 |
+
except ImportError:
|
| 9 |
+
from collections import Iterable
|
| 10 |
+
import contextlib
|
| 11 |
+
import itertools
|
| 12 |
+
import logging
|
| 13 |
+
import re
|
| 14 |
+
import warnings
|
| 15 |
+
from typing import Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from fairseq.file_io import PathManager
|
| 21 |
+
from fairseq import utils
|
| 22 |
+
import os
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def infer_language_pair(path):
|
| 28 |
+
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
| 29 |
+
src, dst = None, None
|
| 30 |
+
for filename in PathManager.ls(path):
|
| 31 |
+
parts = filename.split(".")
|
| 32 |
+
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
| 33 |
+
return parts[1].split("-")
|
| 34 |
+
return src, dst
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def collate_tokens(
|
| 38 |
+
values,
|
| 39 |
+
pad_idx,
|
| 40 |
+
eos_idx=None,
|
| 41 |
+
left_pad=False,
|
| 42 |
+
move_eos_to_beginning=False,
|
| 43 |
+
pad_to_length=None,
|
| 44 |
+
pad_to_multiple=1,
|
| 45 |
+
pad_to_bsz=None,
|
| 46 |
+
):
|
| 47 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
| 48 |
+
size = max(v.size(0) for v in values)
|
| 49 |
+
size = size if pad_to_length is None else max(size, pad_to_length)
|
| 50 |
+
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
| 51 |
+
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
| 52 |
+
|
| 53 |
+
def copy_tensor(src, dst):
|
| 54 |
+
assert dst.numel() == src.numel()
|
| 55 |
+
if move_eos_to_beginning:
|
| 56 |
+
if eos_idx is None:
|
| 57 |
+
# if no eos_idx is specified, then use the last token in src
|
| 58 |
+
dst[0] = src[-1]
|
| 59 |
+
else:
|
| 60 |
+
dst[0] = eos_idx
|
| 61 |
+
dst[1:] = src[:-1]
|
| 62 |
+
else:
|
| 63 |
+
dst.copy_(src)
|
| 64 |
+
|
| 65 |
+
if values[0].dim() == 1:
|
| 66 |
+
res = values[0].new(len(values), size).fill_(pad_idx)
|
| 67 |
+
elif values[0].dim() == 2:
|
| 68 |
+
assert move_eos_to_beginning is False
|
| 69 |
+
res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
|
| 70 |
+
else:
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
for i, v in enumerate(values):
|
| 74 |
+
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
| 75 |
+
return res
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_indexed_dataset(
|
| 79 |
+
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
| 80 |
+
):
|
| 81 |
+
"""A helper function for loading indexed datasets.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
| 85 |
+
dictionary (~fairseq.data.Dictionary): data dictionary
|
| 86 |
+
dataset_impl (str, optional): which dataset implementation to use. If
|
| 87 |
+
not provided, it will be inferred automatically. For legacy indexed
|
| 88 |
+
data we use the 'cached' implementation by default.
|
| 89 |
+
combine (bool, optional): automatically load and combine multiple
|
| 90 |
+
datasets. For example, if *path* is 'data-bin/train', then we will
|
| 91 |
+
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
| 92 |
+
single ConcatDataset instance.
|
| 93 |
+
"""
|
| 94 |
+
import fairseq.data.indexed_dataset as indexed_dataset
|
| 95 |
+
from fairseq.data.concat_dataset import ConcatDataset
|
| 96 |
+
|
| 97 |
+
datasets = []
|
| 98 |
+
for k in itertools.count():
|
| 99 |
+
path_k = path + (str(k) if k > 0 else "")
|
| 100 |
+
try:
|
| 101 |
+
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
| 102 |
+
except Exception as e:
|
| 103 |
+
if "StorageException: [404] Path not found" in str(e):
|
| 104 |
+
logger.warning(f"path_k: {e} not found")
|
| 105 |
+
else:
|
| 106 |
+
raise e
|
| 107 |
+
|
| 108 |
+
dataset_impl_k = dataset_impl
|
| 109 |
+
if dataset_impl_k is None:
|
| 110 |
+
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
| 111 |
+
dataset = indexed_dataset.make_dataset(
|
| 112 |
+
path_k,
|
| 113 |
+
impl=dataset_impl_k or default,
|
| 114 |
+
fix_lua_indexing=True,
|
| 115 |
+
dictionary=dictionary,
|
| 116 |
+
)
|
| 117 |
+
if dataset is None:
|
| 118 |
+
break
|
| 119 |
+
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
|
| 120 |
+
datasets.append(dataset)
|
| 121 |
+
if not combine:
|
| 122 |
+
break
|
| 123 |
+
if len(datasets) == 0:
|
| 124 |
+
return None
|
| 125 |
+
elif len(datasets) == 1:
|
| 126 |
+
return datasets[0]
|
| 127 |
+
else:
|
| 128 |
+
return ConcatDataset(datasets)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@contextlib.contextmanager
|
| 132 |
+
def numpy_seed(seed, *addl_seeds):
|
| 133 |
+
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
| 134 |
+
restores the state afterward"""
|
| 135 |
+
if seed is None:
|
| 136 |
+
yield
|
| 137 |
+
return
|
| 138 |
+
if len(addl_seeds) > 0:
|
| 139 |
+
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
| 140 |
+
state = np.random.get_state()
|
| 141 |
+
np.random.seed(seed)
|
| 142 |
+
try:
|
| 143 |
+
yield
|
| 144 |
+
finally:
|
| 145 |
+
np.random.set_state(state)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def collect_filtered(function, iterable, filtered):
|
| 149 |
+
"""
|
| 150 |
+
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
function (callable): function that returns ``False`` for elements that
|
| 154 |
+
should be filtered
|
| 155 |
+
iterable (iterable): iterable to filter
|
| 156 |
+
filtered (list): list to store filtered elements
|
| 157 |
+
"""
|
| 158 |
+
for el in iterable:
|
| 159 |
+
if function(el):
|
| 160 |
+
yield el
|
| 161 |
+
else:
|
| 162 |
+
filtered.append(el)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
| 166 |
+
def compare_leq(a, b):
|
| 167 |
+
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
| 168 |
+
|
| 169 |
+
def check_size(idx):
|
| 170 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
| 171 |
+
return size_fn(idx) <= max_positions
|
| 172 |
+
elif isinstance(max_positions, dict):
|
| 173 |
+
idx_size = size_fn(idx)
|
| 174 |
+
assert isinstance(idx_size, dict)
|
| 175 |
+
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
| 176 |
+
return all(
|
| 177 |
+
all(
|
| 178 |
+
a is None or b is None or a <= b
|
| 179 |
+
for a, b in zip(idx_size[key], max_positions[key])
|
| 180 |
+
)
|
| 181 |
+
for key in intersect_keys
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
# For MultiCorpusSampledDataset, will generalize it later
|
| 185 |
+
if not isinstance(size_fn(idx), Iterable):
|
| 186 |
+
return all(size_fn(idx) <= b for b in max_positions)
|
| 187 |
+
return all(
|
| 188 |
+
a is None or b is None or a <= b
|
| 189 |
+
for a, b in zip(size_fn(idx), max_positions)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
ignored = []
|
| 193 |
+
itr = collect_filtered(check_size, indices, ignored)
|
| 194 |
+
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
| 195 |
+
return indices, ignored
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
| 199 |
+
"""
|
| 200 |
+
[deprecated] Filter indices based on their size.
|
| 201 |
+
Use `FairseqDataset::filter_indices_by_size` instead.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
indices (List[int]): ordered list of dataset indices
|
| 205 |
+
dataset (FairseqDataset): fairseq dataset instance
|
| 206 |
+
max_positions (tuple): filter elements larger than this size.
|
| 207 |
+
Comparisons are done component-wise.
|
| 208 |
+
raise_exception (bool, optional): if ``True``, raise an exception if
|
| 209 |
+
any elements are filtered (default: False).
|
| 210 |
+
"""
|
| 211 |
+
warnings.warn(
|
| 212 |
+
"data_utils.filter_by_size is deprecated. "
|
| 213 |
+
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
| 214 |
+
stacklevel=2,
|
| 215 |
+
)
|
| 216 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
| 217 |
+
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
| 218 |
+
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
| 219 |
+
indices = indices[dataset.sizes[indices] <= max_positions]
|
| 220 |
+
elif (
|
| 221 |
+
hasattr(dataset, "sizes")
|
| 222 |
+
and isinstance(dataset.sizes, list)
|
| 223 |
+
and len(dataset.sizes) == 1
|
| 224 |
+
):
|
| 225 |
+
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
| 226 |
+
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
| 227 |
+
else:
|
| 228 |
+
indices, ignored = _filter_by_size_dynamic(
|
| 229 |
+
indices, dataset.size, max_positions
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
| 233 |
+
|
| 234 |
+
if len(ignored) > 0 and raise_exception:
|
| 235 |
+
raise Exception(
|
| 236 |
+
(
|
| 237 |
+
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
| 238 |
+
"skip this example with --skip-invalid-size-inputs-valid-test"
|
| 239 |
+
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
| 240 |
+
)
|
| 241 |
+
if len(ignored) > 0:
|
| 242 |
+
logger.warning(
|
| 243 |
+
(
|
| 244 |
+
"{} samples have invalid sizes and will be skipped, "
|
| 245 |
+
"max_positions={}, first few sample ids={}"
|
| 246 |
+
).format(len(ignored), max_positions, ignored[:10])
|
| 247 |
+
)
|
| 248 |
+
return indices
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
| 252 |
+
"""Filter a list of sample indices. Remove those that are longer
|
| 253 |
+
than specified in max_sizes.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
indices (np.array): original array of sample indices
|
| 257 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
| 258 |
+
can be defined separately for src and tgt (then list or tuple)
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
np.array: filtered sample array
|
| 262 |
+
list: list of removed indices
|
| 263 |
+
"""
|
| 264 |
+
if max_sizes is None:
|
| 265 |
+
return indices, []
|
| 266 |
+
if type(max_sizes) in (int, float):
|
| 267 |
+
max_src_size, max_tgt_size = max_sizes, max_sizes
|
| 268 |
+
else:
|
| 269 |
+
max_src_size, max_tgt_size = max_sizes
|
| 270 |
+
if tgt_sizes is None:
|
| 271 |
+
ignored = indices[src_sizes[indices] > max_src_size]
|
| 272 |
+
else:
|
| 273 |
+
ignored = indices[
|
| 274 |
+
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
| 275 |
+
]
|
| 276 |
+
if len(ignored) > 0:
|
| 277 |
+
if tgt_sizes is None:
|
| 278 |
+
indices = indices[src_sizes[indices] <= max_src_size]
|
| 279 |
+
else:
|
| 280 |
+
indices = indices[
|
| 281 |
+
(src_sizes[indices] <= max_src_size)
|
| 282 |
+
& (tgt_sizes[indices] <= max_tgt_size)
|
| 283 |
+
]
|
| 284 |
+
return indices, ignored.tolist()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def batch_by_size(
|
| 288 |
+
indices,
|
| 289 |
+
num_tokens_fn,
|
| 290 |
+
num_tokens_vec=None,
|
| 291 |
+
max_tokens=None,
|
| 292 |
+
max_sentences=None,
|
| 293 |
+
required_batch_size_multiple=1,
|
| 294 |
+
fixed_shapes=None,
|
| 295 |
+
):
|
| 296 |
+
"""
|
| 297 |
+
Yield mini-batches of indices bucketed by size. Batches may contain
|
| 298 |
+
sequences of different lengths.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
indices (List[int]): ordered list of dataset indices
|
| 302 |
+
num_tokens_fn (callable): function that returns the number of tokens at
|
| 303 |
+
a given index
|
| 304 |
+
num_tokens_vec (List[int], optional): precomputed vector of the number
|
| 305 |
+
of tokens for each index in indices (to enable faster batch generation)
|
| 306 |
+
max_tokens (int, optional): max number of tokens in each batch
|
| 307 |
+
(default: None).
|
| 308 |
+
max_sentences (int, optional): max number of sentences in each
|
| 309 |
+
batch (default: None).
|
| 310 |
+
required_batch_size_multiple (int, optional): require batch size to
|
| 311 |
+
be less than N or a multiple of N (default: 1).
|
| 312 |
+
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
| 313 |
+
only be created with the given shapes. *max_sentences* and
|
| 314 |
+
*required_batch_size_multiple* will be ignored (default: None).
|
| 315 |
+
"""
|
| 316 |
+
try:
|
| 317 |
+
from fairseq.data.data_utils_fast import (
|
| 318 |
+
batch_by_size_fn,
|
| 319 |
+
batch_by_size_vec,
|
| 320 |
+
batch_fixed_shapes_fast,
|
| 321 |
+
)
|
| 322 |
+
except ImportError:
|
| 323 |
+
raise ImportError(
|
| 324 |
+
"Please build Cython components with: "
|
| 325 |
+
"`python setup.py build_ext --inplace`"
|
| 326 |
+
)
|
| 327 |
+
except ValueError:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# added int() to avoid TypeError: an integer is required
|
| 333 |
+
max_tokens = (
|
| 334 |
+
int(max_tokens) if max_tokens is not None else -1
|
| 335 |
+
)
|
| 336 |
+
max_sentences = max_sentences if max_sentences is not None else -1
|
| 337 |
+
bsz_mult = required_batch_size_multiple
|
| 338 |
+
|
| 339 |
+
if not isinstance(indices, np.ndarray):
|
| 340 |
+
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
| 341 |
+
|
| 342 |
+
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
|
| 343 |
+
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
|
| 344 |
+
|
| 345 |
+
if fixed_shapes is None:
|
| 346 |
+
if num_tokens_vec is None:
|
| 347 |
+
return batch_by_size_fn(
|
| 348 |
+
indices,
|
| 349 |
+
num_tokens_fn,
|
| 350 |
+
max_tokens,
|
| 351 |
+
max_sentences,
|
| 352 |
+
bsz_mult,
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
return batch_by_size_vec(
|
| 356 |
+
indices,
|
| 357 |
+
num_tokens_vec,
|
| 358 |
+
max_tokens,
|
| 359 |
+
max_sentences,
|
| 360 |
+
bsz_mult,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
else:
|
| 364 |
+
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
| 365 |
+
sort_order = np.lexsort(
|
| 366 |
+
[
|
| 367 |
+
fixed_shapes[:, 1].argsort(), # length
|
| 368 |
+
fixed_shapes[:, 0].argsort(), # bsz
|
| 369 |
+
]
|
| 370 |
+
)
|
| 371 |
+
fixed_shapes_sorted = fixed_shapes[sort_order]
|
| 372 |
+
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def post_process(sentence: str, symbol: str):
|
| 376 |
+
if symbol == "sentencepiece":
|
| 377 |
+
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
| 378 |
+
elif symbol == "wordpiece":
|
| 379 |
+
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
| 380 |
+
elif symbol == "letter":
|
| 381 |
+
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
| 382 |
+
elif symbol == "silence":
|
| 383 |
+
import re
|
| 384 |
+
sentence = sentence.replace("<SIL>", "")
|
| 385 |
+
sentence = re.sub(' +', ' ', sentence).strip()
|
| 386 |
+
elif symbol == "_EOW":
|
| 387 |
+
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
| 388 |
+
elif symbol in {"subword_nmt", "@@ ", "@@"}:
|
| 389 |
+
if symbol == "subword_nmt":
|
| 390 |
+
symbol = "@@ "
|
| 391 |
+
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
| 392 |
+
elif symbol == "none":
|
| 393 |
+
pass
|
| 394 |
+
elif symbol is not None:
|
| 395 |
+
raise NotImplementedError(f"Unknown post_process option: {symbol}")
|
| 396 |
+
return sentence
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def compute_mask_indices(
|
| 400 |
+
shape: Tuple[int, int],
|
| 401 |
+
padding_mask: Optional[torch.Tensor],
|
| 402 |
+
mask_prob: float,
|
| 403 |
+
mask_length: int,
|
| 404 |
+
mask_type: str = "static",
|
| 405 |
+
mask_other: float = 0.0,
|
| 406 |
+
min_masks: int = 0,
|
| 407 |
+
no_overlap: bool = False,
|
| 408 |
+
min_space: int = 0,
|
| 409 |
+
) -> np.ndarray:
|
| 410 |
+
"""
|
| 411 |
+
Computes random mask spans for a given shape
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
shape: the the shape for which to compute masks.
|
| 415 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
| 416 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
| 417 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 418 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 419 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
| 420 |
+
mask_type: how to compute mask lengths
|
| 421 |
+
static = fixed size
|
| 422 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
| 423 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
| 424 |
+
poisson = sample from possion distribution with lambda = mask length
|
| 425 |
+
min_masks: minimum number of masked spans
|
| 426 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
| 427 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
bsz, all_sz = shape
|
| 431 |
+
mask = np.full((bsz, all_sz), False)
|
| 432 |
+
|
| 433 |
+
all_num_mask = int(
|
| 434 |
+
# add a random number for probabilistic rounding
|
| 435 |
+
mask_prob * all_sz / float(mask_length)
|
| 436 |
+
+ np.random.rand()
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
all_num_mask = max(min_masks, all_num_mask)
|
| 440 |
+
|
| 441 |
+
mask_idcs = []
|
| 442 |
+
for i in range(bsz):
|
| 443 |
+
if padding_mask is not None:
|
| 444 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
| 445 |
+
num_mask = int(
|
| 446 |
+
# add a random number for probabilistic rounding
|
| 447 |
+
mask_prob * sz / float(mask_length)
|
| 448 |
+
+ np.random.rand()
|
| 449 |
+
)
|
| 450 |
+
num_mask = max(min_masks, num_mask)
|
| 451 |
+
else:
|
| 452 |
+
sz = all_sz
|
| 453 |
+
num_mask = all_num_mask
|
| 454 |
+
|
| 455 |
+
if mask_type == "static":
|
| 456 |
+
lengths = np.full(num_mask, mask_length)
|
| 457 |
+
elif mask_type == "uniform":
|
| 458 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 459 |
+
elif mask_type == "normal":
|
| 460 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
| 461 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
| 462 |
+
elif mask_type == "poisson":
|
| 463 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
| 464 |
+
lengths = [int(round(x)) for x in lengths]
|
| 465 |
+
else:
|
| 466 |
+
raise Exception("unknown mask selection " + mask_type)
|
| 467 |
+
|
| 468 |
+
if sum(lengths) == 0:
|
| 469 |
+
lengths[0] = min(mask_length, sz - 1)
|
| 470 |
+
|
| 471 |
+
if no_overlap:
|
| 472 |
+
mask_idc = []
|
| 473 |
+
|
| 474 |
+
def arrange(s, e, length, keep_length):
|
| 475 |
+
span_start = np.random.randint(s, e - length)
|
| 476 |
+
mask_idc.extend(span_start + i for i in range(length))
|
| 477 |
+
|
| 478 |
+
new_parts = []
|
| 479 |
+
if span_start - s - min_space >= keep_length:
|
| 480 |
+
new_parts.append((s, span_start - min_space + 1))
|
| 481 |
+
if e - span_start - keep_length - min_space > keep_length:
|
| 482 |
+
new_parts.append((span_start + length + min_space, e))
|
| 483 |
+
return new_parts
|
| 484 |
+
|
| 485 |
+
parts = [(0, sz)]
|
| 486 |
+
min_length = min(lengths)
|
| 487 |
+
for length in sorted(lengths, reverse=True):
|
| 488 |
+
lens = np.fromiter(
|
| 489 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
| 490 |
+
np.int,
|
| 491 |
+
)
|
| 492 |
+
l_sum = np.sum(lens)
|
| 493 |
+
if l_sum == 0:
|
| 494 |
+
break
|
| 495 |
+
probs = lens / np.sum(lens)
|
| 496 |
+
c = np.random.choice(len(parts), p=probs)
|
| 497 |
+
s, e = parts.pop(c)
|
| 498 |
+
parts.extend(arrange(s, e, length, min_length))
|
| 499 |
+
mask_idc = np.asarray(mask_idc)
|
| 500 |
+
else:
|
| 501 |
+
min_len = min(lengths)
|
| 502 |
+
if sz - min_len <= num_mask:
|
| 503 |
+
min_len = sz - num_mask - 1
|
| 504 |
+
|
| 505 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
| 506 |
+
|
| 507 |
+
mask_idc = np.asarray(
|
| 508 |
+
[
|
| 509 |
+
mask_idc[j] + offset
|
| 510 |
+
for j in range(len(mask_idc))
|
| 511 |
+
for offset in range(lengths[j])
|
| 512 |
+
]
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
| 516 |
+
|
| 517 |
+
min_len = min([len(m) for m in mask_idcs])
|
| 518 |
+
for i, mask_idc in enumerate(mask_idcs):
|
| 519 |
+
if len(mask_idc) > min_len:
|
| 520 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
| 521 |
+
mask[i, mask_idc] = True
|
| 522 |
+
|
| 523 |
+
return mask
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def get_mem_usage():
|
| 527 |
+
try:
|
| 528 |
+
import psutil
|
| 529 |
+
|
| 530 |
+
mb = 1024 * 1024
|
| 531 |
+
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
| 532 |
+
except ImportError:
|
| 533 |
+
return "N/A"
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# lens: torch.LongTensor
|
| 537 |
+
# returns: torch.BoolTensor
|
| 538 |
+
def lengths_to_padding_mask(lens):
|
| 539 |
+
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
| 540 |
+
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
| 541 |
+
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
| 542 |
+
return mask
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# lens: torch.LongTensor
|
| 546 |
+
# returns: torch.BoolTensor
|
| 547 |
+
def lengths_to_mask(lens):
|
| 548 |
+
return ~lengths_to_padding_mask(lens)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def get_buckets(sizes, num_buckets):
|
| 552 |
+
buckets = np.unique(
|
| 553 |
+
np.percentile(
|
| 554 |
+
sizes,
|
| 555 |
+
np.linspace(0, 100, num_buckets + 1),
|
| 556 |
+
interpolation='lower',
|
| 557 |
+
)[1:]
|
| 558 |
+
)
|
| 559 |
+
return buckets
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def get_bucketed_sizes(orig_sizes, buckets):
|
| 563 |
+
sizes = np.copy(orig_sizes)
|
| 564 |
+
assert np.min(sizes) >= 0
|
| 565 |
+
start_val = -1
|
| 566 |
+
for end_val in buckets:
|
| 567 |
+
mask = (sizes > start_val) & (sizes <= end_val)
|
| 568 |
+
sizes[mask] = end_val
|
| 569 |
+
start_val = end_val
|
| 570 |
+
return sizes
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def _find_extra_valid_paths(dataset_path: str) -> set:
|
| 575 |
+
paths = utils.split_paths(dataset_path)
|
| 576 |
+
all_valid_paths = set()
|
| 577 |
+
for sub_dir in paths:
|
| 578 |
+
contents = PathManager.ls(sub_dir)
|
| 579 |
+
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
|
| 580 |
+
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
|
| 581 |
+
# Remove .bin, .idx etc
|
| 582 |
+
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
|
| 583 |
+
return roots
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
|
| 587 |
+
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
|
| 588 |
+
if (
|
| 589 |
+
train_cfg.dataset.ignore_unused_valid_subsets
|
| 590 |
+
or train_cfg.dataset.combine_valid_subsets
|
| 591 |
+
or train_cfg.dataset.disable_validation
|
| 592 |
+
or not hasattr(train_cfg.task, "data")
|
| 593 |
+
):
|
| 594 |
+
return
|
| 595 |
+
other_paths = _find_extra_valid_paths(train_cfg.task.data)
|
| 596 |
+
specified_subsets = train_cfg.dataset.valid_subset.split(",")
|
| 597 |
+
ignored_paths = [p for p in other_paths if p not in specified_subsets]
|
| 598 |
+
if ignored_paths:
|
| 599 |
+
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
|
| 600 |
+
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
|
| 601 |
+
raise ValueError(msg)
|
data/file_dataset.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import pickle
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FileDataset:
|
| 7 |
+
def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
|
| 8 |
+
self.file_path = file_path
|
| 9 |
+
assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
|
| 10 |
+
|
| 11 |
+
self.separator = separator
|
| 12 |
+
if selected_col_ids is None:
|
| 13 |
+
# default to all fields
|
| 14 |
+
self.selected_col_ids = list(
|
| 15 |
+
range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
|
| 16 |
+
else:
|
| 17 |
+
self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
|
| 18 |
+
if dtypes is None:
|
| 19 |
+
# default to str
|
| 20 |
+
self.dtypes = [str for col_id in self.selected_col_ids]
|
| 21 |
+
else:
|
| 22 |
+
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
|
| 23 |
+
assert len(self.dtypes) == len(self.selected_col_ids)
|
| 24 |
+
|
| 25 |
+
self.data_cnt = 0
|
| 26 |
+
try:
|
| 27 |
+
self.slice_id = torch.distributed.get_rank()
|
| 28 |
+
self.slice_count = torch.distributed.get_world_size()
|
| 29 |
+
except Exception:
|
| 30 |
+
self.slice_id = 0
|
| 31 |
+
self.slice_count = 1
|
| 32 |
+
self.cached_index = cached_index
|
| 33 |
+
self._init_seek_index()
|
| 34 |
+
self._reader = self._get_reader()
|
| 35 |
+
print("file {} slice_id {} row count {} total row count {}".format(
|
| 36 |
+
self.file_path, self.slice_id, self.row_count, self.total_row_count)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def _init_seek_index(self):
|
| 40 |
+
if self.cached_index:
|
| 41 |
+
cache_path = "{}.index".format(self.file_path)
|
| 42 |
+
assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
|
| 43 |
+
self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
|
| 44 |
+
print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
|
| 45 |
+
self.file_path, self.slice_id))
|
| 46 |
+
else:
|
| 47 |
+
# make an iteration over the file to get row_count and line_idx-to-offset mapping
|
| 48 |
+
fp = open(self.file_path, "r")
|
| 49 |
+
print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
|
| 50 |
+
self.file_path, self.slice_id))
|
| 51 |
+
self.total_row_count = 0
|
| 52 |
+
offset = 0
|
| 53 |
+
self.lineid_to_offset = []
|
| 54 |
+
for line in fp:
|
| 55 |
+
self.lineid_to_offset.append(offset)
|
| 56 |
+
self.total_row_count += 1
|
| 57 |
+
offset += len(line)
|
| 58 |
+
self._compute_start_pos_and_row_count()
|
| 59 |
+
print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
|
| 60 |
+
self.file_path, self.slice_id))
|
| 61 |
+
|
| 62 |
+
def _compute_start_pos_and_row_count(self):
|
| 63 |
+
self.row_count = self.total_row_count // self.slice_count
|
| 64 |
+
if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
|
| 65 |
+
self.row_count += 1
|
| 66 |
+
self.start_pos = self.row_count * self.slice_id
|
| 67 |
+
else:
|
| 68 |
+
self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
|
| 69 |
+
|
| 70 |
+
def _get_reader(self):
|
| 71 |
+
fp = open(self.file_path, "r")
|
| 72 |
+
fp.seek(self.lineid_to_offset[self.start_pos])
|
| 73 |
+
return fp
|
| 74 |
+
|
| 75 |
+
def _seek(self, offset=0):
|
| 76 |
+
try:
|
| 77 |
+
print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
|
| 78 |
+
self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
|
| 79 |
+
self.data_cnt = offset
|
| 80 |
+
except Exception:
|
| 81 |
+
print("slice_id {} seek offset {}".format(self.slice_id, offset))
|
| 82 |
+
self._reader.seek(self.lineid_to_offset[offset])
|
| 83 |
+
self.data_cnt = offset
|
| 84 |
+
|
| 85 |
+
def __del__(self):
|
| 86 |
+
self._reader.close()
|
| 87 |
+
|
| 88 |
+
def __len__(self):
|
| 89 |
+
return self.row_count
|
| 90 |
+
|
| 91 |
+
def get_total_row_count(self):
|
| 92 |
+
return self.total_row_count
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, index):
|
| 95 |
+
if self.data_cnt == self.row_count:
|
| 96 |
+
print("reach the end of datafile, start a new reader")
|
| 97 |
+
self.data_cnt = 0
|
| 98 |
+
self._reader = self._get_reader()
|
| 99 |
+
column_l = self._reader.readline().rstrip("\n").split(self.separator)
|
| 100 |
+
self.data_cnt += 1
|
| 101 |
+
column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
|
| 102 |
+
return column_l
|
data/mm_data/__init__.py
ADDED
|
File without changes
|
data/mm_data/caption_dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import warnings
|
| 9 |
+
import string
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import base64
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
|
| 16 |
+
from PIL import Image, ImageFile
|
| 17 |
+
|
| 18 |
+
from data import data_utils
|
| 19 |
+
from data.ofa_dataset import OFADataset
|
| 20 |
+
|
| 21 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 22 |
+
ImageFile.MAX_IMAGE_PIXELS = None
|
| 23 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
| 27 |
+
|
| 28 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 29 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def collate(samples, pad_idx, eos_idx):
|
| 33 |
+
if len(samples) == 0:
|
| 34 |
+
return {}
|
| 35 |
+
|
| 36 |
+
def merge(key):
|
| 37 |
+
return data_utils.collate_tokens(
|
| 38 |
+
[s[key] for s in samples],
|
| 39 |
+
pad_idx,
|
| 40 |
+
eos_idx=eos_idx,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
id = np.array([s["id"] for s in samples])
|
| 44 |
+
src_tokens = merge("source")
|
| 45 |
+
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
| 46 |
+
|
| 47 |
+
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
| 48 |
+
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
| 49 |
+
|
| 50 |
+
prev_output_tokens = None
|
| 51 |
+
target = None
|
| 52 |
+
if samples[0].get("target", None) is not None:
|
| 53 |
+
target = merge("target")
|
| 54 |
+
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
|
| 55 |
+
ntokens = tgt_lengths.sum().item()
|
| 56 |
+
|
| 57 |
+
if samples[0].get("prev_output_tokens", None) is not None:
|
| 58 |
+
prev_output_tokens = merge("prev_output_tokens")
|
| 59 |
+
else:
|
| 60 |
+
ntokens = src_lengths.sum().item()
|
| 61 |
+
|
| 62 |
+
batch = {
|
| 63 |
+
"id": id,
|
| 64 |
+
"nsentences": len(samples),
|
| 65 |
+
"ntokens": ntokens,
|
| 66 |
+
"net_input": {
|
| 67 |
+
"src_tokens": src_tokens,
|
| 68 |
+
"src_lengths": src_lengths,
|
| 69 |
+
"patch_images": patch_images,
|
| 70 |
+
"patch_masks": patch_masks,
|
| 71 |
+
"prev_output_tokens": prev_output_tokens
|
| 72 |
+
},
|
| 73 |
+
"target": target,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return batch
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CaptionDataset(OFADataset):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
split,
|
| 83 |
+
dataset,
|
| 84 |
+
bpe,
|
| 85 |
+
src_dict,
|
| 86 |
+
tgt_dict=None,
|
| 87 |
+
max_src_length=128,
|
| 88 |
+
max_tgt_length=30,
|
| 89 |
+
patch_image_size=224,
|
| 90 |
+
imagenet_default_mean_and_std=False,
|
| 91 |
+
scst=False
|
| 92 |
+
):
|
| 93 |
+
self.split = split
|
| 94 |
+
self.dataset = dataset
|
| 95 |
+
self.bpe = bpe
|
| 96 |
+
self.src_dict = src_dict
|
| 97 |
+
self.tgt_dict = tgt_dict
|
| 98 |
+
self.max_src_length = max_src_length
|
| 99 |
+
self.max_tgt_length = max_tgt_length
|
| 100 |
+
|
| 101 |
+
self.patch_image_size = patch_image_size
|
| 102 |
+
self.scst = scst
|
| 103 |
+
|
| 104 |
+
self.bos = src_dict.bos()
|
| 105 |
+
self.eos = src_dict.eos()
|
| 106 |
+
self.pad = src_dict.pad()
|
| 107 |
+
self.bos_item = torch.LongTensor([self.bos])
|
| 108 |
+
self.eos_item = torch.LongTensor([self.eos])
|
| 109 |
+
self.transtab = str.maketrans({key: None for key in string.punctuation})
|
| 110 |
+
|
| 111 |
+
if imagenet_default_mean_and_std:
|
| 112 |
+
mean = IMAGENET_DEFAULT_MEAN
|
| 113 |
+
std = IMAGENET_DEFAULT_STD
|
| 114 |
+
else:
|
| 115 |
+
mean = [0.5, 0.5, 0.5]
|
| 116 |
+
std = [0.5, 0.5, 0.5]
|
| 117 |
+
|
| 118 |
+
self.patch_resize_transform = transforms.Compose([
|
| 119 |
+
lambda image: image.convert("RGB"),
|
| 120 |
+
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
| 121 |
+
transforms.ToTensor(),
|
| 122 |
+
transforms.Normalize(mean=mean, std=std),
|
| 123 |
+
])
|
| 124 |
+
|
| 125 |
+
def __getitem__(self, index):
|
| 126 |
+
uniq_id, image, caption = self.dataset[index]
|
| 127 |
+
|
| 128 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
| 129 |
+
patch_image = self.patch_resize_transform(image)
|
| 130 |
+
patch_mask = torch.tensor([True])
|
| 131 |
+
|
| 132 |
+
if self.split == 'train' and not self.scst:
|
| 133 |
+
caption = caption.translate(self.transtab).strip()
|
| 134 |
+
caption_token_list = caption.strip().split()
|
| 135 |
+
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
|
| 136 |
+
else:
|
| 137 |
+
caption = ' '.join(caption.strip().split())
|
| 138 |
+
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
|
| 139 |
+
tgt_caption = '&&'.join(caption_list)
|
| 140 |
+
src_item = self.encode_text(" what does the image describe?")
|
| 141 |
+
tgt_item = self.encode_text(" {}".format(tgt_caption))
|
| 142 |
+
|
| 143 |
+
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
| 144 |
+
target_item = torch.cat([tgt_item, self.eos_item])
|
| 145 |
+
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
| 146 |
+
|
| 147 |
+
example = {
|
| 148 |
+
"id": uniq_id,
|
| 149 |
+
"source": src_item,
|
| 150 |
+
"patch_image": patch_image,
|
| 151 |
+
"patch_mask": patch_mask,
|
| 152 |
+
"target": target_item,
|
| 153 |
+
"prev_output_tokens": prev_output_item
|
| 154 |
+
}
|
| 155 |
+
return example
|
| 156 |
+
|
| 157 |
+
def collater(self, samples, pad_to_length=None):
|
| 158 |
+
"""Merge a list of samples to form a mini-batch.
|
| 159 |
+
Args:
|
| 160 |
+
samples (List[dict]): samples to collate
|
| 161 |
+
Returns:
|
| 162 |
+
dict: a mini-batch with the following keys:
|
| 163 |
+
"""
|
| 164 |
+
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
data/ofa_dataset.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch.utils.data
|
| 3 |
+
from fairseq.data import FairseqDataset
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class OFADataset(FairseqDataset):
|
| 9 |
+
|
| 10 |
+
def __len__(self):
|
| 11 |
+
return len(self.dataset)
|
| 12 |
+
|
| 13 |
+
def encode_text(self, text, length=None, append_bos=False, append_eos=False):
|
| 14 |
+
s = self.tgt_dict.encode_line(
|
| 15 |
+
line=self.bpe.encode(text),
|
| 16 |
+
add_if_not_exist=False,
|
| 17 |
+
append_eos=False
|
| 18 |
+
).long()
|
| 19 |
+
if length is not None:
|
| 20 |
+
s = s[:length]
|
| 21 |
+
if append_bos:
|
| 22 |
+
s = torch.cat([self.bos_item, s])
|
| 23 |
+
if append_eos:
|
| 24 |
+
s = torch.cat([s, self.eos_item])
|
| 25 |
+
return s
|
datasets.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Datasets
|
| 2 |
+
|
| 3 |
+
We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
|
| 4 |
+
|
| 5 |
+
## Finetuning
|
| 6 |
+
|
| 7 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
|
evaluate.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import json
|
| 11 |
+
from itertools import chain
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from fairseq import distributed_utils, options, tasks, utils
|
| 17 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 18 |
+
from fairseq.logging import progress_bar
|
| 19 |
+
from fairseq.utils import reset_logging
|
| 20 |
+
from omegaconf import DictConfig
|
| 21 |
+
|
| 22 |
+
from utils import checkpoint_utils
|
| 23 |
+
from utils.eval_utils import eval_step
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 27 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 28 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 29 |
+
stream=sys.stdout,
|
| 30 |
+
)
|
| 31 |
+
logger = logging.getLogger("ofa.evaluate")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def apply_half(t):
|
| 35 |
+
if t.dtype is torch.float32:
|
| 36 |
+
return t.to(dtype=torch.half)
|
| 37 |
+
return t
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main(cfg: DictConfig):
|
| 41 |
+
utils.import_user_module(cfg.common)
|
| 42 |
+
|
| 43 |
+
reset_logging()
|
| 44 |
+
logger.info(cfg)
|
| 45 |
+
|
| 46 |
+
assert (
|
| 47 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
| 48 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
| 49 |
+
|
| 50 |
+
# Fix seed for stochastic decoding
|
| 51 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
| 52 |
+
np.random.seed(cfg.common.seed)
|
| 53 |
+
utils.set_torch_seed(cfg.common.seed)
|
| 54 |
+
|
| 55 |
+
use_fp16 = cfg.common.fp16
|
| 56 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
| 57 |
+
|
| 58 |
+
if use_cuda:
|
| 59 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
| 60 |
+
|
| 61 |
+
# Load ensemble
|
| 62 |
+
overrides = eval(cfg.common_eval.model_overrides)
|
| 63 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
| 64 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
| 65 |
+
utils.split_paths(cfg.common_eval.path),
|
| 66 |
+
arg_overrides=overrides,
|
| 67 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
| 68 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
| 69 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
| 73 |
+
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
| 74 |
+
|
| 75 |
+
# Move models to GPU
|
| 76 |
+
for model in models:
|
| 77 |
+
model.eval()
|
| 78 |
+
if use_fp16:
|
| 79 |
+
model.half()
|
| 80 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
| 81 |
+
model.cuda()
|
| 82 |
+
model.prepare_for_inference_(cfg)
|
| 83 |
+
|
| 84 |
+
# Load dataset (possibly sharded)
|
| 85 |
+
itr = task.get_batch_iterator(
|
| 86 |
+
dataset=task.dataset(cfg.dataset.gen_subset),
|
| 87 |
+
max_tokens=cfg.dataset.max_tokens,
|
| 88 |
+
max_sentences=cfg.dataset.batch_size,
|
| 89 |
+
max_positions=utils.resolve_max_positions(
|
| 90 |
+
task.max_positions(), *[m.max_positions() for m in models]
|
| 91 |
+
),
|
| 92 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
| 93 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
| 94 |
+
seed=cfg.common.seed,
|
| 95 |
+
num_shards=cfg.distributed_training.distributed_world_size,
|
| 96 |
+
shard_id=cfg.distributed_training.distributed_rank,
|
| 97 |
+
num_workers=cfg.dataset.num_workers,
|
| 98 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
| 99 |
+
).next_epoch_itr(shuffle=False)
|
| 100 |
+
progress = progress_bar.progress_bar(
|
| 101 |
+
itr,
|
| 102 |
+
log_format=cfg.common.log_format,
|
| 103 |
+
log_interval=cfg.common.log_interval,
|
| 104 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Initialize generator
|
| 108 |
+
generator = task.build_generator(models, cfg.generation)
|
| 109 |
+
|
| 110 |
+
results = []
|
| 111 |
+
score_sum = torch.FloatTensor([0]).cuda()
|
| 112 |
+
score_cnt = torch.FloatTensor([0]).cuda()
|
| 113 |
+
for sample in progress:
|
| 114 |
+
if "net_input" not in sample:
|
| 115 |
+
continue
|
| 116 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
| 117 |
+
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
result, scores = eval_step(task, generator, models, sample)
|
| 120 |
+
results += result
|
| 121 |
+
score_sum += sum(scores) if scores is not None else 0
|
| 122 |
+
score_cnt += len(scores) if scores is not None else 0
|
| 123 |
+
progress.log({"sentences": sample["nsentences"]})
|
| 124 |
+
|
| 125 |
+
gather_results = None
|
| 126 |
+
if cfg.distributed_training.distributed_world_size > 1:
|
| 127 |
+
gather_results = [None for _ in range(dist.get_world_size())]
|
| 128 |
+
dist.all_gather_object(gather_results, results)
|
| 129 |
+
dist.all_reduce(score_sum.data)
|
| 130 |
+
dist.all_reduce(score_cnt.data)
|
| 131 |
+
if score_cnt.item() > 0:
|
| 132 |
+
logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
|
| 133 |
+
score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
|
| 134 |
+
))
|
| 135 |
+
|
| 136 |
+
if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
|
| 137 |
+
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
| 138 |
+
output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
|
| 139 |
+
gather_results = list(chain(*gather_results)) if gather_results is not None else results
|
| 140 |
+
with open(output_path, 'w') as fw:
|
| 141 |
+
json.dump(gather_results, fw)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def cli_main():
|
| 145 |
+
parser = options.get_generation_parser()
|
| 146 |
+
args = options.parse_args_and_arch(parser)
|
| 147 |
+
cfg = convert_namespace_to_omegaconf(args)
|
| 148 |
+
distributed_utils.call_main(cfg, main)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
cli_main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .ofa import OFAModel, ofa_base_architecture, ofa_large_architecture, ofa_huge_architecture
|
models/ofa/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .ofa import OFAModel, ofa_base_architecture, ofa_large_architecture, ofa_huge_architecture
|
models/ofa/ofa.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
OFA
|
| 7 |
+
"""
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from fairseq import utils
|
| 16 |
+
from fairseq.models import register_model, register_model_architecture
|
| 17 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
| 18 |
+
|
| 19 |
+
from .unify_transformer import TransformerModel
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@register_model("ofa")
|
| 25 |
+
class OFAModel(TransformerModel):
|
| 26 |
+
__jit_unused_properties__ = ["supported_targets"]
|
| 27 |
+
|
| 28 |
+
def __init__(self, args, encoder, decoder):
|
| 29 |
+
super().__init__(args, encoder, decoder)
|
| 30 |
+
|
| 31 |
+
# We follow BERT's random weight initialization
|
| 32 |
+
self.apply(init_bert_params)
|
| 33 |
+
|
| 34 |
+
self.classification_heads = nn.ModuleDict()
|
| 35 |
+
if hasattr(self.encoder, "dictionary"):
|
| 36 |
+
self.eos: int = self.encoder.dictionary.eos()
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def add_args(parser):
|
| 40 |
+
super(OFAModel, OFAModel).add_args(parser)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--pooler-dropout",
|
| 43 |
+
type=float,
|
| 44 |
+
metavar="D",
|
| 45 |
+
help="dropout probability in the masked_lm pooler layers",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--pooler-classifier",
|
| 49 |
+
type=str,
|
| 50 |
+
choices=['mlp', 'linear'],
|
| 51 |
+
help="type of pooler classifier",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--pooler-activation-fn",
|
| 55 |
+
choices=utils.get_available_activation_fns(),
|
| 56 |
+
help="activation function to use for pooler layer",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--spectral-norm-classification-head",
|
| 60 |
+
action="store_true",
|
| 61 |
+
help="Apply spectral normalization on the classification head",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def supported_targets(self):
|
| 66 |
+
return {"self"}
|
| 67 |
+
|
| 68 |
+
def forward(
|
| 69 |
+
self,
|
| 70 |
+
src_tokens,
|
| 71 |
+
src_lengths,
|
| 72 |
+
prev_output_tokens,
|
| 73 |
+
patch_images: Optional[torch.Tensor] = None,
|
| 74 |
+
patch_images_2: Optional[torch.Tensor] = None,
|
| 75 |
+
patch_masks: Optional[torch.Tensor] = None,
|
| 76 |
+
code_masks: Optional[torch.Tensor] = None,
|
| 77 |
+
sample_patch_num: Optional[int] = None,
|
| 78 |
+
features_only: bool = False,
|
| 79 |
+
classification_head_name: Optional[str] = None,
|
| 80 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
| 81 |
+
return_all_hiddens: bool = False,
|
| 82 |
+
alignment_layer: Optional[int] = None,
|
| 83 |
+
alignment_heads: Optional[int] = None,
|
| 84 |
+
):
|
| 85 |
+
if classification_head_name is not None:
|
| 86 |
+
features_only = True
|
| 87 |
+
|
| 88 |
+
encoder_out = self.encoder(
|
| 89 |
+
src_tokens,
|
| 90 |
+
src_lengths=src_lengths,
|
| 91 |
+
patch_images=patch_images,
|
| 92 |
+
patch_masks=patch_masks,
|
| 93 |
+
patch_images_2=patch_images_2,
|
| 94 |
+
token_embeddings=token_embeddings,
|
| 95 |
+
return_all_hiddens=return_all_hiddens,
|
| 96 |
+
sample_patch_num=sample_patch_num
|
| 97 |
+
)
|
| 98 |
+
x, extra = self.decoder(
|
| 99 |
+
prev_output_tokens,
|
| 100 |
+
code_masks=code_masks,
|
| 101 |
+
encoder_out=encoder_out,
|
| 102 |
+
features_only=features_only,
|
| 103 |
+
alignment_layer=alignment_layer,
|
| 104 |
+
alignment_heads=alignment_heads,
|
| 105 |
+
src_lengths=src_lengths,
|
| 106 |
+
return_all_hiddens=return_all_hiddens,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
pad = self.encoder.padding_idx
|
| 110 |
+
if classification_head_name is not None:
|
| 111 |
+
prev_lengths = prev_output_tokens.ne(pad).sum(1)
|
| 112 |
+
gather_index = prev_lengths[:, None, None].expand(x.size(0), 1, x.size(2)) - 1
|
| 113 |
+
sentence_representation = x.gather(1, gather_index).squeeze()
|
| 114 |
+
if self.classification_heads[classification_head_name].use_two_images:
|
| 115 |
+
hidden_size = sentence_representation.size(1)
|
| 116 |
+
sentence_representation = sentence_representation.view(-1, hidden_size * 2)
|
| 117 |
+
for k, head in self.classification_heads.items():
|
| 118 |
+
# for torch script only supports iteration
|
| 119 |
+
if k == classification_head_name:
|
| 120 |
+
x = head(sentence_representation)
|
| 121 |
+
break
|
| 122 |
+
|
| 123 |
+
return x, extra
|
| 124 |
+
|
| 125 |
+
def register_embedding_tokens(self, ans2label_dict, src_dict, bpe):
|
| 126 |
+
"""Register embedding tokens"""
|
| 127 |
+
logger.info("Registering embedding tokens")
|
| 128 |
+
self.ans_tensor_list = []
|
| 129 |
+
for i in range(len(ans2label_dict)):
|
| 130 |
+
ans = src_dict[-len(ans2label_dict)+i]
|
| 131 |
+
ans = ans[5:-1].replace('_', ' ')
|
| 132 |
+
ans_tensor = src_dict.encode_line(
|
| 133 |
+
line=bpe.encode(' {}'.format(ans.lower())),
|
| 134 |
+
add_if_not_exist=False,
|
| 135 |
+
append_eos=False
|
| 136 |
+
).long()
|
| 137 |
+
self.ans_tensor_list.append(ans_tensor)
|
| 138 |
+
|
| 139 |
+
def register_classification_head(
|
| 140 |
+
self, name, num_classes=None, inner_dim=None, use_two_images=False, **kwargs
|
| 141 |
+
):
|
| 142 |
+
"""Register a classification head."""
|
| 143 |
+
logger.info("Registering classification head: {0}".format(name))
|
| 144 |
+
if name in self.classification_heads:
|
| 145 |
+
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
| 146 |
+
prev_inner_dim = self.classification_heads[name].dense.out_features
|
| 147 |
+
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
| 148 |
+
logger.warning(
|
| 149 |
+
're-registering head "{}" with num_classes {} (prev: {}) '
|
| 150 |
+
"and inner_dim {} (prev: {})".format(
|
| 151 |
+
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
self.classification_heads[name] = OFAClassificationHead(
|
| 155 |
+
input_dim=self.args.encoder_embed_dim,
|
| 156 |
+
inner_dim=inner_dim or self.args.encoder_embed_dim,
|
| 157 |
+
num_classes=num_classes,
|
| 158 |
+
activation_fn=self.args.pooler_activation_fn,
|
| 159 |
+
pooler_dropout=self.args.pooler_dropout,
|
| 160 |
+
pooler_classifier=self.args.pooler_classifier,
|
| 161 |
+
use_two_images=use_two_images,
|
| 162 |
+
do_spectral_norm=getattr(
|
| 163 |
+
self.args, "spectral_norm_classification_head", False
|
| 164 |
+
),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 168 |
+
super().upgrade_state_dict_named(state_dict, name)
|
| 169 |
+
|
| 170 |
+
prefix = name + "." if name != "" else ""
|
| 171 |
+
current_head_names = (
|
| 172 |
+
[]
|
| 173 |
+
if not hasattr(self, "classification_heads")
|
| 174 |
+
else self.classification_heads.keys()
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Handle new classification heads present in the state dict.
|
| 178 |
+
keys_to_delete = []
|
| 179 |
+
for k in state_dict.keys():
|
| 180 |
+
if not k.startswith(prefix + "classification_heads."):
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
|
| 184 |
+
num_classes = state_dict[
|
| 185 |
+
prefix + "classification_heads." + head_name + ".out_proj.weight"
|
| 186 |
+
].size(0)
|
| 187 |
+
inner_dim = state_dict[
|
| 188 |
+
prefix + "classification_heads." + head_name + ".dense.weight"
|
| 189 |
+
].size(0)
|
| 190 |
+
|
| 191 |
+
if getattr(self.args, "load_checkpoint_heads", False):
|
| 192 |
+
if head_name not in current_head_names:
|
| 193 |
+
self.register_classification_head(head_name, num_classes, inner_dim)
|
| 194 |
+
else:
|
| 195 |
+
if head_name not in current_head_names:
|
| 196 |
+
logger.warning(
|
| 197 |
+
"deleting classification head ({}) from checkpoint "
|
| 198 |
+
"not present in current model: {}".format(head_name, k)
|
| 199 |
+
)
|
| 200 |
+
keys_to_delete.append(k)
|
| 201 |
+
elif (
|
| 202 |
+
num_classes
|
| 203 |
+
!= self.classification_heads[head_name].out_proj.out_features
|
| 204 |
+
or inner_dim
|
| 205 |
+
!= self.classification_heads[head_name].dense.out_features
|
| 206 |
+
):
|
| 207 |
+
logger.warning(
|
| 208 |
+
"deleting classification head ({}) from checkpoint "
|
| 209 |
+
"with different dimensions than current model: {}".format(
|
| 210 |
+
head_name, k
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
keys_to_delete.append(k)
|
| 214 |
+
for k in keys_to_delete:
|
| 215 |
+
del state_dict[k]
|
| 216 |
+
|
| 217 |
+
def truncate_emb(key):
|
| 218 |
+
if key in state_dict:
|
| 219 |
+
state_dict[key] = state_dict[key][:-1, :]
|
| 220 |
+
|
| 221 |
+
# When finetuning on translation task, remove last row of
|
| 222 |
+
# embedding matrix that corresponds to mask_idx token.
|
| 223 |
+
loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
|
| 224 |
+
if (
|
| 225 |
+
loaded_dict_size == len(self.encoder.dictionary) + 1
|
| 226 |
+
and "<mask>" not in self.encoder.dictionary
|
| 227 |
+
):
|
| 228 |
+
truncate_emb("encoder.embed_tokens.weight")
|
| 229 |
+
truncate_emb("decoder.embed_tokens.weight")
|
| 230 |
+
truncate_emb("encoder.output_projection.weight")
|
| 231 |
+
truncate_emb("decoder.output_projection.weight")
|
| 232 |
+
|
| 233 |
+
if loaded_dict_size < len(self.encoder.dictionary):
|
| 234 |
+
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
|
| 235 |
+
embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
|
| 236 |
+
|
| 237 |
+
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
|
| 238 |
+
if getattr(self, "ans_tensor_list", None):
|
| 239 |
+
assert len(new_lang_embed_to_add) == len(self.ans_tensor_list)
|
| 240 |
+
for i, ans_tensor in enumerate(self.ans_tensor_list):
|
| 241 |
+
ans_embed = F.embedding(ans_tensor, state_dict["encoder.embed_tokens.weight"])
|
| 242 |
+
ans_embed = ans_embed.sum(0) / ans_embed.size(0)
|
| 243 |
+
new_lang_embed_to_add[i] = ans_embed
|
| 244 |
+
else:
|
| 245 |
+
nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5)
|
| 246 |
+
new_lang_embed_to_add = new_lang_embed_to_add.to(
|
| 247 |
+
dtype=state_dict["encoder.embed_tokens.weight"].dtype,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
state_dict["encoder.embed_tokens.weight"] = torch.cat(
|
| 251 |
+
[state_dict["encoder.embed_tokens.weight"], new_lang_embed_to_add]
|
| 252 |
+
)
|
| 253 |
+
state_dict["decoder.embed_tokens.weight"] = torch.cat(
|
| 254 |
+
[state_dict["decoder.embed_tokens.weight"], new_lang_embed_to_add]
|
| 255 |
+
)
|
| 256 |
+
state_dict["decoder.output_projection.weight"] = torch.cat(
|
| 257 |
+
[state_dict["decoder.output_projection.weight"], new_lang_embed_to_add]
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Copy any newly-added classification heads into the state dict
|
| 261 |
+
# with their current weights.
|
| 262 |
+
if hasattr(self, "classification_heads"):
|
| 263 |
+
cur_state = self.classification_heads.state_dict()
|
| 264 |
+
for k, v in cur_state.items():
|
| 265 |
+
if prefix + "classification_heads." + k not in state_dict:
|
| 266 |
+
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
| 267 |
+
state_dict[prefix + "classification_heads." + k] = v
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class OFAClassificationHead(nn.Module):
|
| 271 |
+
"""Head for sentence-level classification tasks."""
|
| 272 |
+
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
input_dim,
|
| 276 |
+
inner_dim,
|
| 277 |
+
num_classes,
|
| 278 |
+
activation_fn,
|
| 279 |
+
pooler_dropout,
|
| 280 |
+
pooler_classifier,
|
| 281 |
+
use_two_images=False,
|
| 282 |
+
do_spectral_norm=False,
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.pooler_classifier = pooler_classifier
|
| 286 |
+
self.use_two_images = use_two_images
|
| 287 |
+
input_dim = input_dim * 2 if use_two_images else input_dim
|
| 288 |
+
if pooler_classifier == "mlp":
|
| 289 |
+
self.dense = nn.Linear(input_dim, inner_dim)
|
| 290 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
| 291 |
+
self.dropout = nn.Dropout(p=pooler_dropout)
|
| 292 |
+
self.out_proj = nn.Linear(inner_dim, num_classes)
|
| 293 |
+
elif pooler_classifier == "linear":
|
| 294 |
+
self.dropout = nn.Dropout(p=pooler_dropout)
|
| 295 |
+
self.out_proj = nn.Linear(input_dim, num_classes)
|
| 296 |
+
else:
|
| 297 |
+
raise NotImplementedError
|
| 298 |
+
|
| 299 |
+
if do_spectral_norm:
|
| 300 |
+
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
|
| 301 |
+
|
| 302 |
+
def forward(self, features, **kwargs):
|
| 303 |
+
if self.pooler_classifier == 'mlp':
|
| 304 |
+
x = features
|
| 305 |
+
x = self.dropout(x)
|
| 306 |
+
x = self.dense(x)
|
| 307 |
+
x = self.activation_fn(x)
|
| 308 |
+
x = self.dropout(x)
|
| 309 |
+
x = self.out_proj(x)
|
| 310 |
+
elif self.pooler_classifier == 'linear':
|
| 311 |
+
x = features
|
| 312 |
+
x = self.dropout(x)
|
| 313 |
+
x = self.out_proj(x)
|
| 314 |
+
else:
|
| 315 |
+
raise NotImplementedError
|
| 316 |
+
return x
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
@register_model_architecture("ofa", "ofa_large")
|
| 320 |
+
def ofa_large_architecture(args):
|
| 321 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 322 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 323 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
|
| 324 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
| 325 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 326 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
| 327 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
| 328 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 329 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 330 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 331 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 332 |
+
)
|
| 333 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
| 334 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
| 335 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
| 336 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
|
| 337 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 338 |
+
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
|
| 339 |
+
args.dropout = getattr(args, "dropout", 0.0)
|
| 340 |
+
args.max_target_positions = getattr(args, "max_target_positions", 1024)
|
| 341 |
+
args.max_source_positions = getattr(args, "max_source_positions", 1024)
|
| 342 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 343 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 344 |
+
args.share_decoder_input_output_embed = getattr(
|
| 345 |
+
args, "share_decoder_input_output_embed", True
|
| 346 |
+
)
|
| 347 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
|
| 348 |
+
|
| 349 |
+
args.decoder_output_dim = getattr(
|
| 350 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 351 |
+
)
|
| 352 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 353 |
+
|
| 354 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
| 355 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
| 356 |
+
|
| 357 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 358 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 359 |
+
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
| 360 |
+
args.pooler_classifier = getattr(args, "pooler_classifier", "mlp")
|
| 361 |
+
|
| 362 |
+
args.resnet_drop_path_rate = getattr(args, "resnet_drop_path_rate", 0.0)
|
| 363 |
+
args.encoder_drop_path_rate = getattr(args, "encoder_drop_path_rate", 0.0)
|
| 364 |
+
args.decoder_drop_path_rate = getattr(args, "decoder_drop_path_rate", 0.0)
|
| 365 |
+
|
| 366 |
+
args.resnet_type = getattr(args, "resnet_type", "resnet152")
|
| 367 |
+
args.token_bucket_size = getattr(args, "token_bucket_size", 256)
|
| 368 |
+
args.image_bucket_size = getattr(args, "image_bucket_size", 42)
|
| 369 |
+
|
| 370 |
+
args.freeze_encoder_embedding = getattr(args, "freeze_encoder_embedding", False)
|
| 371 |
+
args.freeze_decoder_embedding = getattr(args, "freeze_decoder_embedding", False)
|
| 372 |
+
args.add_type_embedding = getattr(args, "add_type_embedding", True)
|
| 373 |
+
args.attn_scale_factor = getattr(args, "attn_scale_factor", 2)
|
| 374 |
+
|
| 375 |
+
args.code_image_size = getattr(args, "code_image_size", 128)
|
| 376 |
+
args.patch_layernorm_embedding = getattr(args, "patch_layernorm_embedding", True)
|
| 377 |
+
args.code_layernorm_embedding = getattr(args, "code_layernorm_embedding", True)
|
| 378 |
+
args.entangle_position_embedding = getattr(args, "entangle_position_embedding", False)
|
| 379 |
+
args.disable_entangle = getattr(args, "disable_entangle", False)
|
| 380 |
+
args.sync_bn = getattr(args, "sync_bn", False)
|
| 381 |
+
|
| 382 |
+
args.scale_attn = getattr(args, "scale_attn", False)
|
| 383 |
+
args.scale_fc = getattr(args, "scale_fc", False)
|
| 384 |
+
args.scale_heads = getattr(args, "scale_heads", False)
|
| 385 |
+
args.scale_resids = getattr(args, "scale_resids", False)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@register_model_architecture("ofa", "ofa_base")
|
| 389 |
+
def ofa_base_architecture(args):
|
| 390 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
| 391 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
|
| 392 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 393 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
| 394 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 395 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
|
| 396 |
+
args.resnet_type = getattr(args, "resnet_type", "resnet101")
|
| 397 |
+
ofa_large_architecture(args)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@register_model_architecture("ofa", "ofa_huge")
|
| 401 |
+
def ofa_huge_architecture(args):
|
| 402 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280)
|
| 403 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1280)
|
| 404 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
| 405 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 406 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
| 407 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
| 408 |
+
args.resnet_type = getattr(args, "resnet_type", "resnet152")
|
| 409 |
+
ofa_large_architecture(args)
|
| 410 |
+
|
models/ofa/resnet.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 6 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 7 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 8 |
+
the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper...
|
| 9 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 10 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use
|
| 11 |
+
'survival rate' as the argument.
|
| 12 |
+
"""
|
| 13 |
+
if drop_prob == 0. or not training:
|
| 14 |
+
return x
|
| 15 |
+
keep_prob = 1 - drop_prob
|
| 16 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 17 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 18 |
+
random_tensor.floor_() # binarize
|
| 19 |
+
output = x.div(keep_prob) * random_tensor
|
| 20 |
+
return output
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DropPath(nn.Module):
|
| 24 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, drop_prob=None):
|
| 27 |
+
super(DropPath, self).__init__()
|
| 28 |
+
self.drop_prob = drop_prob
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 35 |
+
"""3x3 convolution with padding"""
|
| 36 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 37 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 41 |
+
"""1x1 convolution"""
|
| 42 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class BasicBlock(nn.Module):
|
| 46 |
+
expansion = 1
|
| 47 |
+
|
| 48 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 49 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 50 |
+
super(BasicBlock, self).__init__()
|
| 51 |
+
if norm_layer is None:
|
| 52 |
+
norm_layer = nn.BatchNorm2d
|
| 53 |
+
if groups != 1 or base_width != 64:
|
| 54 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 55 |
+
if dilation > 1:
|
| 56 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 57 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 58 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 59 |
+
self.bn1 = norm_layer(planes)
|
| 60 |
+
self.relu = nn.ReLU(inplace=True)
|
| 61 |
+
self.conv2 = conv3x3(planes, planes)
|
| 62 |
+
self.bn2 = norm_layer(planes)
|
| 63 |
+
self.downsample = downsample
|
| 64 |
+
self.stride = stride
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
assert False
|
| 68 |
+
identity = x
|
| 69 |
+
|
| 70 |
+
out = self.conv1(x)
|
| 71 |
+
out = self.bn1(out)
|
| 72 |
+
out = self.relu(out)
|
| 73 |
+
|
| 74 |
+
out = self.conv2(out)
|
| 75 |
+
out = self.bn2(out)
|
| 76 |
+
|
| 77 |
+
if self.downsample is not None:
|
| 78 |
+
identity = self.downsample(x)
|
| 79 |
+
|
| 80 |
+
out += identity
|
| 81 |
+
out = self.relu(out)
|
| 82 |
+
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Bottleneck(nn.Module):
|
| 87 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 88 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 89 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 90 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 91 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 92 |
+
|
| 93 |
+
expansion = 4
|
| 94 |
+
|
| 95 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 96 |
+
base_width=64, dilation=1, norm_layer=None, drop_path_rate=0.0):
|
| 97 |
+
super(Bottleneck, self).__init__()
|
| 98 |
+
if norm_layer is None:
|
| 99 |
+
norm_layer = nn.BatchNorm2d
|
| 100 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 101 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 102 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 103 |
+
self.bn1 = norm_layer(width)
|
| 104 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 105 |
+
self.bn2 = norm_layer(width)
|
| 106 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 107 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 108 |
+
self.relu = nn.ReLU(inplace=True)
|
| 109 |
+
self.downsample = downsample
|
| 110 |
+
self.stride = stride
|
| 111 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
identity = x
|
| 115 |
+
|
| 116 |
+
out = self.conv1(x)
|
| 117 |
+
out = self.bn1(out)
|
| 118 |
+
out = self.relu(out)
|
| 119 |
+
|
| 120 |
+
out = self.conv2(out)
|
| 121 |
+
out = self.bn2(out)
|
| 122 |
+
out = self.relu(out)
|
| 123 |
+
|
| 124 |
+
out = self.conv3(out)
|
| 125 |
+
out = self.bn3(out)
|
| 126 |
+
|
| 127 |
+
if self.downsample is not None:
|
| 128 |
+
identity = self.downsample(x)
|
| 129 |
+
|
| 130 |
+
out = identity + self.drop_path(out)
|
| 131 |
+
out = self.relu(out)
|
| 132 |
+
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ResNet(nn.Module):
|
| 137 |
+
|
| 138 |
+
def __init__(self, layers, zero_init_residual=False,
|
| 139 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 140 |
+
norm_layer=None, drop_path_rate=0.0):
|
| 141 |
+
super(ResNet, self).__init__()
|
| 142 |
+
if norm_layer is None:
|
| 143 |
+
norm_layer = nn.BatchNorm2d
|
| 144 |
+
self._norm_layer = norm_layer
|
| 145 |
+
|
| 146 |
+
self.inplanes = 64
|
| 147 |
+
self.dilation = 1
|
| 148 |
+
if replace_stride_with_dilation is None:
|
| 149 |
+
# each element in the tuple indicates if we should replace
|
| 150 |
+
# the 2x2 stride with a dilated convolution instead
|
| 151 |
+
replace_stride_with_dilation = [False, False, False]
|
| 152 |
+
if len(replace_stride_with_dilation) != 3:
|
| 153 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 154 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 155 |
+
self.groups = groups
|
| 156 |
+
self.base_width = width_per_group
|
| 157 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
| 158 |
+
bias=False)
|
| 159 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 160 |
+
self.relu = nn.ReLU(inplace=True)
|
| 161 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 162 |
+
self.layer1 = self._make_layer(Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate)
|
| 163 |
+
self.layer2 = self._make_layer(Bottleneck, 128, layers[1], stride=2,
|
| 164 |
+
dilate=replace_stride_with_dilation[0], drop_path_rate=drop_path_rate)
|
| 165 |
+
self.layer3 = self._make_layer(Bottleneck, 256, layers[2], stride=2,
|
| 166 |
+
dilate=replace_stride_with_dilation[1], drop_path_rate=drop_path_rate)
|
| 167 |
+
|
| 168 |
+
for m in self.modules():
|
| 169 |
+
if isinstance(m, nn.Conv2d):
|
| 170 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 171 |
+
elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)):
|
| 172 |
+
nn.init.constant_(m.weight, 1)
|
| 173 |
+
nn.init.constant_(m.bias, 0)
|
| 174 |
+
|
| 175 |
+
# Zero-initialize the last BN in each residual branch,
|
| 176 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 177 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 178 |
+
if zero_init_residual:
|
| 179 |
+
for m in self.modules():
|
| 180 |
+
if isinstance(m, Bottleneck):
|
| 181 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 182 |
+
elif isinstance(m, BasicBlock):
|
| 183 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 184 |
+
|
| 185 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, drop_path_rate=0.0):
|
| 186 |
+
norm_layer = self._norm_layer
|
| 187 |
+
downsample = None
|
| 188 |
+
previous_dilation = self.dilation
|
| 189 |
+
if dilate:
|
| 190 |
+
self.dilation *= stride
|
| 191 |
+
stride = 1
|
| 192 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 193 |
+
downsample = nn.Sequential(
|
| 194 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 195 |
+
norm_layer(planes * block.expansion),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
layers = []
|
| 199 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 200 |
+
self.base_width, previous_dilation, norm_layer))
|
| 201 |
+
self.inplanes = planes * block.expansion
|
| 202 |
+
|
| 203 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)]
|
| 204 |
+
for i in range(1, blocks):
|
| 205 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 206 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 207 |
+
norm_layer=norm_layer, drop_path_rate=dpr[i]))
|
| 208 |
+
|
| 209 |
+
return nn.Sequential(*layers)
|
| 210 |
+
|
| 211 |
+
def _forward_impl(self, x):
|
| 212 |
+
# See note [TorchScript super()]
|
| 213 |
+
x = self.conv1(x)
|
| 214 |
+
x = self.bn1(x)
|
| 215 |
+
x = self.relu(x)
|
| 216 |
+
x = self.maxpool(x)
|
| 217 |
+
|
| 218 |
+
x = self.layer1(x)
|
| 219 |
+
x = self.layer2(x)
|
| 220 |
+
x = self.layer3(x)
|
| 221 |
+
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
return self._forward_impl(x)
|
models/ofa/unify_multihead_attention.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq import utils
|
| 12 |
+
from fairseq.incremental_decoding_utils import with_incremental_state
|
| 13 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
| 14 |
+
from fairseq.modules.quant_noise import quant_noise
|
| 15 |
+
from torch import Tensor, nn
|
| 16 |
+
from torch.nn import Parameter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@with_incremental_state
|
| 20 |
+
class MultiheadAttention(nn.Module):
|
| 21 |
+
"""Multi-headed attention.
|
| 22 |
+
|
| 23 |
+
See "Attention Is All You Need" for more details.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
embed_dim,
|
| 29 |
+
num_heads,
|
| 30 |
+
kdim=None,
|
| 31 |
+
vdim=None,
|
| 32 |
+
dropout=0.0,
|
| 33 |
+
bias=True,
|
| 34 |
+
add_bias_kv=False,
|
| 35 |
+
add_zero_attn=False,
|
| 36 |
+
self_attention=False,
|
| 37 |
+
encoder_decoder_attention=False,
|
| 38 |
+
q_noise=0.0,
|
| 39 |
+
qn_block_size=8,
|
| 40 |
+
scale_factor=2,
|
| 41 |
+
scale_heads=False
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.embed_dim = embed_dim
|
| 45 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 46 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 47 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 48 |
+
|
| 49 |
+
self.num_heads = num_heads
|
| 50 |
+
self.dropout_module = FairseqDropout(
|
| 51 |
+
dropout, module_name=self.__class__.__name__
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.head_dim = embed_dim // num_heads
|
| 55 |
+
assert (
|
| 56 |
+
self.head_dim * num_heads == self.embed_dim
|
| 57 |
+
), "embed_dim must be divisible by num_heads"
|
| 58 |
+
self.scaling = float(self.head_dim * scale_factor) ** -0.5
|
| 59 |
+
|
| 60 |
+
self.self_attention = self_attention
|
| 61 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 62 |
+
self.c_attn = nn.Parameter(torch.ones((self.num_heads,)), requires_grad=True) if scale_heads else None
|
| 63 |
+
|
| 64 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 65 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.k_proj = quant_noise(
|
| 69 |
+
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 70 |
+
)
|
| 71 |
+
self.v_proj = quant_noise(
|
| 72 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 73 |
+
)
|
| 74 |
+
self.q_proj = quant_noise(
|
| 75 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.out_proj = quant_noise(
|
| 79 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if add_bias_kv:
|
| 83 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 84 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 85 |
+
else:
|
| 86 |
+
self.bias_k = self.bias_v = None
|
| 87 |
+
|
| 88 |
+
self.add_zero_attn = add_zero_attn
|
| 89 |
+
|
| 90 |
+
self.reset_parameters()
|
| 91 |
+
|
| 92 |
+
self.onnx_trace = False
|
| 93 |
+
|
| 94 |
+
def prepare_for_onnx_export_(self):
|
| 95 |
+
self.onnx_trace = True
|
| 96 |
+
|
| 97 |
+
def reset_parameters(self):
|
| 98 |
+
if self.qkv_same_dim:
|
| 99 |
+
# Empirically observed the convergence to be much better with
|
| 100 |
+
# the scaled initialization
|
| 101 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 102 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 103 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 104 |
+
else:
|
| 105 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 106 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 107 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 108 |
+
|
| 109 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 110 |
+
if self.out_proj.bias is not None:
|
| 111 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 112 |
+
if self.bias_k is not None:
|
| 113 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 114 |
+
if self.bias_v is not None:
|
| 115 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
query,
|
| 120 |
+
key: Optional[Tensor],
|
| 121 |
+
value: Optional[Tensor],
|
| 122 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 123 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 124 |
+
need_weights: bool = True,
|
| 125 |
+
static_kv: bool = False,
|
| 126 |
+
attn_mask: Optional[Tensor] = None,
|
| 127 |
+
self_attn_mask: Optional[Tensor] = None,
|
| 128 |
+
before_softmax: bool = False,
|
| 129 |
+
need_head_weights: bool = False,
|
| 130 |
+
attn_bias: Optional[Tensor] = None
|
| 131 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 132 |
+
"""Input shape: Time x Batch x Channel
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 136 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 137 |
+
padding elements are indicated by 1s.
|
| 138 |
+
need_weights (bool, optional): return the attention weights,
|
| 139 |
+
averaged over heads (default: False).
|
| 140 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 141 |
+
implement causal attention, where the mask prevents the
|
| 142 |
+
attention from looking forward in time (default: None).
|
| 143 |
+
before_softmax (bool, optional): return the raw attention
|
| 144 |
+
weights and values before the attention softmax.
|
| 145 |
+
need_head_weights (bool, optional): return the attention
|
| 146 |
+
weights for each head. Implies *need_weights*. Default:
|
| 147 |
+
return the average attention weights over all heads.
|
| 148 |
+
"""
|
| 149 |
+
if need_head_weights:
|
| 150 |
+
need_weights = True
|
| 151 |
+
|
| 152 |
+
is_tpu = query.device.type == "xla"
|
| 153 |
+
|
| 154 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 155 |
+
src_len = tgt_len
|
| 156 |
+
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
| 157 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 158 |
+
if key is not None:
|
| 159 |
+
src_len, key_bsz, _ = key.size()
|
| 160 |
+
if not torch.jit.is_scripting():
|
| 161 |
+
assert key_bsz == bsz
|
| 162 |
+
assert value is not None
|
| 163 |
+
assert src_len, bsz == value.shape[:2]
|
| 164 |
+
|
| 165 |
+
if (
|
| 166 |
+
not self.onnx_trace
|
| 167 |
+
and not is_tpu # don't use PyTorch version on TPUs
|
| 168 |
+
and incremental_state is None
|
| 169 |
+
and not static_kv
|
| 170 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
| 171 |
+
# treats bias in linear module as method.
|
| 172 |
+
and not torch.jit.is_scripting()
|
| 173 |
+
and self_attn_mask is None
|
| 174 |
+
and attn_bias is None
|
| 175 |
+
):
|
| 176 |
+
assert key is not None and value is not None
|
| 177 |
+
return F.multi_head_attention_forward(
|
| 178 |
+
query,
|
| 179 |
+
key,
|
| 180 |
+
value,
|
| 181 |
+
self.embed_dim,
|
| 182 |
+
self.num_heads,
|
| 183 |
+
torch.empty([0]),
|
| 184 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
| 185 |
+
self.bias_k,
|
| 186 |
+
self.bias_v,
|
| 187 |
+
self.add_zero_attn,
|
| 188 |
+
self.dropout_module.p,
|
| 189 |
+
self.out_proj.weight,
|
| 190 |
+
self.out_proj.bias,
|
| 191 |
+
self.training or self.dropout_module.apply_during_inference,
|
| 192 |
+
key_padding_mask,
|
| 193 |
+
need_weights,
|
| 194 |
+
attn_mask,
|
| 195 |
+
use_separate_proj_weight=True,
|
| 196 |
+
q_proj_weight=self.q_proj.weight,
|
| 197 |
+
k_proj_weight=self.k_proj.weight,
|
| 198 |
+
v_proj_weight=self.v_proj.weight,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if incremental_state is not None:
|
| 202 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 203 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 204 |
+
# previous time steps are cached - no need to recompute
|
| 205 |
+
# key and value if they are static
|
| 206 |
+
if static_kv:
|
| 207 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 208 |
+
key = value = None
|
| 209 |
+
else:
|
| 210 |
+
saved_state = None
|
| 211 |
+
|
| 212 |
+
if self.self_attention and self_attn_mask is None:
|
| 213 |
+
q = self.q_proj(query)
|
| 214 |
+
k = self.k_proj(query)
|
| 215 |
+
v = self.v_proj(query)
|
| 216 |
+
elif self.encoder_decoder_attention:
|
| 217 |
+
# encoder-decoder attention
|
| 218 |
+
q = self.q_proj(query)
|
| 219 |
+
if key is None:
|
| 220 |
+
assert value is None
|
| 221 |
+
k = v = None
|
| 222 |
+
else:
|
| 223 |
+
k = self.k_proj(key)
|
| 224 |
+
v = self.v_proj(key)
|
| 225 |
+
|
| 226 |
+
else:
|
| 227 |
+
assert key is not None and value is not None
|
| 228 |
+
q = self.q_proj(query)
|
| 229 |
+
k = self.k_proj(key)
|
| 230 |
+
v = self.v_proj(value)
|
| 231 |
+
q *= self.scaling
|
| 232 |
+
|
| 233 |
+
if self.bias_k is not None:
|
| 234 |
+
assert self.bias_v is not None
|
| 235 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 236 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 237 |
+
if attn_mask is not None:
|
| 238 |
+
attn_mask = torch.cat(
|
| 239 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 240 |
+
)
|
| 241 |
+
if key_padding_mask is not None:
|
| 242 |
+
key_padding_mask = torch.cat(
|
| 243 |
+
[
|
| 244 |
+
key_padding_mask,
|
| 245 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 246 |
+
],
|
| 247 |
+
dim=1,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
q = (
|
| 251 |
+
q.contiguous()
|
| 252 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
| 253 |
+
.transpose(0, 1)
|
| 254 |
+
)
|
| 255 |
+
if k is not None:
|
| 256 |
+
k = (
|
| 257 |
+
k.contiguous()
|
| 258 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 259 |
+
.transpose(0, 1)
|
| 260 |
+
)
|
| 261 |
+
if v is not None:
|
| 262 |
+
v = (
|
| 263 |
+
v.contiguous()
|
| 264 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 265 |
+
.transpose(0, 1)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if saved_state is not None:
|
| 269 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 270 |
+
if "prev_key" in saved_state:
|
| 271 |
+
_prev_key = saved_state["prev_key"]
|
| 272 |
+
assert _prev_key is not None
|
| 273 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 274 |
+
if static_kv:
|
| 275 |
+
k = prev_key
|
| 276 |
+
else:
|
| 277 |
+
assert k is not None
|
| 278 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 279 |
+
src_len = k.size(1)
|
| 280 |
+
if "prev_value" in saved_state:
|
| 281 |
+
_prev_value = saved_state["prev_value"]
|
| 282 |
+
assert _prev_value is not None
|
| 283 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 284 |
+
if static_kv:
|
| 285 |
+
v = prev_value
|
| 286 |
+
else:
|
| 287 |
+
assert v is not None
|
| 288 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 289 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 290 |
+
if "prev_key_padding_mask" in saved_state:
|
| 291 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 292 |
+
assert k is not None and v is not None
|
| 293 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
| 294 |
+
key_padding_mask=key_padding_mask,
|
| 295 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 296 |
+
batch_size=bsz,
|
| 297 |
+
src_len=k.size(1),
|
| 298 |
+
static_kv=static_kv,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 302 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 303 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 304 |
+
# In this branch incremental_state is never None
|
| 305 |
+
assert incremental_state is not None
|
| 306 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 307 |
+
assert k is not None
|
| 308 |
+
assert k.size(1) == src_len
|
| 309 |
+
|
| 310 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 311 |
+
# not supporting Optional types.
|
| 312 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 313 |
+
key_padding_mask = None
|
| 314 |
+
|
| 315 |
+
if key_padding_mask is not None:
|
| 316 |
+
assert key_padding_mask.size(0) == bsz
|
| 317 |
+
assert key_padding_mask.size(1) == src_len
|
| 318 |
+
|
| 319 |
+
if self.add_zero_attn:
|
| 320 |
+
assert v is not None
|
| 321 |
+
src_len += 1
|
| 322 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 323 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 324 |
+
if attn_mask is not None:
|
| 325 |
+
attn_mask = torch.cat(
|
| 326 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 327 |
+
)
|
| 328 |
+
if key_padding_mask is not None:
|
| 329 |
+
key_padding_mask = torch.cat(
|
| 330 |
+
[
|
| 331 |
+
key_padding_mask,
|
| 332 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
| 333 |
+
key_padding_mask
|
| 334 |
+
),
|
| 335 |
+
],
|
| 336 |
+
dim=1,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 340 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 341 |
+
|
| 342 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 343 |
+
|
| 344 |
+
if attn_bias is not None:
|
| 345 |
+
attn_weights += attn_bias
|
| 346 |
+
|
| 347 |
+
if attn_mask is not None:
|
| 348 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 349 |
+
if self.onnx_trace:
|
| 350 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
| 351 |
+
attn_weights += attn_mask
|
| 352 |
+
|
| 353 |
+
if self_attn_mask is not None:
|
| 354 |
+
self_attn_mask = self_attn_mask.unsqueeze(1).expand(bsz, self.num_heads, tgt_len, src_len)
|
| 355 |
+
attn_weights += self_attn_mask.contiguous().view(bsz * self.num_heads, tgt_len, src_len)
|
| 356 |
+
|
| 357 |
+
if key_padding_mask is not None:
|
| 358 |
+
# don't attend to padding symbols
|
| 359 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 360 |
+
if not is_tpu:
|
| 361 |
+
attn_weights = attn_weights.masked_fill(
|
| 362 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
| 363 |
+
float("-inf"),
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 367 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
| 368 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 369 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 370 |
+
|
| 371 |
+
if before_softmax:
|
| 372 |
+
return attn_weights, v
|
| 373 |
+
|
| 374 |
+
attn_weights_float = utils.softmax(
|
| 375 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
| 376 |
+
)
|
| 377 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 378 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 379 |
+
|
| 380 |
+
assert v is not None
|
| 381 |
+
attn = torch.bmm(attn_probs, v)
|
| 382 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 383 |
+
if self.onnx_trace and attn.size(1) == 1:
|
| 384 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
| 385 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
| 386 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
| 387 |
+
else:
|
| 388 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 389 |
+
if self.c_attn is not None:
|
| 390 |
+
attn = attn.view(tgt_len, bsz, self.num_heads, self.head_dim)
|
| 391 |
+
attn = torch.einsum('tbhd,h->tbhd', attn, self.c_attn)
|
| 392 |
+
attn = attn.reshape(tgt_len, bsz, self.embed_dim)
|
| 393 |
+
attn = self.out_proj(attn)
|
| 394 |
+
attn_weights: Optional[Tensor] = None
|
| 395 |
+
if need_weights:
|
| 396 |
+
attn_weights = attn_weights_float.view(
|
| 397 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 398 |
+
).transpose(1, 0)
|
| 399 |
+
if not need_head_weights:
|
| 400 |
+
# average attention weights over heads
|
| 401 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 402 |
+
|
| 403 |
+
return attn, attn_weights
|
| 404 |
+
|
| 405 |
+
@staticmethod
|
| 406 |
+
def _append_prev_key_padding_mask(
|
| 407 |
+
key_padding_mask: Optional[Tensor],
|
| 408 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 409 |
+
batch_size: int,
|
| 410 |
+
src_len: int,
|
| 411 |
+
static_kv: bool,
|
| 412 |
+
) -> Optional[Tensor]:
|
| 413 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 414 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 415 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 416 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 417 |
+
new_key_padding_mask = torch.cat(
|
| 418 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 419 |
+
)
|
| 420 |
+
# During incremental decoding, as the padding token enters and
|
| 421 |
+
# leaves the frame, there will be a time when prev or current
|
| 422 |
+
# is None
|
| 423 |
+
elif prev_key_padding_mask is not None:
|
| 424 |
+
if src_len > prev_key_padding_mask.size(1):
|
| 425 |
+
filler = torch.zeros(
|
| 426 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 427 |
+
device=prev_key_padding_mask.device,
|
| 428 |
+
)
|
| 429 |
+
new_key_padding_mask = torch.cat(
|
| 430 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
| 434 |
+
elif key_padding_mask is not None:
|
| 435 |
+
if src_len > key_padding_mask.size(1):
|
| 436 |
+
filler = torch.zeros(
|
| 437 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 438 |
+
device=key_padding_mask.device,
|
| 439 |
+
)
|
| 440 |
+
new_key_padding_mask = torch.cat(
|
| 441 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
| 442 |
+
)
|
| 443 |
+
else:
|
| 444 |
+
new_key_padding_mask = key_padding_mask.float()
|
| 445 |
+
else:
|
| 446 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 447 |
+
return new_key_padding_mask
|
| 448 |
+
|
| 449 |
+
@torch.jit.export
|
| 450 |
+
def reorder_incremental_state(
|
| 451 |
+
self,
|
| 452 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 453 |
+
new_order: Tensor,
|
| 454 |
+
):
|
| 455 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
| 456 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
| 457 |
+
if input_buffer is not None:
|
| 458 |
+
for k in input_buffer.keys():
|
| 459 |
+
input_buffer_k = input_buffer[k]
|
| 460 |
+
if input_buffer_k is not None:
|
| 461 |
+
if self.encoder_decoder_attention and input_buffer_k.size(
|
| 462 |
+
0
|
| 463 |
+
) == new_order.size(0):
|
| 464 |
+
break
|
| 465 |
+
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 466 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
| 467 |
+
return incremental_state
|
| 468 |
+
|
| 469 |
+
def _get_input_buffer(
|
| 470 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 471 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 472 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 473 |
+
if result is not None:
|
| 474 |
+
return result
|
| 475 |
+
else:
|
| 476 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 477 |
+
return empty_result
|
| 478 |
+
|
| 479 |
+
def _set_input_buffer(
|
| 480 |
+
self,
|
| 481 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 482 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 483 |
+
):
|
| 484 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 485 |
+
|
| 486 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 487 |
+
return attn_weights
|
| 488 |
+
|
| 489 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 490 |
+
prefix = name + "." if name != "" else ""
|
| 491 |
+
items_to_add = {}
|
| 492 |
+
keys_to_remove = []
|
| 493 |
+
for k in state_dict.keys():
|
| 494 |
+
if k.endswith(prefix + "in_proj_weight"):
|
| 495 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
| 496 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 497 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
| 498 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
| 499 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
| 500 |
+
|
| 501 |
+
keys_to_remove.append(k)
|
| 502 |
+
|
| 503 |
+
k_bias = prefix + "in_proj_bias"
|
| 504 |
+
if k_bias in state_dict.keys():
|
| 505 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 506 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
| 507 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
| 508 |
+
dim : 2 * dim
|
| 509 |
+
]
|
| 510 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
| 511 |
+
|
| 512 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
| 513 |
+
|
| 514 |
+
for k in keys_to_remove:
|
| 515 |
+
del state_dict[k]
|
| 516 |
+
|
| 517 |
+
for key, value in items_to_add.items():
|
| 518 |
+
state_dict[key] = value
|
models/ofa/unify_transformer.py
ADDED
|
@@ -0,0 +1,1510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import random
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from fairseq import utils
|
| 14 |
+
from fairseq.distributed import fsdp_wrap
|
| 15 |
+
from fairseq.models import (
|
| 16 |
+
FairseqEncoder,
|
| 17 |
+
FairseqEncoderDecoderModel,
|
| 18 |
+
FairseqIncrementalDecoder,
|
| 19 |
+
register_model,
|
| 20 |
+
register_model_architecture,
|
| 21 |
+
)
|
| 22 |
+
from fairseq.modules import (
|
| 23 |
+
AdaptiveSoftmax,
|
| 24 |
+
BaseLayer,
|
| 25 |
+
FairseqDropout,
|
| 26 |
+
LayerDropModuleList,
|
| 27 |
+
LayerNorm,
|
| 28 |
+
SinusoidalPositionalEmbedding,
|
| 29 |
+
GradMultiply
|
| 30 |
+
)
|
| 31 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
| 32 |
+
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
| 33 |
+
from torch import Tensor
|
| 34 |
+
|
| 35 |
+
from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer
|
| 36 |
+
from .resnet import ResNet
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
| 40 |
+
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def BatchNorm2d(out_chan, momentum=0.1, eps=1e-3):
|
| 47 |
+
return nn.SyncBatchNorm.convert_sync_batchnorm(
|
| 48 |
+
nn.BatchNorm2d(out_chan, momentum=momentum, eps=eps)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS):
|
| 53 |
+
context_pos = torch.arange(max_position, dtype=torch.long)[:, None]
|
| 54 |
+
memory_pos = torch.arange(max_position, dtype=torch.long)[None, :]
|
| 55 |
+
relative_pos = context_pos - memory_pos
|
| 56 |
+
sign = torch.sign(relative_pos)
|
| 57 |
+
mid = bucket_size // 2
|
| 58 |
+
abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), mid-1, torch.abs(relative_pos))
|
| 59 |
+
log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
|
| 60 |
+
log_pos = log_pos.int()
|
| 61 |
+
bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long()
|
| 62 |
+
return bucket_pos + bucket_size - 1
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def make_image_bucket_position(bucket_size, num_relative_distance):
|
| 66 |
+
coords_h = torch.arange(bucket_size)
|
| 67 |
+
coords_w = torch.arange(bucket_size)
|
| 68 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 69 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 70 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 71 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 72 |
+
relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0
|
| 73 |
+
relative_coords[:, :, 1] += bucket_size - 1
|
| 74 |
+
relative_coords[:, :, 0] *= 2 * bucket_size - 1
|
| 75 |
+
relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype)
|
| 76 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 77 |
+
relative_position_index[0, 0:] = num_relative_distance - 3
|
| 78 |
+
relative_position_index[0:, 0] = num_relative_distance - 2
|
| 79 |
+
relative_position_index[0, 0] = num_relative_distance - 1
|
| 80 |
+
return relative_position_index
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@register_model("unify_transformer")
|
| 84 |
+
class TransformerModel(FairseqEncoderDecoderModel):
|
| 85 |
+
"""
|
| 86 |
+
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
|
| 87 |
+
<https://arxiv.org/abs/1706.03762>`_.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
encoder (TransformerEncoder): the encoder
|
| 91 |
+
decoder (TransformerDecoder): the decoder
|
| 92 |
+
|
| 93 |
+
The Transformer model provides the following named architectures and
|
| 94 |
+
command-line arguments:
|
| 95 |
+
|
| 96 |
+
.. argparse::
|
| 97 |
+
:ref: fairseq.models.transformer_parser
|
| 98 |
+
:prog:
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, args, encoder, decoder):
|
| 102 |
+
super().__init__(encoder, decoder)
|
| 103 |
+
self.args = args
|
| 104 |
+
self.supports_align_args = True
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def add_args(parser):
|
| 108 |
+
"""Add model-specific arguments to the parser."""
|
| 109 |
+
# fmt: off
|
| 110 |
+
parser.add_argument('--activation-fn',
|
| 111 |
+
choices=utils.get_available_activation_fns(),
|
| 112 |
+
help='activation function to use')
|
| 113 |
+
parser.add_argument('--dropout', type=float, metavar='D',
|
| 114 |
+
help='dropout probability')
|
| 115 |
+
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
| 116 |
+
help='dropout probability for attention weights')
|
| 117 |
+
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
|
| 118 |
+
help='dropout probability after activation in FFN.')
|
| 119 |
+
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
|
| 120 |
+
help='path to pre-trained encoder embedding')
|
| 121 |
+
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
|
| 122 |
+
help='encoder embedding dimension')
|
| 123 |
+
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
|
| 124 |
+
help='encoder embedding dimension for FFN')
|
| 125 |
+
parser.add_argument('--encoder-layers', type=int, metavar='N',
|
| 126 |
+
help='num encoder layers')
|
| 127 |
+
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
|
| 128 |
+
help='num encoder attention heads')
|
| 129 |
+
parser.add_argument('--encoder-normalize-before', action='store_true',
|
| 130 |
+
help='apply layernorm before each encoder block')
|
| 131 |
+
parser.add_argument('--encoder-learned-pos', action='store_true',
|
| 132 |
+
help='use learned positional embeddings in the encoder')
|
| 133 |
+
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
|
| 134 |
+
help='path to pre-trained decoder embedding')
|
| 135 |
+
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
|
| 136 |
+
help='decoder embedding dimension')
|
| 137 |
+
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
|
| 138 |
+
help='decoder embedding dimension for FFN')
|
| 139 |
+
parser.add_argument('--decoder-layers', type=int, metavar='N',
|
| 140 |
+
help='num decoder layers')
|
| 141 |
+
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
|
| 142 |
+
help='num decoder attention heads')
|
| 143 |
+
parser.add_argument('--decoder-learned-pos', action='store_true',
|
| 144 |
+
help='use learned positional embeddings in the decoder')
|
| 145 |
+
parser.add_argument('--decoder-normalize-before', action='store_true',
|
| 146 |
+
help='apply layernorm before each decoder block')
|
| 147 |
+
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
|
| 148 |
+
help='decoder output dimension (extra linear layer '
|
| 149 |
+
'if different from decoder embed dim')
|
| 150 |
+
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
|
| 151 |
+
help='share decoder input and output embeddings')
|
| 152 |
+
parser.add_argument('--share-all-embeddings', action='store_true',
|
| 153 |
+
help='share encoder, decoder and output embeddings'
|
| 154 |
+
' (requires shared dictionary and embed dim)')
|
| 155 |
+
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
|
| 156 |
+
help='if set, disables positional embeddings (outside self attention)')
|
| 157 |
+
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
|
| 158 |
+
help='comma separated list of adaptive softmax cutoff points. '
|
| 159 |
+
'Must be used with adaptive_loss criterion'),
|
| 160 |
+
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
|
| 161 |
+
help='sets adaptive softmax dropout for the tail projections')
|
| 162 |
+
parser.add_argument('--layernorm-embedding', action='store_true',
|
| 163 |
+
help='add layernorm to embedding')
|
| 164 |
+
parser.add_argument('--no-scale-embedding', action='store_true',
|
| 165 |
+
help='if True, dont scale embeddings')
|
| 166 |
+
parser.add_argument('--checkpoint-activations', action='store_true',
|
| 167 |
+
help='checkpoint activations at each layer, which saves GPU '
|
| 168 |
+
'memory usage at the cost of some additional compute')
|
| 169 |
+
parser.add_argument('--offload-activations', action='store_true',
|
| 170 |
+
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
|
| 171 |
+
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
|
| 172 |
+
parser.add_argument('--no-cross-attention', default=False, action='store_true',
|
| 173 |
+
help='do not perform cross-attention')
|
| 174 |
+
parser.add_argument('--cross-self-attention', default=False, action='store_true',
|
| 175 |
+
help='perform cross+self-attention')
|
| 176 |
+
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
| 177 |
+
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
|
| 178 |
+
help='LayerDrop probability for encoder')
|
| 179 |
+
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
|
| 180 |
+
help='LayerDrop probability for decoder')
|
| 181 |
+
parser.add_argument('--encoder-layers-to-keep', default=None,
|
| 182 |
+
help='which layers to *keep* when pruning as a comma-separated list')
|
| 183 |
+
parser.add_argument('--decoder-layers-to-keep', default=None,
|
| 184 |
+
help='which layers to *keep* when pruning as a comma-separated list')
|
| 185 |
+
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
|
| 186 |
+
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
|
| 187 |
+
help='iterative PQ quantization noise at training time')
|
| 188 |
+
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
|
| 189 |
+
help='block size of quantization noise at training time')
|
| 190 |
+
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
|
| 191 |
+
help='scalar quantization noise and scalar quantization at training time')
|
| 192 |
+
# args for Fully Sharded Data Parallel (FSDP) training
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
| 195 |
+
help=(
|
| 196 |
+
'minimum number of params for a layer to be wrapped with FSDP() when '
|
| 197 |
+
'training with --ddp-backend=fully_sharded. Smaller values will '
|
| 198 |
+
'improve memory efficiency, but may make torch.distributed '
|
| 199 |
+
'communication less efficient due to smaller input sizes. This option '
|
| 200 |
+
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
|
| 201 |
+
'--offload-activations are passed.'
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
parser.add_argument('--resnet-drop-path-rate', type=float,
|
| 206 |
+
help='resnet drop path rate')
|
| 207 |
+
parser.add_argument('--encoder-drop-path-rate', type=float,
|
| 208 |
+
help='encoder drop path rate')
|
| 209 |
+
parser.add_argument('--decoder-drop-path-rate', type=float,
|
| 210 |
+
help='encoder drop path rate')
|
| 211 |
+
|
| 212 |
+
parser.add_argument('--token-bucket-size', type=int,
|
| 213 |
+
help='token bucket size')
|
| 214 |
+
parser.add_argument('--image-bucket-size', type=int,
|
| 215 |
+
help='image bucket size')
|
| 216 |
+
|
| 217 |
+
parser.add_argument('--attn-scale-factor', type=float,
|
| 218 |
+
help='attention scale factor')
|
| 219 |
+
parser.add_argument('--freeze-resnet', action='store_true',
|
| 220 |
+
help='freeze resnet')
|
| 221 |
+
parser.add_argument('--freeze-encoder-embedding', action='store_true',
|
| 222 |
+
help='freeze encoder token embedding')
|
| 223 |
+
parser.add_argument('--freeze-decoder-embedding', action='store_true',
|
| 224 |
+
help='freeze decoder token embedding')
|
| 225 |
+
parser.add_argument('--add-type-embedding', action='store_true',
|
| 226 |
+
help='add source/region/patch type embedding')
|
| 227 |
+
|
| 228 |
+
parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'],
|
| 229 |
+
help='resnet type')
|
| 230 |
+
parser.add_argument('--resnet-model-path', type=str, metavar='STR',
|
| 231 |
+
help='path to load resnet')
|
| 232 |
+
parser.add_argument('--code-image-size', type=int,
|
| 233 |
+
help='code image size')
|
| 234 |
+
parser.add_argument('--patch-layernorm-embedding', action='store_true',
|
| 235 |
+
help='add layernorm to patch embedding')
|
| 236 |
+
parser.add_argument('--code-layernorm-embedding', action='store_true',
|
| 237 |
+
help='add layernorm to code embedding')
|
| 238 |
+
parser.add_argument('--entangle-position-embedding', action='store_true',
|
| 239 |
+
help='entangle position embedding')
|
| 240 |
+
parser.add_argument('--disable-entangle', action='store_true',
|
| 241 |
+
help='disable entangle')
|
| 242 |
+
parser.add_argument('--sync-bn', action='store_true',
|
| 243 |
+
help='sync batchnorm')
|
| 244 |
+
|
| 245 |
+
parser.add_argument('--scale-attn', action='store_true',
|
| 246 |
+
help='scale attn')
|
| 247 |
+
parser.add_argument('--scale-fc', action='store_true',
|
| 248 |
+
help='scale fc')
|
| 249 |
+
parser.add_argument('--scale-heads', action='store_true',
|
| 250 |
+
help='scale heads')
|
| 251 |
+
parser.add_argument('--scale-resids', action='store_true',
|
| 252 |
+
help='scale resids')
|
| 253 |
+
# fmt: on
|
| 254 |
+
|
| 255 |
+
@classmethod
|
| 256 |
+
def build_model(cls, args, task):
|
| 257 |
+
"""Build a new model instance."""
|
| 258 |
+
|
| 259 |
+
# make sure all arguments are present in older models
|
| 260 |
+
base_architecture(args)
|
| 261 |
+
|
| 262 |
+
if args.encoder_layers_to_keep:
|
| 263 |
+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
| 264 |
+
if args.decoder_layers_to_keep:
|
| 265 |
+
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
|
| 266 |
+
|
| 267 |
+
if getattr(args, "max_source_positions", None) is None:
|
| 268 |
+
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
| 269 |
+
if getattr(args, "max_target_positions", None) is None:
|
| 270 |
+
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
| 271 |
+
|
| 272 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
| 273 |
+
|
| 274 |
+
if args.share_all_embeddings:
|
| 275 |
+
if src_dict != tgt_dict:
|
| 276 |
+
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
| 277 |
+
if args.encoder_embed_dim != args.decoder_embed_dim:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
| 280 |
+
)
|
| 281 |
+
if args.decoder_embed_path and (
|
| 282 |
+
args.decoder_embed_path != args.encoder_embed_path
|
| 283 |
+
):
|
| 284 |
+
raise ValueError(
|
| 285 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
| 286 |
+
)
|
| 287 |
+
encoder_embed_tokens = cls.build_embedding(
|
| 288 |
+
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
| 289 |
+
)
|
| 290 |
+
decoder_embed_tokens = encoder_embed_tokens
|
| 291 |
+
args.share_decoder_input_output_embed = True
|
| 292 |
+
else:
|
| 293 |
+
encoder_embed_tokens = cls.build_embedding(
|
| 294 |
+
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
| 295 |
+
)
|
| 296 |
+
decoder_embed_tokens = cls.build_embedding(
|
| 297 |
+
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
| 298 |
+
)
|
| 299 |
+
if getattr(args, "freeze_encoder_embedding", False):
|
| 300 |
+
encoder_embed_tokens.weight.requires_grad = False
|
| 301 |
+
if getattr(args, "freeze_decoder_embedding", False):
|
| 302 |
+
decoder_embed_tokens.weight.requires_grad = False
|
| 303 |
+
if getattr(args, "offload_activations", False):
|
| 304 |
+
args.checkpoint_activations = True # offloading implies checkpointing
|
| 305 |
+
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
|
| 306 |
+
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
|
| 307 |
+
if not args.share_all_embeddings:
|
| 308 |
+
min_params_to_wrap = getattr(
|
| 309 |
+
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
| 310 |
+
)
|
| 311 |
+
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
|
| 312 |
+
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
|
| 313 |
+
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
|
| 314 |
+
return cls(args, encoder, decoder)
|
| 315 |
+
|
| 316 |
+
@classmethod
|
| 317 |
+
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
| 318 |
+
num_embeddings = len(dictionary)
|
| 319 |
+
padding_idx = dictionary.pad()
|
| 320 |
+
|
| 321 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
| 322 |
+
# if provided, load from preloaded dictionaries
|
| 323 |
+
if path:
|
| 324 |
+
embed_dict = utils.parse_embedding(path)
|
| 325 |
+
utils.load_embedding(embed_dict, dictionary, emb)
|
| 326 |
+
return emb
|
| 327 |
+
|
| 328 |
+
@classmethod
|
| 329 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
| 330 |
+
return TransformerEncoder(args, src_dict, embed_tokens)
|
| 331 |
+
|
| 332 |
+
@classmethod
|
| 333 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
| 334 |
+
return TransformerDecoder(
|
| 335 |
+
args,
|
| 336 |
+
tgt_dict,
|
| 337 |
+
embed_tokens,
|
| 338 |
+
no_encoder_attn=getattr(args, "no_cross_attention", False),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# TorchScript doesn't support optional arguments with variable length (**kwargs).
|
| 342 |
+
# Current workaround is to add union of all arguments in child classes.
|
| 343 |
+
def forward(
|
| 344 |
+
self,
|
| 345 |
+
src_tokens,
|
| 346 |
+
src_lengths,
|
| 347 |
+
prev_output_tokens,
|
| 348 |
+
return_all_hiddens: bool = True,
|
| 349 |
+
features_only: bool = False,
|
| 350 |
+
alignment_layer: Optional[int] = None,
|
| 351 |
+
alignment_heads: Optional[int] = None,
|
| 352 |
+
):
|
| 353 |
+
"""
|
| 354 |
+
Run the forward pass for an encoder-decoder model.
|
| 355 |
+
|
| 356 |
+
Copied from the base class, but without ``**kwargs``,
|
| 357 |
+
which are not supported by TorchScript.
|
| 358 |
+
"""
|
| 359 |
+
encoder_out = self.encoder(
|
| 360 |
+
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
|
| 361 |
+
)
|
| 362 |
+
decoder_out = self.decoder(
|
| 363 |
+
prev_output_tokens,
|
| 364 |
+
encoder_out=encoder_out,
|
| 365 |
+
features_only=features_only,
|
| 366 |
+
alignment_layer=alignment_layer,
|
| 367 |
+
alignment_heads=alignment_heads,
|
| 368 |
+
src_lengths=src_lengths,
|
| 369 |
+
return_all_hiddens=return_all_hiddens,
|
| 370 |
+
)
|
| 371 |
+
return decoder_out
|
| 372 |
+
|
| 373 |
+
# Since get_normalized_probs is in the Fairseq Model which is not scriptable,
|
| 374 |
+
# I rewrite the get_normalized_probs from Base Class to call the
|
| 375 |
+
# helper function in the Base Class.
|
| 376 |
+
@torch.jit.export
|
| 377 |
+
def get_normalized_probs(
|
| 378 |
+
self,
|
| 379 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
| 380 |
+
log_probs: bool,
|
| 381 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
| 382 |
+
):
|
| 383 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
| 384 |
+
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class TransformerEncoder(FairseqEncoder):
|
| 388 |
+
"""
|
| 389 |
+
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
| 390 |
+
is a :class:`TransformerEncoderLayer`.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 394 |
+
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
| 395 |
+
embed_tokens (torch.nn.Embedding): input embedding
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def __init__(self, args, dictionary, embed_tokens):
|
| 399 |
+
self.args = args
|
| 400 |
+
super().__init__(dictionary)
|
| 401 |
+
self.register_buffer("version", torch.Tensor([3]))
|
| 402 |
+
|
| 403 |
+
self.dropout_module = FairseqDropout(
|
| 404 |
+
args.dropout, module_name=self.__class__.__name__
|
| 405 |
+
)
|
| 406 |
+
self.encoder_layerdrop = args.encoder_layerdrop
|
| 407 |
+
|
| 408 |
+
embed_dim = embed_tokens.embedding_dim
|
| 409 |
+
self.padding_idx = embed_tokens.padding_idx
|
| 410 |
+
self.max_source_positions = args.max_source_positions
|
| 411 |
+
self.num_attention_heads = args.encoder_attention_heads
|
| 412 |
+
|
| 413 |
+
self.embed_tokens = embed_tokens
|
| 414 |
+
|
| 415 |
+
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
| 416 |
+
|
| 417 |
+
if getattr(args, "layernorm_embedding", False):
|
| 418 |
+
self.layernorm_embedding = LayerNorm(embed_dim)
|
| 419 |
+
else:
|
| 420 |
+
self.layernorm_embedding = None
|
| 421 |
+
|
| 422 |
+
if getattr(args, "add_type_embedding", False):
|
| 423 |
+
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
|
| 424 |
+
else:
|
| 425 |
+
self.type_embedding = None
|
| 426 |
+
|
| 427 |
+
if getattr(args, "sync_bn", False):
|
| 428 |
+
norm_layer = BatchNorm2d
|
| 429 |
+
else:
|
| 430 |
+
norm_layer = None
|
| 431 |
+
|
| 432 |
+
if args.resnet_type == 'resnet101':
|
| 433 |
+
self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
|
| 434 |
+
elif args.resnet_type == 'resnet152':
|
| 435 |
+
self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
|
| 436 |
+
else:
|
| 437 |
+
raise NotImplementedError
|
| 438 |
+
self.image_proj = Linear(1024, embed_dim)
|
| 439 |
+
if getattr(args, "resnet_model_path", None):
|
| 440 |
+
print("load resnet {}".format(args.resnet_model_path))
|
| 441 |
+
resnet_state_dict = torch.load(self.args.resnet_model_path)
|
| 442 |
+
self.embed_images.load_state_dict(resnet_state_dict)
|
| 443 |
+
if getattr(args, "patch_layernorm_embedding", False):
|
| 444 |
+
self.patch_layernorm_embedding = LayerNorm(embed_dim)
|
| 445 |
+
else:
|
| 446 |
+
self.patch_layernorm_embedding = None
|
| 447 |
+
|
| 448 |
+
self.embed_positions = Embedding(args.max_source_positions + 2, embed_dim)
|
| 449 |
+
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
|
| 450 |
+
self.pos_ln = LayerNorm(embed_dim)
|
| 451 |
+
self.image_pos_ln = LayerNorm(embed_dim)
|
| 452 |
+
self.pos_scaling = float(embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5
|
| 453 |
+
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
|
| 454 |
+
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
|
| 455 |
+
|
| 456 |
+
if not args.adaptive_input and args.quant_noise_pq > 0:
|
| 457 |
+
self.quant_noise = apply_quant_noise_(
|
| 458 |
+
nn.Linear(embed_dim, embed_dim, bias=False),
|
| 459 |
+
args.quant_noise_pq,
|
| 460 |
+
args.quant_noise_pq_block_size,
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
self.quant_noise = None
|
| 464 |
+
|
| 465 |
+
if self.encoder_layerdrop > 0.0:
|
| 466 |
+
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
|
| 467 |
+
else:
|
| 468 |
+
self.layers = nn.ModuleList([])
|
| 469 |
+
|
| 470 |
+
dpr = [x.item() for x in torch.linspace(0, args.encoder_drop_path_rate, args.encoder_layers)]
|
| 471 |
+
self.layers.extend(
|
| 472 |
+
[self.build_encoder_layer(args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)]
|
| 473 |
+
)
|
| 474 |
+
self.num_layers = len(self.layers)
|
| 475 |
+
|
| 476 |
+
if args.encoder_normalize_before:
|
| 477 |
+
self.layer_norm = LayerNorm(embed_dim)
|
| 478 |
+
else:
|
| 479 |
+
self.layer_norm = None
|
| 480 |
+
|
| 481 |
+
token_bucket_size = args.token_bucket_size
|
| 482 |
+
token_num_rel_dis = 2 * token_bucket_size - 1
|
| 483 |
+
token_rp_bucket = make_token_bucket_position(token_bucket_size)
|
| 484 |
+
self.token_rel_pos_table_list = nn.ModuleList(
|
| 485 |
+
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
image_bucket_size = args.image_bucket_size
|
| 489 |
+
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
|
| 490 |
+
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
|
| 491 |
+
self.image_rel_pos_table_list = nn.ModuleList(
|
| 492 |
+
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
self.register_buffer("token_rp_bucket", token_rp_bucket)
|
| 496 |
+
self.register_buffer("image_rp_bucket", image_rp_bucket)
|
| 497 |
+
self.entangle_position_embedding = args.entangle_position_embedding
|
| 498 |
+
|
| 499 |
+
def train(self, mode=True):
|
| 500 |
+
super(TransformerEncoder, self).train(mode)
|
| 501 |
+
if getattr(self.args, "freeze_resnet", False):
|
| 502 |
+
for m in self.embed_images.modules():
|
| 503 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 504 |
+
m.eval()
|
| 505 |
+
m.weight.requires_grad = False
|
| 506 |
+
m.bias.requires_grad = False
|
| 507 |
+
|
| 508 |
+
def build_encoder_layer(self, args, drop_path_rate=0.0):
|
| 509 |
+
layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate)
|
| 510 |
+
checkpoint = getattr(args, "checkpoint_activations", False)
|
| 511 |
+
if checkpoint:
|
| 512 |
+
offload_to_cpu = getattr(args, "offload_activations", False)
|
| 513 |
+
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
| 514 |
+
# if we are checkpointing, enforce that FSDP always wraps the
|
| 515 |
+
# checkpointed layer, regardless of layer size
|
| 516 |
+
min_params_to_wrap = (
|
| 517 |
+
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
|
| 518 |
+
if not checkpoint else 0
|
| 519 |
+
)
|
| 520 |
+
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
| 521 |
+
return layer
|
| 522 |
+
|
| 523 |
+
def get_rel_pos_bias(self, x, idx):
|
| 524 |
+
seq_len = x.size(1)
|
| 525 |
+
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
|
| 526 |
+
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
|
| 527 |
+
values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1)
|
| 528 |
+
values = values.permute([0, 3, 1, 2])
|
| 529 |
+
return values.contiguous()
|
| 530 |
+
|
| 531 |
+
def get_image_rel_pos_bias(self, image_position_ids, idx):
|
| 532 |
+
bsz, seq_len = image_position_ids.shape
|
| 533 |
+
rp_bucket_size = self.image_rp_bucket.size(1)
|
| 534 |
+
|
| 535 |
+
rp_bucket = self.image_rp_bucket.unsqueeze(0).expand(
|
| 536 |
+
bsz, rp_bucket_size, rp_bucket_size
|
| 537 |
+
).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size)
|
| 538 |
+
).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len))
|
| 539 |
+
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
|
| 540 |
+
values = values.permute(0, 3, 1, 2)
|
| 541 |
+
return values
|
| 542 |
+
|
| 543 |
+
def get_patch_images_info(self, patch_images, sample_patch_num, device):
|
| 544 |
+
image_embed = self.embed_images(patch_images)
|
| 545 |
+
h, w = image_embed.shape[-2:]
|
| 546 |
+
image_num_patches = h * w
|
| 547 |
+
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
|
| 548 |
+
image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
|
| 549 |
+
torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1
|
| 550 |
+
image_position_idx = image_position_idx.view(-1).to(device)
|
| 551 |
+
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
|
| 552 |
+
|
| 553 |
+
image_embed = image_embed.flatten(2).transpose(1, 2)
|
| 554 |
+
if sample_patch_num is not None:
|
| 555 |
+
patch_orders = [
|
| 556 |
+
random.sample(range(image_num_patches), k=sample_patch_num)
|
| 557 |
+
for _ in range(patch_images.size(0))
|
| 558 |
+
]
|
| 559 |
+
patch_orders = torch.LongTensor(patch_orders).to(device)
|
| 560 |
+
image_embed = image_embed.gather(
|
| 561 |
+
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
|
| 562 |
+
)
|
| 563 |
+
image_num_patches = sample_patch_num
|
| 564 |
+
image_padding_mask = image_padding_mask.gather(1, patch_orders)
|
| 565 |
+
image_position_ids = image_position_ids.gather(1, patch_orders)
|
| 566 |
+
image_pos_embed = self.embed_image_positions(image_position_ids)
|
| 567 |
+
|
| 568 |
+
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
|
| 569 |
+
|
| 570 |
+
def forward_embedding(
|
| 571 |
+
self,
|
| 572 |
+
src_tokens,
|
| 573 |
+
image_embed: Optional[torch.Tensor] = None,
|
| 574 |
+
image_embed_2: Optional[torch.Tensor] = None,
|
| 575 |
+
token_embedding: Optional[torch.Tensor] = None,
|
| 576 |
+
pos_embed: Optional[torch.Tensor] = None,
|
| 577 |
+
image_pos_embed: Optional[torch.Tensor] = None,
|
| 578 |
+
image_pos_embed_2: Optional[torch.Tensor] = None
|
| 579 |
+
):
|
| 580 |
+
# embed tokens and positions
|
| 581 |
+
if token_embedding is None:
|
| 582 |
+
token_embedding = self.embed_tokens(src_tokens)
|
| 583 |
+
x = embed = self.embed_scale * token_embedding
|
| 584 |
+
if self.entangle_position_embedding and pos_embed is not None:
|
| 585 |
+
x += pos_embed
|
| 586 |
+
if self.type_embedding is not None:
|
| 587 |
+
x += self.type_embedding(src_tokens.new_zeros(x.size()[:2]))
|
| 588 |
+
if self.layernorm_embedding is not None:
|
| 589 |
+
x = self.layernorm_embedding(x)
|
| 590 |
+
x = self.dropout_module(x)
|
| 591 |
+
if self.quant_noise is not None:
|
| 592 |
+
x = self.quant_noise(x)
|
| 593 |
+
|
| 594 |
+
# embed raw images
|
| 595 |
+
if image_embed is not None:
|
| 596 |
+
image_embed = self.image_proj(image_embed)
|
| 597 |
+
image_x = image_embed = self.embed_scale * image_embed
|
| 598 |
+
if self.entangle_position_embedding and image_pos_embed is not None:
|
| 599 |
+
image_x += image_pos_embed
|
| 600 |
+
if self.type_embedding is not None:
|
| 601 |
+
image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2]))
|
| 602 |
+
if self.patch_layernorm_embedding is not None:
|
| 603 |
+
image_x = self.patch_layernorm_embedding(image_x)
|
| 604 |
+
image_x = self.dropout_module(image_x)
|
| 605 |
+
if self.quant_noise is not None:
|
| 606 |
+
image_x = self.quant_noise(image_x)
|
| 607 |
+
x = torch.cat([image_x, x], dim=1)
|
| 608 |
+
embed = torch.cat([image_embed, embed], dim=1)
|
| 609 |
+
|
| 610 |
+
if image_embed_2 is not None:
|
| 611 |
+
assert self.type_embedding is not None
|
| 612 |
+
image_embed_2 = self.image_proj(image_embed_2)
|
| 613 |
+
image_x_2 = image_embed_2 = self.embed_scale * image_embed_2
|
| 614 |
+
if self.entangle_position_embedding and image_pos_embed_2 is not None:
|
| 615 |
+
image_x_2 += image_pos_embed_2
|
| 616 |
+
if self.type_embedding is not None:
|
| 617 |
+
image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2))
|
| 618 |
+
if self.patch_layernorm_embedding is not None:
|
| 619 |
+
image_x_2 = self.patch_layernorm_embedding(image_x_2)
|
| 620 |
+
image_x_2 = self.dropout_module(image_x_2)
|
| 621 |
+
if self.quant_noise is not None:
|
| 622 |
+
image_x_2 = self.quant_noise(image_x_2)
|
| 623 |
+
x = torch.cat([image_x_2, x], dim=1)
|
| 624 |
+
embed = torch.cat([image_embed_2, embed], dim=1)
|
| 625 |
+
|
| 626 |
+
return x, embed
|
| 627 |
+
|
| 628 |
+
def forward(
|
| 629 |
+
self,
|
| 630 |
+
src_tokens,
|
| 631 |
+
src_lengths,
|
| 632 |
+
patch_images: Optional[torch.Tensor] = None,
|
| 633 |
+
patch_images_2: Optional[torch.Tensor] = None,
|
| 634 |
+
patch_masks: Optional[torch.Tensor] = None,
|
| 635 |
+
code_masks: Optional[torch.Tensor] = None,
|
| 636 |
+
return_all_hiddens: bool = False,
|
| 637 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
| 638 |
+
sample_patch_num: Optional[int] = None
|
| 639 |
+
):
|
| 640 |
+
"""
|
| 641 |
+
Args:
|
| 642 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
| 643 |
+
`(batch, src_len)`
|
| 644 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
| 645 |
+
shape `(batch)`
|
| 646 |
+
return_all_hiddens (bool, optional): also return all of the
|
| 647 |
+
intermediate hidden states (default: False).
|
| 648 |
+
token_embeddings (torch.Tensor, optional): precomputed embeddings
|
| 649 |
+
default `None` will recompute embeddings
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
dict:
|
| 653 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
| 654 |
+
shape `(src_len, batch, embed_dim)`
|
| 655 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
| 656 |
+
padding elements of shape `(batch, src_len)`
|
| 657 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
| 658 |
+
of shape `(batch, src_len, embed_dim)`
|
| 659 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
| 660 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
| 661 |
+
Only populated if *return_all_hiddens* is True.
|
| 662 |
+
"""
|
| 663 |
+
return self.forward_scriptable(src_tokens,
|
| 664 |
+
src_lengths,
|
| 665 |
+
patch_images,
|
| 666 |
+
patch_images_2,
|
| 667 |
+
patch_masks,
|
| 668 |
+
return_all_hiddens,
|
| 669 |
+
token_embeddings,
|
| 670 |
+
sample_patch_num)
|
| 671 |
+
|
| 672 |
+
# TorchScript doesn't support super() method so that the scriptable Subclass
|
| 673 |
+
# can't access the base class model in Torchscript.
|
| 674 |
+
# Current workaround is to add a helper function with different name and
|
| 675 |
+
# call the helper function from scriptable Subclass.
|
| 676 |
+
def forward_scriptable(
|
| 677 |
+
self,
|
| 678 |
+
src_tokens,
|
| 679 |
+
src_lengths,
|
| 680 |
+
patch_images: Optional[torch.Tensor] = None,
|
| 681 |
+
patch_images_2: Optional[torch.Tensor] = None,
|
| 682 |
+
patch_masks: Optional[torch.Tensor] = None,
|
| 683 |
+
return_all_hiddens: bool = False,
|
| 684 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
| 685 |
+
sample_patch_num: Optional[int] = None
|
| 686 |
+
):
|
| 687 |
+
"""
|
| 688 |
+
Args:
|
| 689 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
| 690 |
+
`(batch, src_len)`
|
| 691 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
| 692 |
+
shape `(batch)`
|
| 693 |
+
return_all_hiddens (bool, optional): also return all of the
|
| 694 |
+
intermediate hidden states (default: False).
|
| 695 |
+
token_embeddings (torch.Tensor, optional): precomputed embeddings
|
| 696 |
+
default `None` will recompute embeddings
|
| 697 |
+
|
| 698 |
+
Returns:
|
| 699 |
+
dict:
|
| 700 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
| 701 |
+
shape `(src_len, batch, embed_dim)`
|
| 702 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
| 703 |
+
padding elements of shape `(batch, src_len)`
|
| 704 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
| 705 |
+
of shape `(batch, src_len, embed_dim)`
|
| 706 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
| 707 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
| 708 |
+
Only populated if *return_all_hiddens* is True.
|
| 709 |
+
"""
|
| 710 |
+
image_embed = None
|
| 711 |
+
image_embed_2 = None
|
| 712 |
+
image_pos_embed = None
|
| 713 |
+
image_pos_embed_2 = None
|
| 714 |
+
if patch_images is not None:
|
| 715 |
+
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
|
| 716 |
+
self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device)
|
| 717 |
+
image_padding_mask[~patch_masks] = True
|
| 718 |
+
if patch_images_2 is not None:
|
| 719 |
+
image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
|
| 720 |
+
self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device)
|
| 721 |
+
image_padding_mask_2[~patch_masks] = True
|
| 722 |
+
|
| 723 |
+
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
| 724 |
+
if patch_images is not None:
|
| 725 |
+
encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1)
|
| 726 |
+
if patch_images_2 is not None:
|
| 727 |
+
encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1)
|
| 728 |
+
has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any())
|
| 729 |
+
|
| 730 |
+
pos_embed = self.embed_positions(utils.new_arange(src_tokens))
|
| 731 |
+
x, encoder_embedding = self.forward_embedding(
|
| 732 |
+
src_tokens, image_embed, image_embed_2, token_embeddings,
|
| 733 |
+
pos_embed, image_pos_embed, image_pos_embed_2
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# account for padding while computing the representation
|
| 737 |
+
if has_pads:
|
| 738 |
+
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
| 739 |
+
|
| 740 |
+
# B x T x C -> T x B x C
|
| 741 |
+
x = x.transpose(0, 1)
|
| 742 |
+
|
| 743 |
+
pos_embed = self.pos_ln(pos_embed)
|
| 744 |
+
if patch_images is not None:
|
| 745 |
+
image_pos_embed = self.image_pos_ln(image_pos_embed)
|
| 746 |
+
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
|
| 747 |
+
if patch_images_2 is not None:
|
| 748 |
+
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
|
| 749 |
+
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
|
| 750 |
+
|
| 751 |
+
pos_q = self.pos_q_linear(pos_embed).view(
|
| 752 |
+
x.size(1), x.size(0), self.num_attention_heads, -1
|
| 753 |
+
).transpose(1, 2) * self.pos_scaling
|
| 754 |
+
pos_k = self.pos_k_linear(pos_embed).view(
|
| 755 |
+
x.size(1), x.size(0), self.num_attention_heads, -1
|
| 756 |
+
).transpose(1, 2)
|
| 757 |
+
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
|
| 758 |
+
|
| 759 |
+
encoder_states = []
|
| 760 |
+
|
| 761 |
+
if return_all_hiddens:
|
| 762 |
+
encoder_states.append(x)
|
| 763 |
+
|
| 764 |
+
# encoder layers
|
| 765 |
+
for idx, layer in enumerate(self.layers):
|
| 766 |
+
self_attn_bias = abs_pos_bias.clone()
|
| 767 |
+
self_attn_bias[:, :, -src_tokens.size(1):, -src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx)
|
| 768 |
+
if patch_images_2 is not None:
|
| 769 |
+
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
|
| 770 |
+
self.get_image_rel_pos_bias(image_position_ids_2, idx)
|
| 771 |
+
self_attn_bias[:, :, image_num_patches_2:image_num_patches_2+image_num_patches, image_num_patches_2:image_num_patches_2+image_num_patches] += \
|
| 772 |
+
self.get_image_rel_pos_bias(image_position_ids, idx)
|
| 773 |
+
elif patch_images is not None:
|
| 774 |
+
self_attn_bias[:, :, :x.size(0) - src_tokens.size(1), :x.size(0) - src_tokens.size(1)] += \
|
| 775 |
+
self.get_image_rel_pos_bias(image_position_ids, idx)
|
| 776 |
+
self_attn_bias = self_attn_bias.reshape(-1, x.size(0), x.size(0))
|
| 777 |
+
|
| 778 |
+
x = layer(
|
| 779 |
+
x, encoder_padding_mask=encoder_padding_mask if has_pads else None, self_attn_bias=self_attn_bias
|
| 780 |
+
)
|
| 781 |
+
if return_all_hiddens:
|
| 782 |
+
assert encoder_states is not None
|
| 783 |
+
encoder_states.append(x)
|
| 784 |
+
|
| 785 |
+
if self.layer_norm is not None:
|
| 786 |
+
x = self.layer_norm(x)
|
| 787 |
+
|
| 788 |
+
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
|
| 789 |
+
# `forward` so we use a dictionary instead.
|
| 790 |
+
# TorchScript does not support mixed values so the values are all lists.
|
| 791 |
+
# The empty list is equivalent to None.
|
| 792 |
+
return {
|
| 793 |
+
"encoder_out": [x], # T x B x C
|
| 794 |
+
"encoder_padding_mask": [encoder_padding_mask], # B x T
|
| 795 |
+
"encoder_embedding": [], # B x T x C
|
| 796 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
| 797 |
+
"src_tokens": [],
|
| 798 |
+
"src_lengths": [],
|
| 799 |
+
"position_embeddings": [pos_embed], # B x T x C
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
@torch.jit.export
|
| 803 |
+
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
| 804 |
+
"""
|
| 805 |
+
Reorder encoder output according to *new_order*.
|
| 806 |
+
|
| 807 |
+
Args:
|
| 808 |
+
encoder_out: output from the ``forward()`` method
|
| 809 |
+
new_order (LongTensor): desired order
|
| 810 |
+
|
| 811 |
+
Returns:
|
| 812 |
+
*encoder_out* rearranged according to *new_order*
|
| 813 |
+
"""
|
| 814 |
+
if len(encoder_out["encoder_out"]) == 0:
|
| 815 |
+
new_encoder_out = []
|
| 816 |
+
else:
|
| 817 |
+
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
|
| 818 |
+
if len(encoder_out["encoder_padding_mask"]) == 0:
|
| 819 |
+
new_encoder_padding_mask = []
|
| 820 |
+
else:
|
| 821 |
+
new_encoder_padding_mask = [
|
| 822 |
+
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
|
| 823 |
+
]
|
| 824 |
+
if len(encoder_out["encoder_embedding"]) == 0:
|
| 825 |
+
new_encoder_embedding = []
|
| 826 |
+
else:
|
| 827 |
+
new_encoder_embedding = [
|
| 828 |
+
encoder_out["encoder_embedding"][0].index_select(0, new_order)
|
| 829 |
+
]
|
| 830 |
+
|
| 831 |
+
if len(encoder_out["src_tokens"]) == 0:
|
| 832 |
+
new_src_tokens = []
|
| 833 |
+
else:
|
| 834 |
+
new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
|
| 835 |
+
|
| 836 |
+
if len(encoder_out["src_lengths"]) == 0:
|
| 837 |
+
new_src_lengths = []
|
| 838 |
+
else:
|
| 839 |
+
new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
|
| 840 |
+
|
| 841 |
+
if len(encoder_out["position_embeddings"]) == 0:
|
| 842 |
+
new_position_embeddings = []
|
| 843 |
+
else:
|
| 844 |
+
new_position_embeddings = [(encoder_out["position_embeddings"][0]).index_select(0, new_order)]
|
| 845 |
+
|
| 846 |
+
encoder_states = encoder_out["encoder_states"]
|
| 847 |
+
if len(encoder_states) > 0:
|
| 848 |
+
for idx, state in enumerate(encoder_states):
|
| 849 |
+
encoder_states[idx] = state.index_select(1, new_order)
|
| 850 |
+
|
| 851 |
+
return {
|
| 852 |
+
"encoder_out": new_encoder_out, # T x B x C
|
| 853 |
+
"encoder_padding_mask": new_encoder_padding_mask, # B x T
|
| 854 |
+
"encoder_embedding": new_encoder_embedding, # B x T x C
|
| 855 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
| 856 |
+
"src_tokens": new_src_tokens, # B x T
|
| 857 |
+
"src_lengths": new_src_lengths, # B x 1
|
| 858 |
+
"position_embeddings": new_position_embeddings, # B x T x C
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
def max_positions(self):
|
| 862 |
+
"""Maximum input length supported by the encoder."""
|
| 863 |
+
if self.embed_positions is None:
|
| 864 |
+
return self.max_source_positions
|
| 865 |
+
return self.max_source_positions
|
| 866 |
+
|
| 867 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 868 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
| 869 |
+
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
| 870 |
+
weights_key = "{}.embed_positions.weights".format(name)
|
| 871 |
+
if weights_key in state_dict:
|
| 872 |
+
print("deleting {0}".format(weights_key))
|
| 873 |
+
del state_dict[weights_key]
|
| 874 |
+
state_dict[
|
| 875 |
+
"{}.embed_positions._float_tensor".format(name)
|
| 876 |
+
] = torch.FloatTensor(1)
|
| 877 |
+
for i in range(self.num_layers):
|
| 878 |
+
# update layer norms
|
| 879 |
+
self.layers[i].upgrade_state_dict_named(
|
| 880 |
+
state_dict, "{}.layers.{}".format(name, i)
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
# version_key = "{}.version".format(name)
|
| 884 |
+
# if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
|
| 885 |
+
# # earlier checkpoints did not normalize after the stack of layers
|
| 886 |
+
# self.layer_norm = None
|
| 887 |
+
# self.normalize = False
|
| 888 |
+
# state_dict[version_key] = torch.Tensor([1])
|
| 889 |
+
|
| 890 |
+
prefix = name + "." if name != "" else ""
|
| 891 |
+
for param_name, param_tensor in self.state_dict().items():
|
| 892 |
+
if (prefix + param_name) not in state_dict and param_name in self.state_dict():
|
| 893 |
+
state_dict[prefix + param_name] = self.state_dict()[param_name]
|
| 894 |
+
|
| 895 |
+
if len(state_dict["encoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
|
| 896 |
+
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["encoder.embed_image_positions.weight"])
|
| 897 |
+
embed_dim = state_dict["encoder.embed_image_positions.weight"].size(1)
|
| 898 |
+
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
|
| 899 |
+
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
|
| 900 |
+
new_pos_embed_to_add = new_pos_embed_to_add.to(
|
| 901 |
+
dtype=state_dict["encoder.embed_image_positions.weight"].dtype,
|
| 902 |
+
)
|
| 903 |
+
state_dict["encoder.embed_image_positions.weight"] = torch.cat(
|
| 904 |
+
[state_dict["encoder.embed_image_positions.weight"], new_pos_embed_to_add]
|
| 905 |
+
)
|
| 906 |
+
return state_dict
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
class TransformerDecoder(FairseqIncrementalDecoder):
|
| 910 |
+
"""
|
| 911 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
| 912 |
+
is a :class:`TransformerDecoderLayer`.
|
| 913 |
+
|
| 914 |
+
Args:
|
| 915 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 916 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
| 917 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
| 918 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
| 919 |
+
(default: False).
|
| 920 |
+
"""
|
| 921 |
+
|
| 922 |
+
def __init__(
|
| 923 |
+
self,
|
| 924 |
+
args,
|
| 925 |
+
dictionary,
|
| 926 |
+
embed_tokens,
|
| 927 |
+
no_encoder_attn=False,
|
| 928 |
+
output_projection=None,
|
| 929 |
+
):
|
| 930 |
+
self.args = args
|
| 931 |
+
super().__init__(dictionary)
|
| 932 |
+
self.register_buffer("version", torch.Tensor([3]))
|
| 933 |
+
self._future_mask = torch.empty(0)
|
| 934 |
+
|
| 935 |
+
self.dropout_module = FairseqDropout(
|
| 936 |
+
args.dropout, module_name=self.__class__.__name__
|
| 937 |
+
)
|
| 938 |
+
self.decoder_layerdrop = args.decoder_layerdrop
|
| 939 |
+
self.share_input_output_embed = args.share_decoder_input_output_embed
|
| 940 |
+
self.num_attention_heads = args.decoder_attention_heads
|
| 941 |
+
|
| 942 |
+
input_embed_dim = embed_tokens.embedding_dim
|
| 943 |
+
embed_dim = args.decoder_embed_dim
|
| 944 |
+
self.embed_dim = embed_dim
|
| 945 |
+
self.output_embed_dim = args.decoder_output_dim
|
| 946 |
+
|
| 947 |
+
self.padding_idx = embed_tokens.padding_idx
|
| 948 |
+
self.max_target_positions = args.max_target_positions
|
| 949 |
+
|
| 950 |
+
self.embed_tokens = embed_tokens
|
| 951 |
+
|
| 952 |
+
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
| 953 |
+
|
| 954 |
+
if not args.adaptive_input and args.quant_noise_pq > 0:
|
| 955 |
+
self.quant_noise = apply_quant_noise_(
|
| 956 |
+
nn.Linear(embed_dim, embed_dim, bias=False),
|
| 957 |
+
args.quant_noise_pq,
|
| 958 |
+
args.quant_noise_pq_block_size,
|
| 959 |
+
)
|
| 960 |
+
else:
|
| 961 |
+
self.quant_noise = None
|
| 962 |
+
|
| 963 |
+
self.project_in_dim = (
|
| 964 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
| 965 |
+
if embed_dim != input_embed_dim
|
| 966 |
+
else None
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
if getattr(args, "layernorm_embedding", False):
|
| 970 |
+
self.layernorm_embedding = LayerNorm(embed_dim)
|
| 971 |
+
else:
|
| 972 |
+
self.layernorm_embedding = None
|
| 973 |
+
|
| 974 |
+
self.window_size = args.code_image_size // 8
|
| 975 |
+
|
| 976 |
+
self.embed_positions = Embedding(args.max_target_positions + 2, embed_dim)
|
| 977 |
+
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
|
| 978 |
+
self.pos_ln = LayerNorm(embed_dim)
|
| 979 |
+
self.image_pos_ln = LayerNorm(embed_dim)
|
| 980 |
+
self.pos_scaling = float(embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5
|
| 981 |
+
self.self_pos_q_linear = nn.Linear(embed_dim, embed_dim)
|
| 982 |
+
self.self_pos_k_linear = nn.Linear(embed_dim, embed_dim)
|
| 983 |
+
self.cross_pos_q_linear = nn.Linear(embed_dim, embed_dim)
|
| 984 |
+
self.cross_pos_k_linear = nn.Linear(embed_dim, embed_dim)
|
| 985 |
+
|
| 986 |
+
if getattr(args, "code_layernorm_embedding", False):
|
| 987 |
+
self.code_layernorm_embedding = LayerNorm(embed_dim)
|
| 988 |
+
else:
|
| 989 |
+
self.code_layernorm_embedding = None
|
| 990 |
+
|
| 991 |
+
self.cross_self_attention = getattr(args, "cross_self_attention", False)
|
| 992 |
+
|
| 993 |
+
if self.decoder_layerdrop > 0.0:
|
| 994 |
+
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
|
| 995 |
+
else:
|
| 996 |
+
self.layers = nn.ModuleList([])
|
| 997 |
+
|
| 998 |
+
dpr = [x.item() for x in torch.linspace(0, args.decoder_drop_path_rate, args.decoder_layers)]
|
| 999 |
+
self.layers.extend(
|
| 1000 |
+
[
|
| 1001 |
+
self.build_decoder_layer(args, no_encoder_attn, drop_path_rate=dpr[i])
|
| 1002 |
+
for i in range(args.decoder_layers)
|
| 1003 |
+
]
|
| 1004 |
+
)
|
| 1005 |
+
self.num_layers = len(self.layers)
|
| 1006 |
+
|
| 1007 |
+
if args.decoder_normalize_before:
|
| 1008 |
+
self.layer_norm = LayerNorm(embed_dim)
|
| 1009 |
+
else:
|
| 1010 |
+
self.layer_norm = None
|
| 1011 |
+
|
| 1012 |
+
self.project_out_dim = (
|
| 1013 |
+
Linear(embed_dim, self.output_embed_dim, bias=False)
|
| 1014 |
+
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
|
| 1015 |
+
else None
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
self.adaptive_softmax = None
|
| 1019 |
+
self.output_projection = output_projection
|
| 1020 |
+
if self.output_projection is None:
|
| 1021 |
+
self.build_output_projection(args, dictionary, embed_tokens)
|
| 1022 |
+
|
| 1023 |
+
token_bucket_size = args.token_bucket_size
|
| 1024 |
+
token_num_rel_dis = 2 * token_bucket_size - 1
|
| 1025 |
+
token_rp_bucket = make_token_bucket_position(token_bucket_size)
|
| 1026 |
+
self.token_rel_pos_table_list = nn.ModuleList(
|
| 1027 |
+
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
image_bucket_size = args.image_bucket_size
|
| 1031 |
+
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
|
| 1032 |
+
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
|
| 1033 |
+
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
|
| 1034 |
+
torch.arange(self.window_size).unsqueeze(1) * image_bucket_size + 1
|
| 1035 |
+
image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)])
|
| 1036 |
+
image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 768)])
|
| 1037 |
+
self.image_rel_pos_table_list = nn.ModuleList(
|
| 1038 |
+
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
self.register_buffer("token_rp_bucket", token_rp_bucket)
|
| 1042 |
+
self.register_buffer("image_rp_bucket", image_rp_bucket)
|
| 1043 |
+
self.register_buffer("image_position_idx", image_position_idx)
|
| 1044 |
+
self.entangle_position_embedding = args.entangle_position_embedding
|
| 1045 |
+
|
| 1046 |
+
def build_output_projection(self, args, dictionary, embed_tokens):
|
| 1047 |
+
if args.adaptive_softmax_cutoff is not None:
|
| 1048 |
+
self.adaptive_softmax = AdaptiveSoftmax(
|
| 1049 |
+
len(dictionary),
|
| 1050 |
+
self.output_embed_dim,
|
| 1051 |
+
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
|
| 1052 |
+
dropout=args.adaptive_softmax_dropout,
|
| 1053 |
+
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
|
| 1054 |
+
factor=args.adaptive_softmax_factor,
|
| 1055 |
+
tie_proj=args.tie_adaptive_proj,
|
| 1056 |
+
)
|
| 1057 |
+
elif self.share_input_output_embed:
|
| 1058 |
+
self.output_projection = nn.Linear(
|
| 1059 |
+
self.embed_tokens.weight.shape[1],
|
| 1060 |
+
self.embed_tokens.weight.shape[0],
|
| 1061 |
+
bias=False,
|
| 1062 |
+
)
|
| 1063 |
+
self.output_projection.weight = self.embed_tokens.weight
|
| 1064 |
+
else:
|
| 1065 |
+
self.output_projection = nn.Linear(
|
| 1066 |
+
self.output_embed_dim, len(dictionary), bias=False
|
| 1067 |
+
)
|
| 1068 |
+
nn.init.normal_(
|
| 1069 |
+
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
|
| 1070 |
+
)
|
| 1071 |
+
num_base_layers = getattr(args, "base_layers", 0)
|
| 1072 |
+
for i in range(num_base_layers):
|
| 1073 |
+
self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args))
|
| 1074 |
+
|
| 1075 |
+
def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0):
|
| 1076 |
+
layer = TransformerDecoderLayer(args, no_encoder_attn, drop_path_rate=drop_path_rate)
|
| 1077 |
+
checkpoint = getattr(args, "checkpoint_activations", False)
|
| 1078 |
+
if checkpoint:
|
| 1079 |
+
offload_to_cpu = getattr(args, "offload_activations", False)
|
| 1080 |
+
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
| 1081 |
+
# if we are checkpointing, enforce that FSDP always wraps the
|
| 1082 |
+
# checkpointed layer, regardless of layer size
|
| 1083 |
+
min_params_to_wrap = (
|
| 1084 |
+
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
|
| 1085 |
+
if not checkpoint else 0
|
| 1086 |
+
)
|
| 1087 |
+
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
| 1088 |
+
return layer
|
| 1089 |
+
|
| 1090 |
+
def get_rel_pos_bias(self, x, idx):
|
| 1091 |
+
seq_len = x.size(1)
|
| 1092 |
+
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
|
| 1093 |
+
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
|
| 1094 |
+
values = values.permute([2, 0, 1])
|
| 1095 |
+
return values.contiguous()
|
| 1096 |
+
|
| 1097 |
+
def get_image_rel_pos_bias(self, x, idx):
|
| 1098 |
+
seq_len = x.size(1)
|
| 1099 |
+
image_position_idx = self.image_position_idx[:seq_len]
|
| 1100 |
+
rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx]
|
| 1101 |
+
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
|
| 1102 |
+
values = values.permute(2, 0, 1)
|
| 1103 |
+
return values
|
| 1104 |
+
|
| 1105 |
+
def get_pos_info(self, tokens, tgt_pos_embed, src_pos_embed=None, use_image=False):
|
| 1106 |
+
batch_size = tokens.size(0)
|
| 1107 |
+
tgt_len = tokens.size(1)
|
| 1108 |
+
tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
|
| 1109 |
+
if src_pos_embed is not None:
|
| 1110 |
+
src_len = src_pos_embed.size(1)
|
| 1111 |
+
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
|
| 1112 |
+
batch_size, tgt_len, self.num_attention_heads, -1
|
| 1113 |
+
).transpose(1, 2) * self.pos_scaling
|
| 1114 |
+
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
|
| 1115 |
+
batch_size, src_len, self.num_attention_heads, -1
|
| 1116 |
+
).transpose(1, 2)
|
| 1117 |
+
else:
|
| 1118 |
+
src_len = tgt_pos_embed.size(1)
|
| 1119 |
+
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
|
| 1120 |
+
batch_size, tgt_len, self.num_attention_heads, -1
|
| 1121 |
+
).transpose(1, 2) * self.pos_scaling
|
| 1122 |
+
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
|
| 1123 |
+
batch_size, src_len, self.num_attention_heads, -1
|
| 1124 |
+
).transpose(1, 2)
|
| 1125 |
+
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
|
| 1126 |
+
return abs_pos_bias
|
| 1127 |
+
|
| 1128 |
+
def forward(
|
| 1129 |
+
self,
|
| 1130 |
+
prev_output_tokens,
|
| 1131 |
+
code_masks: Optional[torch.Tensor] = None,
|
| 1132 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
| 1133 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 1134 |
+
features_only: bool = False,
|
| 1135 |
+
full_context_alignment: bool = False,
|
| 1136 |
+
alignment_layer: Optional[int] = None,
|
| 1137 |
+
alignment_heads: Optional[int] = None,
|
| 1138 |
+
src_lengths: Optional[Any] = None,
|
| 1139 |
+
return_all_hiddens: bool = False,
|
| 1140 |
+
):
|
| 1141 |
+
"""
|
| 1142 |
+
Args:
|
| 1143 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
| 1144 |
+
`(batch, tgt_len)`, for teacher forcing
|
| 1145 |
+
encoder_out (optional): output from the encoder, used for
|
| 1146 |
+
encoder-side attention, should be of size T x B x C
|
| 1147 |
+
incremental_state (dict): dictionary used for storing state during
|
| 1148 |
+
:ref:`Incremental decoding`
|
| 1149 |
+
features_only (bool, optional): only return features without
|
| 1150 |
+
applying output layer (default: False).
|
| 1151 |
+
full_context_alignment (bool, optional): don't apply
|
| 1152 |
+
auto-regressive mask to self-attention (default: False).
|
| 1153 |
+
|
| 1154 |
+
Returns:
|
| 1155 |
+
tuple:
|
| 1156 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
| 1157 |
+
- a dictionary with any model-specific outputs
|
| 1158 |
+
"""
|
| 1159 |
+
|
| 1160 |
+
x, extra = self.extract_features(
|
| 1161 |
+
prev_output_tokens,
|
| 1162 |
+
code_masks=code_masks,
|
| 1163 |
+
encoder_out=encoder_out,
|
| 1164 |
+
incremental_state=incremental_state,
|
| 1165 |
+
full_context_alignment=full_context_alignment,
|
| 1166 |
+
alignment_layer=alignment_layer,
|
| 1167 |
+
alignment_heads=alignment_heads,
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
if not features_only:
|
| 1171 |
+
x = self.output_layer(x)
|
| 1172 |
+
return x, extra
|
| 1173 |
+
|
| 1174 |
+
def extract_features(
|
| 1175 |
+
self,
|
| 1176 |
+
prev_output_tokens,
|
| 1177 |
+
code_masks: Optional[torch.Tensor],
|
| 1178 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
| 1179 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 1180 |
+
full_context_alignment: bool = False,
|
| 1181 |
+
alignment_layer: Optional[int] = None,
|
| 1182 |
+
alignment_heads: Optional[int] = None,
|
| 1183 |
+
):
|
| 1184 |
+
return self.extract_features_scriptable(
|
| 1185 |
+
prev_output_tokens,
|
| 1186 |
+
code_masks,
|
| 1187 |
+
encoder_out,
|
| 1188 |
+
incremental_state,
|
| 1189 |
+
full_context_alignment,
|
| 1190 |
+
alignment_layer,
|
| 1191 |
+
alignment_heads,
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
"""
|
| 1195 |
+
A scriptable subclass of this class has an extract_features method and calls
|
| 1196 |
+
super().extract_features, but super() is not supported in torchscript. A copy of
|
| 1197 |
+
this function is made to be used in the subclass instead.
|
| 1198 |
+
"""
|
| 1199 |
+
|
| 1200 |
+
def extract_features_scriptable(
|
| 1201 |
+
self,
|
| 1202 |
+
prev_output_tokens,
|
| 1203 |
+
code_masks: Optional[torch.Tensor],
|
| 1204 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
| 1205 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 1206 |
+
full_context_alignment: bool = False,
|
| 1207 |
+
alignment_layer: Optional[int] = None,
|
| 1208 |
+
alignment_heads: Optional[int] = None,
|
| 1209 |
+
):
|
| 1210 |
+
"""
|
| 1211 |
+
Similar to *forward* but only return features.
|
| 1212 |
+
|
| 1213 |
+
Includes several features from "Jointly Learning to Align and
|
| 1214 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
| 1215 |
+
|
| 1216 |
+
Args:
|
| 1217 |
+
full_context_alignment (bool, optional): don't apply
|
| 1218 |
+
auto-regressive mask to self-attention (default: False).
|
| 1219 |
+
alignment_layer (int, optional): return mean alignment over
|
| 1220 |
+
heads at this layer (default: last layer).
|
| 1221 |
+
alignment_heads (int, optional): only average alignment over
|
| 1222 |
+
this many heads (default: all heads).
|
| 1223 |
+
|
| 1224 |
+
Returns:
|
| 1225 |
+
tuple:
|
| 1226 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
| 1227 |
+
- a dictionary with any model-specific outputs
|
| 1228 |
+
"""
|
| 1229 |
+
bs, slen = prev_output_tokens.size()
|
| 1230 |
+
if alignment_layer is None:
|
| 1231 |
+
alignment_layer = self.num_layers - 1
|
| 1232 |
+
|
| 1233 |
+
enc: Optional[Tensor] = None
|
| 1234 |
+
padding_mask: Optional[Tensor] = None
|
| 1235 |
+
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
|
| 1236 |
+
enc = encoder_out["encoder_out"][0]
|
| 1237 |
+
assert (
|
| 1238 |
+
enc.size()[1] == bs
|
| 1239 |
+
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
|
| 1240 |
+
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
|
| 1241 |
+
padding_mask = encoder_out["encoder_padding_mask"][0]
|
| 1242 |
+
|
| 1243 |
+
bsz, tgt_len = prev_output_tokens.shape
|
| 1244 |
+
token_position_idx = utils.new_arange(prev_output_tokens)
|
| 1245 |
+
tgt_pos_embed = self.embed_positions(token_position_idx)
|
| 1246 |
+
if code_masks is not None and torch.any(code_masks):
|
| 1247 |
+
image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len)
|
| 1248 |
+
tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks]
|
| 1249 |
+
|
| 1250 |
+
# self attn position bias
|
| 1251 |
+
self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False)
|
| 1252 |
+
if code_masks is not None and torch.any(code_masks):
|
| 1253 |
+
self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True)
|
| 1254 |
+
self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks]
|
| 1255 |
+
# cross attn position bias
|
| 1256 |
+
src_pos_embed = encoder_out['position_embeddings'][0]
|
| 1257 |
+
cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed)
|
| 1258 |
+
if code_masks is not None and torch.any(code_masks):
|
| 1259 |
+
cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True)
|
| 1260 |
+
cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks]
|
| 1261 |
+
cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:])
|
| 1262 |
+
|
| 1263 |
+
all_prev_output_tokens = prev_output_tokens.clone()
|
| 1264 |
+
if incremental_state is not None:
|
| 1265 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
| 1266 |
+
cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :]
|
| 1267 |
+
tgt_pos_embed = tgt_pos_embed[:, -1:, :]
|
| 1268 |
+
|
| 1269 |
+
# embed tokens and positions
|
| 1270 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
| 1271 |
+
|
| 1272 |
+
if self.quant_noise is not None:
|
| 1273 |
+
x = self.quant_noise(x)
|
| 1274 |
+
|
| 1275 |
+
if self.project_in_dim is not None:
|
| 1276 |
+
x = self.project_in_dim(x)
|
| 1277 |
+
|
| 1278 |
+
if self.entangle_position_embedding is not None and not self.args.disable_entangle:
|
| 1279 |
+
x += tgt_pos_embed
|
| 1280 |
+
|
| 1281 |
+
if self.layernorm_embedding is not None:
|
| 1282 |
+
if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False):
|
| 1283 |
+
x = self.layernorm_embedding(x)
|
| 1284 |
+
elif code_masks is not None and code_masks.all():
|
| 1285 |
+
x = self.code_layernorm_embedding(x)
|
| 1286 |
+
else:
|
| 1287 |
+
x[~code_masks] = self.layernorm_embedding(x[~code_masks])
|
| 1288 |
+
x[code_masks] = self.code_layernorm_embedding(x[code_masks])
|
| 1289 |
+
|
| 1290 |
+
x = self.dropout_module(x)
|
| 1291 |
+
|
| 1292 |
+
# B x T x C -> T x B x C
|
| 1293 |
+
x = x.transpose(0, 1)
|
| 1294 |
+
|
| 1295 |
+
self_attn_padding_mask: Optional[Tensor] = None
|
| 1296 |
+
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
| 1297 |
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
| 1298 |
+
|
| 1299 |
+
# decoder layers
|
| 1300 |
+
attn: Optional[Tensor] = None
|
| 1301 |
+
inner_states: List[Optional[Tensor]] = [x]
|
| 1302 |
+
for idx, layer in enumerate(self.layers):
|
| 1303 |
+
if incremental_state is None and not full_context_alignment:
|
| 1304 |
+
self_attn_mask = self.buffered_future_mask(x)
|
| 1305 |
+
else:
|
| 1306 |
+
self_attn_mask = None
|
| 1307 |
+
|
| 1308 |
+
self_attn_bias = self_abs_pos_bias.clone()
|
| 1309 |
+
if code_masks is None or not code_masks.any():
|
| 1310 |
+
self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
|
| 1311 |
+
elif code_masks is not None and code_masks.all():
|
| 1312 |
+
self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
|
| 1313 |
+
else:
|
| 1314 |
+
self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
|
| 1315 |
+
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
|
| 1316 |
+
self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:])
|
| 1317 |
+
if incremental_state is not None:
|
| 1318 |
+
self_attn_bias = self_attn_bias[:, -1:, :]
|
| 1319 |
+
|
| 1320 |
+
x, layer_attn, _ = layer(
|
| 1321 |
+
x,
|
| 1322 |
+
enc,
|
| 1323 |
+
padding_mask,
|
| 1324 |
+
incremental_state,
|
| 1325 |
+
self_attn_mask=self_attn_mask,
|
| 1326 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
| 1327 |
+
need_attn=bool((idx == alignment_layer)),
|
| 1328 |
+
need_head_weights=bool((idx == alignment_layer)),
|
| 1329 |
+
self_attn_bias=self_attn_bias,
|
| 1330 |
+
cross_attn_bias=cross_abs_pos_bias
|
| 1331 |
+
)
|
| 1332 |
+
inner_states.append(x)
|
| 1333 |
+
if layer_attn is not None and idx == alignment_layer:
|
| 1334 |
+
attn = layer_attn.float().to(x)
|
| 1335 |
+
|
| 1336 |
+
if attn is not None:
|
| 1337 |
+
if alignment_heads is not None:
|
| 1338 |
+
attn = attn[:alignment_heads]
|
| 1339 |
+
|
| 1340 |
+
# average probabilities over heads
|
| 1341 |
+
attn = attn.mean(dim=0)
|
| 1342 |
+
|
| 1343 |
+
if self.layer_norm is not None:
|
| 1344 |
+
x = self.layer_norm(x)
|
| 1345 |
+
|
| 1346 |
+
# T x B x C -> B x T x C
|
| 1347 |
+
x = x.transpose(0, 1)
|
| 1348 |
+
|
| 1349 |
+
if self.project_out_dim is not None:
|
| 1350 |
+
x = self.project_out_dim(x)
|
| 1351 |
+
|
| 1352 |
+
return x, {"attn": [attn], "inner_states": inner_states}
|
| 1353 |
+
|
| 1354 |
+
def output_layer(self, features):
|
| 1355 |
+
"""Project features to the vocabulary size."""
|
| 1356 |
+
if self.adaptive_softmax is None:
|
| 1357 |
+
# project back to size of vocabulary
|
| 1358 |
+
return self.output_projection(features)
|
| 1359 |
+
else:
|
| 1360 |
+
return features
|
| 1361 |
+
|
| 1362 |
+
def max_positions(self):
|
| 1363 |
+
"""Maximum output length supported by the decoder."""
|
| 1364 |
+
if self.embed_positions is None:
|
| 1365 |
+
return self.max_target_positions
|
| 1366 |
+
return self.max_target_positions
|
| 1367 |
+
|
| 1368 |
+
def buffered_future_mask(self, tensor):
|
| 1369 |
+
dim = tensor.size(0)
|
| 1370 |
+
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
| 1371 |
+
if (
|
| 1372 |
+
self._future_mask.size(0) == 0
|
| 1373 |
+
or (not self._future_mask.device == tensor.device)
|
| 1374 |
+
or self._future_mask.size(0) < dim
|
| 1375 |
+
):
|
| 1376 |
+
self._future_mask = torch.triu(
|
| 1377 |
+
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
|
| 1378 |
+
)
|
| 1379 |
+
self._future_mask = self._future_mask.to(tensor)
|
| 1380 |
+
return self._future_mask[:dim, :dim]
|
| 1381 |
+
|
| 1382 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 1383 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
| 1384 |
+
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
| 1385 |
+
weights_key = "{}.embed_positions.weights".format(name)
|
| 1386 |
+
if weights_key in state_dict:
|
| 1387 |
+
del state_dict[weights_key]
|
| 1388 |
+
state_dict[
|
| 1389 |
+
"{}.embed_positions._float_tensor".format(name)
|
| 1390 |
+
] = torch.FloatTensor(1)
|
| 1391 |
+
|
| 1392 |
+
if f"{name}.output_projection.weight" not in state_dict:
|
| 1393 |
+
if self.share_input_output_embed:
|
| 1394 |
+
embed_out_key = f"{name}.embed_tokens.weight"
|
| 1395 |
+
else:
|
| 1396 |
+
embed_out_key = f"{name}.embed_out"
|
| 1397 |
+
if embed_out_key in state_dict:
|
| 1398 |
+
state_dict[f"{name}.output_projection.weight"] = state_dict[
|
| 1399 |
+
embed_out_key
|
| 1400 |
+
]
|
| 1401 |
+
if not self.share_input_output_embed:
|
| 1402 |
+
del state_dict[embed_out_key]
|
| 1403 |
+
|
| 1404 |
+
for i in range(self.num_layers):
|
| 1405 |
+
# update layer norms
|
| 1406 |
+
self.layers[i].upgrade_state_dict_named(
|
| 1407 |
+
state_dict, "{}.layers.{}".format(name, i)
|
| 1408 |
+
)
|
| 1409 |
+
|
| 1410 |
+
# version_key = "{}.version".format(name)
|
| 1411 |
+
# if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
| 1412 |
+
# # earlier checkpoints did not normalize after the stack of layers
|
| 1413 |
+
# self.layer_norm = None
|
| 1414 |
+
# self.normalize = False
|
| 1415 |
+
# state_dict[version_key] = torch.Tensor([1])
|
| 1416 |
+
|
| 1417 |
+
prefix = name + "." if name != "" else ""
|
| 1418 |
+
image_params = ["image_position_idx"]
|
| 1419 |
+
for image_param in image_params:
|
| 1420 |
+
state_dict[prefix + image_param] = self.state_dict()[image_param]
|
| 1421 |
+
for param_name, param_tensor in self.state_dict().items():
|
| 1422 |
+
if (prefix + param_name) not in state_dict and param_name in self.state_dict():
|
| 1423 |
+
state_dict[prefix + param_name] = self.state_dict()[param_name]
|
| 1424 |
+
|
| 1425 |
+
if len(state_dict["decoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
|
| 1426 |
+
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["decoder.embed_image_positions.weight"])
|
| 1427 |
+
embed_dim = state_dict["decoder.embed_image_positions.weight"].size(1)
|
| 1428 |
+
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
|
| 1429 |
+
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
|
| 1430 |
+
new_pos_embed_to_add = new_pos_embed_to_add.to(
|
| 1431 |
+
dtype=state_dict["decoder.embed_image_positions.weight"].dtype,
|
| 1432 |
+
)
|
| 1433 |
+
state_dict["decoder.embed_image_positions.weight"] = torch.cat(
|
| 1434 |
+
[state_dict["decoder.embed_image_positions.weight"], new_pos_embed_to_add]
|
| 1435 |
+
)
|
| 1436 |
+
return state_dict
|
| 1437 |
+
|
| 1438 |
+
|
| 1439 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False):
|
| 1440 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
| 1441 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
| 1442 |
+
if padding_idx is not None:
|
| 1443 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
| 1444 |
+
if zero_init:
|
| 1445 |
+
nn.init.constant_(m.weight, 0)
|
| 1446 |
+
return m
|
| 1447 |
+
|
| 1448 |
+
|
| 1449 |
+
def Linear(in_features, out_features, bias=True):
|
| 1450 |
+
m = nn.Linear(in_features, out_features, bias)
|
| 1451 |
+
nn.init.xavier_uniform_(m.weight)
|
| 1452 |
+
if bias:
|
| 1453 |
+
nn.init.constant_(m.bias, 0.0)
|
| 1454 |
+
return m
|
| 1455 |
+
|
| 1456 |
+
|
| 1457 |
+
@register_model_architecture("unify_transformer", "unify_transformer")
|
| 1458 |
+
def base_architecture(args):
|
| 1459 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
| 1460 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 1461 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
| 1462 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 1463 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
| 1464 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 1465 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
| 1466 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
| 1467 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
| 1468 |
+
args.decoder_ffn_embed_dim = getattr(
|
| 1469 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
| 1470 |
+
)
|
| 1471 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 1472 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
| 1473 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
| 1474 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
| 1475 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
| 1476 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 1477 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
| 1478 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 1479 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
| 1480 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
| 1481 |
+
args.share_decoder_input_output_embed = getattr(
|
| 1482 |
+
args, "share_decoder_input_output_embed", False
|
| 1483 |
+
)
|
| 1484 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
| 1485 |
+
args.no_token_positional_embeddings = getattr(
|
| 1486 |
+
args, "no_token_positional_embeddings", False
|
| 1487 |
+
)
|
| 1488 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
| 1489 |
+
args.no_cross_attention = getattr(args, "no_cross_attention", False)
|
| 1490 |
+
args.cross_self_attention = getattr(args, "cross_self_attention", False)
|
| 1491 |
+
|
| 1492 |
+
args.decoder_output_dim = getattr(
|
| 1493 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
| 1494 |
+
)
|
| 1495 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
| 1496 |
+
|
| 1497 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
| 1498 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
| 1499 |
+
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
| 1500 |
+
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
| 1501 |
+
args.offload_activations = getattr(args, "offload_activations", False)
|
| 1502 |
+
if args.offload_activations:
|
| 1503 |
+
args.checkpoint_activations = True
|
| 1504 |
+
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
| 1505 |
+
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
| 1506 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
| 1507 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
| 1508 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
| 1509 |
+
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
| 1510 |
+
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
|
models/ofa/unify_transformer_layer.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from fairseq import utils
|
| 11 |
+
from fairseq.modules import LayerNorm
|
| 12 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
| 13 |
+
from fairseq.modules.quant_noise import quant_noise
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
|
| 16 |
+
from .unify_multihead_attention import MultiheadAttention
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 20 |
+
"""
|
| 21 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 22 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 23 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 24 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 25 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 26 |
+
argument.
|
| 27 |
+
"""
|
| 28 |
+
if drop_prob == 0.0 or not training:
|
| 29 |
+
return x
|
| 30 |
+
keep_prob = 1 - drop_prob
|
| 31 |
+
shape = (1, x.shape[1], 1)
|
| 32 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 33 |
+
random_tensor.floor_() # binarize
|
| 34 |
+
output = x.div(keep_prob) * random_tensor
|
| 35 |
+
return output
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DropPath(nn.Module):
|
| 39 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, drop_prob=None):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.drop_prob = drop_prob
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 47 |
+
|
| 48 |
+
def extra_repr(self) -> str:
|
| 49 |
+
return "p={}".format(self.drop_prob)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TransformerEncoderLayer(nn.Module):
|
| 53 |
+
"""Encoder layer block.
|
| 54 |
+
|
| 55 |
+
In the original paper each operation (multi-head attention or FFN) is
|
| 56 |
+
postprocessed with: `dropout -> add residual -> layernorm`. In the
|
| 57 |
+
tensor2tensor code they suggest that learning is more robust when
|
| 58 |
+
preprocessing each layer with layernorm and postprocessing with:
|
| 59 |
+
`dropout -> add residual`. We default to the approach in the paper, but the
|
| 60 |
+
tensor2tensor approach can be enabled by setting
|
| 61 |
+
*args.encoder_normalize_before* to ``True``.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, args, drop_path_rate=0.0):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.args = args
|
| 70 |
+
self.embed_dim = args.encoder_embed_dim
|
| 71 |
+
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
|
| 72 |
+
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
|
| 73 |
+
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
| 74 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
| 75 |
+
self.dropout_module = FairseqDropout(
|
| 76 |
+
args.dropout, module_name=self.__class__.__name__
|
| 77 |
+
)
|
| 78 |
+
self.activation_fn = utils.get_activation_fn(
|
| 79 |
+
activation=getattr(args, 'activation_fn', 'relu') or "relu"
|
| 80 |
+
)
|
| 81 |
+
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
|
| 82 |
+
if activation_dropout_p == 0:
|
| 83 |
+
# for backwards compatibility with models that use args.relu_dropout
|
| 84 |
+
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
|
| 85 |
+
self.activation_dropout_module = FairseqDropout(
|
| 86 |
+
float(activation_dropout_p), module_name=self.__class__.__name__
|
| 87 |
+
)
|
| 88 |
+
self.normalize_before = args.encoder_normalize_before
|
| 89 |
+
self.fc1 = self.build_fc1(
|
| 90 |
+
self.embed_dim,
|
| 91 |
+
args.encoder_ffn_embed_dim,
|
| 92 |
+
self.quant_noise,
|
| 93 |
+
self.quant_noise_block_size,
|
| 94 |
+
)
|
| 95 |
+
self.fc2 = self.build_fc2(
|
| 96 |
+
args.encoder_ffn_embed_dim,
|
| 97 |
+
self.embed_dim,
|
| 98 |
+
self.quant_noise,
|
| 99 |
+
self.quant_noise_block_size,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None
|
| 103 |
+
self.nh = self.self_attn.num_heads
|
| 104 |
+
self.head_dim = self.self_attn.head_dim
|
| 105 |
+
|
| 106 |
+
self.ffn_layernorm = LayerNorm(args.encoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None
|
| 107 |
+
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None
|
| 108 |
+
|
| 109 |
+
self.final_layer_norm = LayerNorm(self.embed_dim)
|
| 110 |
+
|
| 111 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 112 |
+
|
| 113 |
+
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
|
| 114 |
+
return quant_noise(
|
| 115 |
+
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
|
| 119 |
+
return quant_noise(
|
| 120 |
+
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def build_self_attention(self, embed_dim, args):
|
| 124 |
+
return MultiheadAttention(
|
| 125 |
+
embed_dim,
|
| 126 |
+
args.encoder_attention_heads,
|
| 127 |
+
dropout=args.attention_dropout,
|
| 128 |
+
self_attention=True,
|
| 129 |
+
q_noise=self.quant_noise,
|
| 130 |
+
qn_block_size=self.quant_noise_block_size,
|
| 131 |
+
scale_factor=args.attn_scale_factor,
|
| 132 |
+
scale_heads=getattr(args, 'scale_heads', False)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def residual_connection(self, x, residual):
|
| 136 |
+
return residual + self.drop_path(x)
|
| 137 |
+
|
| 138 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 139 |
+
"""
|
| 140 |
+
Rename layer norm states from `...layer_norms.0.weight` to
|
| 141 |
+
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
|
| 142 |
+
`...final_layer_norm.weight`
|
| 143 |
+
"""
|
| 144 |
+
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
|
| 145 |
+
for old, new in layer_norm_map.items():
|
| 146 |
+
for m in ("weight", "bias"):
|
| 147 |
+
k = "{}.layer_norms.{}.{}".format(name, old, m)
|
| 148 |
+
if k in state_dict:
|
| 149 |
+
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
|
| 150 |
+
del state_dict[k]
|
| 151 |
+
if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict():
|
| 152 |
+
state_dict[
|
| 153 |
+
"{}.{}.{}".format(name, new, m)
|
| 154 |
+
] = self.state_dict()["{}.{}".format(new, m)]
|
| 155 |
+
|
| 156 |
+
prefix = name + "." if name != "" else ""
|
| 157 |
+
for param_name, param_tensor in self.state_dict().items():
|
| 158 |
+
if (prefix + param_name) not in state_dict and param_name in self.state_dict():
|
| 159 |
+
state_dict[prefix + param_name] = self.state_dict()[param_name]
|
| 160 |
+
|
| 161 |
+
def forward(
|
| 162 |
+
self,
|
| 163 |
+
x,
|
| 164 |
+
encoder_padding_mask: Optional[Tensor],
|
| 165 |
+
attn_mask: Optional[Tensor] = None,
|
| 166 |
+
self_attn_bias: Optional[Tensor] = None
|
| 167 |
+
):
|
| 168 |
+
"""
|
| 169 |
+
Args:
|
| 170 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 171 |
+
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
| 172 |
+
`(batch, seq_len)` where padding elements are indicated by ``1``.
|
| 173 |
+
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
|
| 174 |
+
where `tgt_len` is the length of output and `src_len` is the
|
| 175 |
+
length of input, though here both are equal to `seq_len`.
|
| 176 |
+
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
|
| 177 |
+
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
|
| 178 |
+
useful for strided self-attention.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
encoded output of shape `(seq_len, batch, embed_dim)`
|
| 182 |
+
"""
|
| 183 |
+
# anything in original attn_mask = 1, becomes -1e8
|
| 184 |
+
# anything in original attn_mask = 0, becomes 0
|
| 185 |
+
# Note that we cannot use -inf here, because at some edge cases,
|
| 186 |
+
# the attention weight (before softmax) for some padded element in query
|
| 187 |
+
# will become -inf, which results in NaN in model parameters
|
| 188 |
+
if attn_mask is not None:
|
| 189 |
+
attn_mask = attn_mask.masked_fill(
|
| 190 |
+
attn_mask.to(torch.bool),
|
| 191 |
+
-1e8 if x.dtype == torch.float32 else -1e4
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
residual = x
|
| 195 |
+
if self.normalize_before:
|
| 196 |
+
x = self.self_attn_layer_norm(x)
|
| 197 |
+
x, _ = self.self_attn(
|
| 198 |
+
query=x,
|
| 199 |
+
key=x,
|
| 200 |
+
value=x,
|
| 201 |
+
key_padding_mask=encoder_padding_mask,
|
| 202 |
+
need_weights=False,
|
| 203 |
+
attn_mask=attn_mask,
|
| 204 |
+
attn_bias=self_attn_bias
|
| 205 |
+
)
|
| 206 |
+
if self.attn_ln is not None:
|
| 207 |
+
x = self.attn_ln(x)
|
| 208 |
+
x = self.dropout_module(x)
|
| 209 |
+
x = self.residual_connection(x, residual)
|
| 210 |
+
if not self.normalize_before:
|
| 211 |
+
x = self.self_attn_layer_norm(x)
|
| 212 |
+
|
| 213 |
+
residual = x
|
| 214 |
+
if self.normalize_before:
|
| 215 |
+
x = self.final_layer_norm(x)
|
| 216 |
+
x = self.activation_fn(self.fc1(x))
|
| 217 |
+
x = self.activation_dropout_module(x)
|
| 218 |
+
if self.ffn_layernorm is not None:
|
| 219 |
+
x = self.ffn_layernorm(x)
|
| 220 |
+
x = self.fc2(x)
|
| 221 |
+
x = self.dropout_module(x)
|
| 222 |
+
if self.w_resid is not None:
|
| 223 |
+
residual = torch.mul(self.w_resid, residual)
|
| 224 |
+
x = self.residual_connection(x, residual)
|
| 225 |
+
if not self.normalize_before:
|
| 226 |
+
x = self.final_layer_norm(x)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class TransformerDecoderLayer(nn.Module):
|
| 231 |
+
"""Decoder layer block.
|
| 232 |
+
|
| 233 |
+
In the original paper each operation (multi-head attention, encoder
|
| 234 |
+
attention or FFN) is postprocessed with: `dropout -> add residual ->
|
| 235 |
+
layernorm`. In the tensor2tensor code they suggest that learning is more
|
| 236 |
+
robust when preprocessing each layer with layernorm and postprocessing with:
|
| 237 |
+
`dropout -> add residual`. We default to the approach in the paper, but the
|
| 238 |
+
tensor2tensor approach can be enabled by setting
|
| 239 |
+
*args.decoder_normalize_before* to ``True``.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 243 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
| 244 |
+
(default: False).
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(
|
| 248 |
+
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, drop_path_rate=0.0
|
| 249 |
+
):
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.embed_dim = args.decoder_embed_dim
|
| 252 |
+
self.dropout_module = FairseqDropout(
|
| 253 |
+
args.dropout, module_name=self.__class__.__name__
|
| 254 |
+
)
|
| 255 |
+
self.quant_noise = getattr(args, "quant_noise_pq", 0)
|
| 256 |
+
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
| 257 |
+
|
| 258 |
+
self.cross_self_attention = getattr(args, "cross_self_attention", False)
|
| 259 |
+
|
| 260 |
+
self.self_attn = self.build_self_attention(
|
| 261 |
+
self.embed_dim,
|
| 262 |
+
args,
|
| 263 |
+
add_bias_kv=add_bias_kv,
|
| 264 |
+
add_zero_attn=add_zero_attn,
|
| 265 |
+
)
|
| 266 |
+
self.self_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None
|
| 267 |
+
self.cross_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None
|
| 268 |
+
self.nh = self.self_attn.num_heads
|
| 269 |
+
self.head_dim = self.self_attn.head_dim
|
| 270 |
+
|
| 271 |
+
self.activation_fn = utils.get_activation_fn(
|
| 272 |
+
activation=str(args.activation_fn)
|
| 273 |
+
if getattr(args, "activation_fn", None) is not None
|
| 274 |
+
else "relu"
|
| 275 |
+
)
|
| 276 |
+
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
|
| 277 |
+
if activation_dropout_p == 0:
|
| 278 |
+
# for backwards compatibility with models that use args.relu_dropout
|
| 279 |
+
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
|
| 280 |
+
self.activation_dropout_module = FairseqDropout(
|
| 281 |
+
float(activation_dropout_p), module_name=self.__class__.__name__
|
| 282 |
+
)
|
| 283 |
+
self.normalize_before = args.decoder_normalize_before
|
| 284 |
+
|
| 285 |
+
# use layerNorm rather than FusedLayerNorm for exporting.
|
| 286 |
+
# char_inputs can be used to determint this.
|
| 287 |
+
# TODO remove this once we update apex with the fix
|
| 288 |
+
export = getattr(args, "char_inputs", False)
|
| 289 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
| 290 |
+
|
| 291 |
+
if no_encoder_attn:
|
| 292 |
+
self.encoder_attn = None
|
| 293 |
+
self.encoder_attn_layer_norm = None
|
| 294 |
+
else:
|
| 295 |
+
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
| 296 |
+
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
| 297 |
+
|
| 298 |
+
self.ffn_layernorm = LayerNorm(args.decoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None
|
| 299 |
+
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None
|
| 300 |
+
|
| 301 |
+
self.fc1 = self.build_fc1(
|
| 302 |
+
self.embed_dim,
|
| 303 |
+
args.decoder_ffn_embed_dim,
|
| 304 |
+
self.quant_noise,
|
| 305 |
+
self.quant_noise_block_size,
|
| 306 |
+
)
|
| 307 |
+
self.fc2 = self.build_fc2(
|
| 308 |
+
args.decoder_ffn_embed_dim,
|
| 309 |
+
self.embed_dim,
|
| 310 |
+
self.quant_noise,
|
| 311 |
+
self.quant_noise_block_size,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
|
| 315 |
+
self.need_attn = True
|
| 316 |
+
|
| 317 |
+
self.onnx_trace = False
|
| 318 |
+
|
| 319 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 320 |
+
|
| 321 |
+
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
|
| 322 |
+
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
|
| 323 |
+
|
| 324 |
+
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
|
| 325 |
+
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
|
| 326 |
+
|
| 327 |
+
def build_self_attention(
|
| 328 |
+
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
|
| 329 |
+
):
|
| 330 |
+
return MultiheadAttention(
|
| 331 |
+
embed_dim,
|
| 332 |
+
args.decoder_attention_heads,
|
| 333 |
+
dropout=args.attention_dropout,
|
| 334 |
+
add_bias_kv=add_bias_kv,
|
| 335 |
+
add_zero_attn=add_zero_attn,
|
| 336 |
+
self_attention=not getattr(args, "cross_self_attention", False),
|
| 337 |
+
q_noise=self.quant_noise,
|
| 338 |
+
qn_block_size=self.quant_noise_block_size,
|
| 339 |
+
scale_factor=args.attn_scale_factor,
|
| 340 |
+
scale_heads=getattr(args, 'scale_heads', False)
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def build_encoder_attention(self, embed_dim, args):
|
| 344 |
+
return MultiheadAttention(
|
| 345 |
+
embed_dim,
|
| 346 |
+
args.decoder_attention_heads,
|
| 347 |
+
kdim=getattr(args, "encoder_embed_dim", None),
|
| 348 |
+
vdim=getattr(args, "encoder_embed_dim", None),
|
| 349 |
+
dropout=args.attention_dropout,
|
| 350 |
+
encoder_decoder_attention=True,
|
| 351 |
+
q_noise=self.quant_noise,
|
| 352 |
+
qn_block_size=self.quant_noise_block_size,
|
| 353 |
+
scale_factor=args.attn_scale_factor,
|
| 354 |
+
scale_heads=getattr(args, 'scale_heads', False)
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
def prepare_for_onnx_export_(self):
|
| 358 |
+
self.onnx_trace = True
|
| 359 |
+
|
| 360 |
+
def residual_connection(self, x, residual):
|
| 361 |
+
return residual + self.drop_path(x)
|
| 362 |
+
|
| 363 |
+
def forward(
|
| 364 |
+
self,
|
| 365 |
+
x,
|
| 366 |
+
encoder_out: Optional[torch.Tensor] = None,
|
| 367 |
+
encoder_padding_mask: Optional[torch.Tensor] = None,
|
| 368 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 369 |
+
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
|
| 370 |
+
prev_attn_state: Optional[List[torch.Tensor]] = None,
|
| 371 |
+
self_attn_mask: Optional[torch.Tensor] = None,
|
| 372 |
+
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
| 373 |
+
need_attn: bool = False,
|
| 374 |
+
need_head_weights: bool = False,
|
| 375 |
+
self_attn_bias: Optional[Tensor] = None,
|
| 376 |
+
cross_attn_bias: Optional[Tensor] = None
|
| 377 |
+
):
|
| 378 |
+
"""
|
| 379 |
+
Args:
|
| 380 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 381 |
+
encoder_padding_mask (ByteTensor, optional): binary
|
| 382 |
+
ByteTensor of shape `(batch, src_len)` where padding
|
| 383 |
+
elements are indicated by ``1``.
|
| 384 |
+
need_attn (bool, optional): return attention weights
|
| 385 |
+
need_head_weights (bool, optional): return attention weights
|
| 386 |
+
for each head (default: return average over heads).
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
encoded output of shape `(seq_len, batch, embed_dim)`
|
| 390 |
+
"""
|
| 391 |
+
if need_head_weights:
|
| 392 |
+
need_attn = True
|
| 393 |
+
|
| 394 |
+
residual = x
|
| 395 |
+
if self.normalize_before:
|
| 396 |
+
x = self.self_attn_layer_norm(x)
|
| 397 |
+
if prev_self_attn_state is not None:
|
| 398 |
+
prev_key, prev_value = prev_self_attn_state[:2]
|
| 399 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
| 400 |
+
"prev_key": prev_key,
|
| 401 |
+
"prev_value": prev_value,
|
| 402 |
+
}
|
| 403 |
+
if len(prev_self_attn_state) >= 3:
|
| 404 |
+
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
|
| 405 |
+
assert incremental_state is not None
|
| 406 |
+
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
| 407 |
+
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
|
| 408 |
+
if self.cross_self_attention and not (
|
| 409 |
+
incremental_state is not None
|
| 410 |
+
and _self_attn_input_buffer is not None
|
| 411 |
+
and "prev_key" in _self_attn_input_buffer
|
| 412 |
+
):
|
| 413 |
+
if self_attn_mask is not None:
|
| 414 |
+
assert encoder_out is not None
|
| 415 |
+
self_attn_mask = torch.cat(
|
| 416 |
+
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
|
| 417 |
+
)
|
| 418 |
+
if self_attn_padding_mask is not None:
|
| 419 |
+
if encoder_padding_mask is None:
|
| 420 |
+
assert encoder_out is not None
|
| 421 |
+
encoder_padding_mask = self_attn_padding_mask.new_zeros(
|
| 422 |
+
encoder_out.size(1), encoder_out.size(0)
|
| 423 |
+
)
|
| 424 |
+
self_attn_padding_mask = torch.cat(
|
| 425 |
+
(encoder_padding_mask, self_attn_padding_mask), dim=1
|
| 426 |
+
)
|
| 427 |
+
assert encoder_out is not None
|
| 428 |
+
y = torch.cat((encoder_out, x), dim=0)
|
| 429 |
+
else:
|
| 430 |
+
y = x
|
| 431 |
+
|
| 432 |
+
x, attn = self.self_attn(
|
| 433 |
+
query=x,
|
| 434 |
+
key=y,
|
| 435 |
+
value=y,
|
| 436 |
+
key_padding_mask=self_attn_padding_mask,
|
| 437 |
+
incremental_state=incremental_state,
|
| 438 |
+
need_weights=False,
|
| 439 |
+
attn_mask=self_attn_mask,
|
| 440 |
+
attn_bias=self_attn_bias
|
| 441 |
+
)
|
| 442 |
+
if self.self_attn_ln is not None:
|
| 443 |
+
x = self.self_attn_ln(x)
|
| 444 |
+
x = self.dropout_module(x)
|
| 445 |
+
x = self.residual_connection(x, residual)
|
| 446 |
+
if not self.normalize_before:
|
| 447 |
+
x = self.self_attn_layer_norm(x)
|
| 448 |
+
|
| 449 |
+
if self.encoder_attn is not None and encoder_out is not None:
|
| 450 |
+
residual = x
|
| 451 |
+
if self.normalize_before:
|
| 452 |
+
x = self.encoder_attn_layer_norm(x)
|
| 453 |
+
if prev_attn_state is not None:
|
| 454 |
+
prev_key, prev_value = prev_attn_state[:2]
|
| 455 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
| 456 |
+
"prev_key": prev_key,
|
| 457 |
+
"prev_value": prev_value,
|
| 458 |
+
}
|
| 459 |
+
if len(prev_attn_state) >= 3:
|
| 460 |
+
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
| 461 |
+
assert incremental_state is not None
|
| 462 |
+
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
| 463 |
+
|
| 464 |
+
x, attn = self.encoder_attn(
|
| 465 |
+
query=x,
|
| 466 |
+
key=encoder_out,
|
| 467 |
+
value=encoder_out,
|
| 468 |
+
key_padding_mask=encoder_padding_mask,
|
| 469 |
+
incremental_state=incremental_state,
|
| 470 |
+
static_kv=True,
|
| 471 |
+
need_weights=need_attn or (not self.training and self.need_attn),
|
| 472 |
+
need_head_weights=need_head_weights,
|
| 473 |
+
attn_bias=cross_attn_bias
|
| 474 |
+
)
|
| 475 |
+
if self.cross_attn_ln is not None:
|
| 476 |
+
x = self.cross_attn_ln(x)
|
| 477 |
+
x = self.dropout_module(x)
|
| 478 |
+
x = self.residual_connection(x, residual)
|
| 479 |
+
if not self.normalize_before:
|
| 480 |
+
x = self.encoder_attn_layer_norm(x)
|
| 481 |
+
|
| 482 |
+
residual = x
|
| 483 |
+
if self.normalize_before:
|
| 484 |
+
x = self.final_layer_norm(x)
|
| 485 |
+
|
| 486 |
+
x = self.activation_fn(self.fc1(x))
|
| 487 |
+
x = self.activation_dropout_module(x)
|
| 488 |
+
if self.ffn_layernorm is not None:
|
| 489 |
+
x = self.ffn_layernorm(x)
|
| 490 |
+
x = self.fc2(x)
|
| 491 |
+
x = self.dropout_module(x)
|
| 492 |
+
if self.w_resid is not None:
|
| 493 |
+
residual = torch.mul(self.w_resid, residual)
|
| 494 |
+
x = self.residual_connection(x, residual)
|
| 495 |
+
if not self.normalize_before:
|
| 496 |
+
x = self.final_layer_norm(x)
|
| 497 |
+
if self.onnx_trace and incremental_state is not None:
|
| 498 |
+
saved_state = self.self_attn._get_input_buffer(incremental_state)
|
| 499 |
+
assert saved_state is not None
|
| 500 |
+
if self_attn_padding_mask is not None:
|
| 501 |
+
self_attn_state = [
|
| 502 |
+
saved_state["prev_key"],
|
| 503 |
+
saved_state["prev_value"],
|
| 504 |
+
saved_state["prev_key_padding_mask"],
|
| 505 |
+
]
|
| 506 |
+
else:
|
| 507 |
+
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
|
| 508 |
+
return x, attn, self_attn_state
|
| 509 |
+
return x, attn, None
|
| 510 |
+
|
| 511 |
+
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
|
| 512 |
+
self.need_attn = need_attn
|
| 513 |
+
|
| 514 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 515 |
+
"""
|
| 516 |
+
Rename layer norm states from `...layer_norms.0.weight` to
|
| 517 |
+
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
|
| 518 |
+
`...final_layer_norm.weight`
|
| 519 |
+
"""
|
| 520 |
+
# update layer norms
|
| 521 |
+
layer_norm_map = {
|
| 522 |
+
"0": "self_attn_layer_norm",
|
| 523 |
+
"1": "encoder_attn_layer_norm",
|
| 524 |
+
"2": "final_layer_norm",
|
| 525 |
+
}
|
| 526 |
+
for old, new in layer_norm_map.items():
|
| 527 |
+
for m in ("weight", "bias"):
|
| 528 |
+
k = "{}.layer_norms.{}.{}".format(name, old, m)
|
| 529 |
+
if k in state_dict:
|
| 530 |
+
state_dict[
|
| 531 |
+
"{}.{}.{}".format(name, new, m)
|
| 532 |
+
] = state_dict[k]
|
| 533 |
+
del state_dict[k]
|
| 534 |
+
if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict():
|
| 535 |
+
state_dict[
|
| 536 |
+
"{}.{}.{}".format(name, new, m)
|
| 537 |
+
] = self.state_dict()["{}.{}".format(new, m)]
|
| 538 |
+
|
| 539 |
+
prefix = name + "." if name != "" else ""
|
| 540 |
+
for param_name, param_tensor in self.state_dict().items():
|
| 541 |
+
if (prefix + param_name) not in state_dict and param_name in self.state_dict():
|
| 542 |
+
state_dict[prefix + param_name] = self.state_dict()[param_name]
|
models/search.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from fairseq.token_generation_constraints import (
|
| 12 |
+
ConstraintState,
|
| 13 |
+
OrderedConstraintState,
|
| 14 |
+
UnorderedConstraintState,
|
| 15 |
+
)
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Search(nn.Module):
|
| 20 |
+
def __init__(self, tgt_dict):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.pad = tgt_dict.pad()
|
| 23 |
+
self.unk = tgt_dict.unk()
|
| 24 |
+
self.eos = tgt_dict.eos()
|
| 25 |
+
self.vocab_size = len(tgt_dict)
|
| 26 |
+
self.src_lengths = torch.tensor(-1)
|
| 27 |
+
self.supports_constraints = False
|
| 28 |
+
self.stop_on_max_len = False
|
| 29 |
+
|
| 30 |
+
def step(
|
| 31 |
+
self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None
|
| 32 |
+
):
|
| 33 |
+
"""Take a single search step.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
step: the current search step, starting at 0
|
| 37 |
+
lprobs: (bsz x input_beam_size x vocab_size)
|
| 38 |
+
the model's log-probabilities over the vocabulary at the current step
|
| 39 |
+
scores: (bsz x input_beam_size x step)
|
| 40 |
+
the historical model scores of each hypothesis up to this point
|
| 41 |
+
prev_output_tokens: (bsz x step)
|
| 42 |
+
the previously generated oputput tokens
|
| 43 |
+
original_batch_idxs: (bsz)
|
| 44 |
+
the tensor with the batch indices, in the range [0, bsz)
|
| 45 |
+
this is useful in case there has been applied a re-ordering
|
| 46 |
+
and we need to know the orignal indices
|
| 47 |
+
|
| 48 |
+
Return: A tuple of (scores, indices, beams) where:
|
| 49 |
+
scores: (bsz x output_beam_size)
|
| 50 |
+
the scores of the chosen elements; output_beam_size can be
|
| 51 |
+
larger than input_beam_size, e.g., we may return
|
| 52 |
+
2*input_beam_size to account for EOS
|
| 53 |
+
indices: (bsz x output_beam_size)
|
| 54 |
+
the indices of the chosen elements
|
| 55 |
+
beams: (bsz x output_beam_size)
|
| 56 |
+
the hypothesis ids of the chosen elements, in the range [0, input_beam_size)
|
| 57 |
+
"""
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
@torch.jit.export
|
| 61 |
+
def set_src_lengths(self, src_lengths):
|
| 62 |
+
self.src_lengths = src_lengths
|
| 63 |
+
|
| 64 |
+
@torch.jit.export
|
| 65 |
+
def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):
|
| 66 |
+
"""Initialize constraint states for constrained decoding (if supported).
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
batch_constraints: (torch.Tensor, optional)
|
| 70 |
+
the list of constraints, in packed form
|
| 71 |
+
beam_size: (int)
|
| 72 |
+
the beam size
|
| 73 |
+
Returns:
|
| 74 |
+
*encoder_out* rearranged according to *new_order*
|
| 75 |
+
"""
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def prune_sentences(self, batch_idxs: Tensor):
|
| 79 |
+
"""
|
| 80 |
+
Removes constraint states for completed sentences (if supported).
|
| 81 |
+
This is called from sequence_generator._generate() when sentences are
|
| 82 |
+
deleted from the batch.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
batch_idxs: Indices of *sentences* whose constraint state should be *kept*.
|
| 86 |
+
"""
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
def update_constraints(self, active_hypos: Tensor):
|
| 90 |
+
"""
|
| 91 |
+
Updates the constraint states by selecting the beam items that are retained.
|
| 92 |
+
This is called at each time step of sequence_generator._generate() when
|
| 93 |
+
the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
active_hypos: (batch size, beam size)
|
| 97 |
+
list of integers denoting, for each sentence, which beam candidate items
|
| 98 |
+
should be kept.
|
| 99 |
+
"""
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class BeamSearch(Search):
|
| 104 |
+
def __init__(self, tgt_dict):
|
| 105 |
+
super().__init__(tgt_dict)
|
| 106 |
+
self.constraint_states = None
|
| 107 |
+
|
| 108 |
+
@torch.jit.export
|
| 109 |
+
def step(
|
| 110 |
+
self,
|
| 111 |
+
step: int,
|
| 112 |
+
lprobs,
|
| 113 |
+
scores: Optional[Tensor],
|
| 114 |
+
prev_output_tokens: Optional[Tensor] = None,
|
| 115 |
+
original_batch_idxs: Optional[Tensor] = None,
|
| 116 |
+
):
|
| 117 |
+
bsz, beam_size, vocab_size = lprobs.size()
|
| 118 |
+
|
| 119 |
+
if step == 0:
|
| 120 |
+
# at the first step all hypotheses are equally likely, so use
|
| 121 |
+
# only the first beam
|
| 122 |
+
lprobs = lprobs[:, ::beam_size, :].contiguous()
|
| 123 |
+
else:
|
| 124 |
+
# make probs contain cumulative scores for each hypothesis
|
| 125 |
+
assert scores is not None
|
| 126 |
+
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
|
| 127 |
+
|
| 128 |
+
top_prediction = torch.topk(
|
| 129 |
+
lprobs.view(bsz, -1),
|
| 130 |
+
k=min(
|
| 131 |
+
# Take the best 2 x beam_size predictions. We'll choose the first
|
| 132 |
+
# beam_size of these which don't predict eos to continue with.
|
| 133 |
+
beam_size * 2,
|
| 134 |
+
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
|
| 135 |
+
),
|
| 136 |
+
)
|
| 137 |
+
scores_buf = top_prediction[0]
|
| 138 |
+
indices_buf = top_prediction[1]
|
| 139 |
+
# Project back into relative indices and beams
|
| 140 |
+
beams_buf = indices_buf // vocab_size
|
| 141 |
+
indices_buf = indices_buf.fmod(vocab_size)
|
| 142 |
+
|
| 143 |
+
# At this point, beams_buf and indices_buf are single-dim and contain relative indices
|
| 144 |
+
return scores_buf, indices_buf, beams_buf
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class PrefixConstrainedBeamSearch(Search):
|
| 148 |
+
def __init__(self, tgt_dict, prefix_allowed_tokens_fn):
|
| 149 |
+
super().__init__(tgt_dict)
|
| 150 |
+
self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
|
| 151 |
+
self.stop_on_max_len = True
|
| 152 |
+
|
| 153 |
+
@torch.jit.export
|
| 154 |
+
def apply_mask(self, x, prev_output_tokens, original_batch_idxs):
|
| 155 |
+
beam_size = x.shape[0] // original_batch_idxs.shape[0]
|
| 156 |
+
original_batch_idxs = (
|
| 157 |
+
original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist()
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
mask = torch.full_like(x, -math.inf)
|
| 161 |
+
for sent_i, (sent, batch_i) in enumerate(
|
| 162 |
+
zip(prev_output_tokens, original_batch_idxs)
|
| 163 |
+
):
|
| 164 |
+
mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0
|
| 165 |
+
|
| 166 |
+
return mask
|
| 167 |
+
|
| 168 |
+
@torch.jit.export
|
| 169 |
+
def step(
|
| 170 |
+
self,
|
| 171 |
+
step: int,
|
| 172 |
+
lprobs: Tensor,
|
| 173 |
+
scores: Tensor,
|
| 174 |
+
prev_output_tokens: Tensor,
|
| 175 |
+
original_batch_idxs: Tensor,
|
| 176 |
+
):
|
| 177 |
+
bsz, beam_size, vocab_size = lprobs.size()
|
| 178 |
+
|
| 179 |
+
lprobs += self.apply_mask(
|
| 180 |
+
lprobs.view(bsz * beam_size, 1, vocab_size),
|
| 181 |
+
prev_output_tokens,
|
| 182 |
+
original_batch_idxs,
|
| 183 |
+
).view(bsz, beam_size, vocab_size)
|
| 184 |
+
|
| 185 |
+
if step == 0:
|
| 186 |
+
# at the first step all hypotheses are equally likely, so use
|
| 187 |
+
# only the first beam
|
| 188 |
+
lprobs = lprobs[:, ::beam_size, :].contiguous()
|
| 189 |
+
else:
|
| 190 |
+
# make probs contain cumulative scores for each hypothesis
|
| 191 |
+
assert scores is not None
|
| 192 |
+
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
|
| 193 |
+
|
| 194 |
+
top_prediction = torch.topk(
|
| 195 |
+
lprobs.view(bsz, -1),
|
| 196 |
+
k=min(
|
| 197 |
+
# Take the best beam_size predictions. We'll choose the first
|
| 198 |
+
# beam_size of these which don't predict eos to continue with.
|
| 199 |
+
beam_size,
|
| 200 |
+
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
|
| 201 |
+
),
|
| 202 |
+
)
|
| 203 |
+
scores_buf = top_prediction[0]
|
| 204 |
+
indices_buf = top_prediction[1]
|
| 205 |
+
beams_buf = indices_buf // vocab_size
|
| 206 |
+
indices_buf = indices_buf.fmod(vocab_size)
|
| 207 |
+
return scores_buf, indices_buf, beams_buf
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class LexicallyConstrainedBeamSearch(Search):
|
| 211 |
+
"""Implements lexically constrained beam search as described in
|
| 212 |
+
|
| 213 |
+
Fast Lexically Constrained Decoding with Dynamic Beam
|
| 214 |
+
Allocation for Neural Machine Translation. Post & Vilar,
|
| 215 |
+
NAACL 2018. https://www.aclweb.org/anthology/N18-1119/
|
| 216 |
+
|
| 217 |
+
and
|
| 218 |
+
|
| 219 |
+
Improved Lexically Constrained Decoding for Translation and
|
| 220 |
+
Monolingual Rewriting. Hu et al, NAACL
|
| 221 |
+
2019. https://www.aclweb.org/anthology/N19-1090/
|
| 222 |
+
|
| 223 |
+
This is accomplished by maintaining, for each beam hypothesis, a
|
| 224 |
+
ConstraintState object (see constraints.py) that tracks which
|
| 225 |
+
constraints have been generated and using this information to
|
| 226 |
+
shape the beam for each input sentence.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(self, tgt_dict, representation):
|
| 230 |
+
super().__init__(tgt_dict)
|
| 231 |
+
self.representation = representation
|
| 232 |
+
self.vocab_size = len(tgt_dict)
|
| 233 |
+
self.num_cands = 0
|
| 234 |
+
self.supports_constraints = True
|
| 235 |
+
|
| 236 |
+
@torch.jit.export
|
| 237 |
+
def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):
|
| 238 |
+
self.constraint_states = []
|
| 239 |
+
for constraint_tensor in batch_constraints:
|
| 240 |
+
if self.representation == "ordered":
|
| 241 |
+
constraint_state = OrderedConstraintState.create(constraint_tensor)
|
| 242 |
+
elif self.representation == "unordered":
|
| 243 |
+
constraint_state = UnorderedConstraintState.create(constraint_tensor)
|
| 244 |
+
|
| 245 |
+
self.constraint_states.append([constraint_state for i in range(beam_size)])
|
| 246 |
+
|
| 247 |
+
@torch.jit.export
|
| 248 |
+
def prune_sentences(self, batch_idxs: Tensor):
|
| 249 |
+
self.constraint_states = [
|
| 250 |
+
self.constraint_states[i] for i in batch_idxs.tolist()
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
@torch.jit.export
|
| 254 |
+
def update_constraints(self, active_hypos: Tensor):
|
| 255 |
+
if self.constraint_states:
|
| 256 |
+
batch_size = active_hypos.size(0)
|
| 257 |
+
for sentid in range(batch_size):
|
| 258 |
+
self.constraint_states[sentid] = [
|
| 259 |
+
self.constraint_states[sentid][i] for i in active_hypos[sentid]
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
@torch.jit.export
|
| 263 |
+
def step(
|
| 264 |
+
self,
|
| 265 |
+
step: int,
|
| 266 |
+
lprobs: Tensor,
|
| 267 |
+
scores: Optional[Tensor],
|
| 268 |
+
prev_output_tokens: Optional[Tensor] = None,
|
| 269 |
+
original_batch_idxs: Optional[Tensor] = None,
|
| 270 |
+
):
|
| 271 |
+
"""
|
| 272 |
+
A constrained step builds a large candidates list from the following:
|
| 273 |
+
- the top 2 * {beam_size} items over the whole beam
|
| 274 |
+
- for each item in the beam
|
| 275 |
+
- the top {each_k} (default 1)
|
| 276 |
+
- all next constraints
|
| 277 |
+
We then compute the constrained state of each beam item, and assign
|
| 278 |
+
stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so
|
| 279 |
+
on. We then sort by (stripe, score), and truncate the list at
|
| 280 |
+
2 * beam size.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
step: the decoder step
|
| 284 |
+
lprobs: (batch size, beam size, target vocab)
|
| 285 |
+
the target-vocab distributions for each item in the beam.
|
| 286 |
+
Retrun: A tuple of (scores, indices, beams, constraints) where:
|
| 287 |
+
scores: (batch, output beam size)
|
| 288 |
+
the scores of the chosen elements
|
| 289 |
+
indices: (batch, output beam size)
|
| 290 |
+
the target vocab indices of the chosen elements
|
| 291 |
+
beams: (batch, output beam size)
|
| 292 |
+
the 0-indexed hypothesis ids of the chosen elements
|
| 293 |
+
constraints: (batch, output beam size)
|
| 294 |
+
the new constraint states
|
| 295 |
+
"""
|
| 296 |
+
each_k = 1
|
| 297 |
+
device = lprobs.device
|
| 298 |
+
|
| 299 |
+
batch_size, beam_size, vocab_size = lprobs.size()
|
| 300 |
+
|
| 301 |
+
self.num_cands = min(
|
| 302 |
+
# Just take the k-best. We'll get another k from the 1-best from each
|
| 303 |
+
# row, plus more from the constraints
|
| 304 |
+
beam_size * 2,
|
| 305 |
+
lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items
|
| 309 |
+
constraint_states = self.constraint_states
|
| 310 |
+
if constraint_states and step > 0:
|
| 311 |
+
not_finished_indices = []
|
| 312 |
+
for sentno, sent_constraints in enumerate(constraint_states):
|
| 313 |
+
for beamno, state in enumerate(sent_constraints):
|
| 314 |
+
index = sentno * beam_size + beamno
|
| 315 |
+
if not state.finished:
|
| 316 |
+
not_finished_indices.append(index)
|
| 317 |
+
not_finished_indices = torch.tensor(not_finished_indices)
|
| 318 |
+
if not_finished_indices.numel() > 0:
|
| 319 |
+
lprobs.view(batch_size * beam_size, -1)[
|
| 320 |
+
not_finished_indices, self.eos
|
| 321 |
+
] = -math.inf
|
| 322 |
+
|
| 323 |
+
if step == 0:
|
| 324 |
+
# at the first step all hypotheses are equally likely, so use
|
| 325 |
+
# only the first beam entry for each batch item
|
| 326 |
+
lprobs = lprobs[:, ::beam_size, :].contiguous()
|
| 327 |
+
else:
|
| 328 |
+
# make probs contain cumulative scores for each hypothesis
|
| 329 |
+
assert scores is not None
|
| 330 |
+
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
|
| 331 |
+
|
| 332 |
+
top_prediction = torch.topk(
|
| 333 |
+
lprobs.view(batch_size, -1),
|
| 334 |
+
self.num_cands,
|
| 335 |
+
)
|
| 336 |
+
scores_buf, indices_buf = top_prediction
|
| 337 |
+
# Project back into relative indices and beams
|
| 338 |
+
beams_buf = indices_buf // vocab_size
|
| 339 |
+
indices_buf = indices_buf.fmod(vocab_size)
|
| 340 |
+
|
| 341 |
+
# Short circuit if there are no constraints in this batch
|
| 342 |
+
if not constraint_states:
|
| 343 |
+
return scores_buf, indices_buf, beams_buf
|
| 344 |
+
|
| 345 |
+
# STEP 1: get top-1 from each hypothesis across all sentences in the batch
|
| 346 |
+
if step > 0:
|
| 347 |
+
top_scores, top_indices = torch.topk(
|
| 348 |
+
lprobs.view(batch_size * beam_size, -1),
|
| 349 |
+
k=each_k,
|
| 350 |
+
dim=1,
|
| 351 |
+
)
|
| 352 |
+
top_scores = top_scores.view(batch_size, -1)
|
| 353 |
+
top_indices = top_indices.view(batch_size, -1)
|
| 354 |
+
scores_buf = torch.cat((scores_buf, top_scores), dim=1)
|
| 355 |
+
indices_buf = torch.cat((indices_buf, top_indices), dim=1)
|
| 356 |
+
new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1)
|
| 357 |
+
beams_buf = torch.cat((beams_buf, new_beams), dim=1)
|
| 358 |
+
|
| 359 |
+
# Now, process sentences in the batch one by one.
|
| 360 |
+
new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device)
|
| 361 |
+
new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
|
| 362 |
+
new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
|
| 363 |
+
for sentno, states in enumerate(constraint_states):
|
| 364 |
+
scores, indices, beams, new_states = self.step_sentence(
|
| 365 |
+
step,
|
| 366 |
+
sentno,
|
| 367 |
+
lprobs[sentno],
|
| 368 |
+
constraint_states[sentno],
|
| 369 |
+
beams_buf[sentno].clone(),
|
| 370 |
+
indices_buf[sentno].clone(),
|
| 371 |
+
scores_buf[sentno].clone(),
|
| 372 |
+
)
|
| 373 |
+
new_scores_buf[sentno] = scores
|
| 374 |
+
new_indices_buf[sentno] = indices
|
| 375 |
+
new_beams_buf[sentno] = beams
|
| 376 |
+
self.constraint_states[sentno] = new_states
|
| 377 |
+
|
| 378 |
+
return new_scores_buf, new_indices_buf, new_beams_buf
|
| 379 |
+
|
| 380 |
+
@torch.jit.export
|
| 381 |
+
def step_sentence(
|
| 382 |
+
self,
|
| 383 |
+
step: int,
|
| 384 |
+
sentno: int,
|
| 385 |
+
lprobs: Tensor,
|
| 386 |
+
constraint_states: List[List[ConstraintState]],
|
| 387 |
+
beams_buf: Tensor,
|
| 388 |
+
indices_buf: Tensor,
|
| 389 |
+
scores_buf: Tensor,
|
| 390 |
+
):
|
| 391 |
+
"""Does per-sentence processing. Adds all constraints for each
|
| 392 |
+
hypothesis to the list of candidates; then removes duplicates,
|
| 393 |
+
sorts, and dynamically stripes across the banks. All tensor inputs
|
| 394 |
+
are collapsed to those pertaining to a single input sentence.
|
| 395 |
+
"""
|
| 396 |
+
device = lprobs.device
|
| 397 |
+
|
| 398 |
+
# STEP 2: Add all constraints for each beam item
|
| 399 |
+
for beamno, state in enumerate(constraint_states):
|
| 400 |
+
next_tokens = torch.tensor(list(state.next_tokens()), device=device).long()
|
| 401 |
+
if next_tokens.numel() != 0:
|
| 402 |
+
indices_buf = torch.cat((indices_buf, next_tokens))
|
| 403 |
+
next_beams = (
|
| 404 |
+
torch.tensor(beamno, device=device)
|
| 405 |
+
.repeat(next_tokens.size(0))
|
| 406 |
+
.long()
|
| 407 |
+
)
|
| 408 |
+
beams_buf = torch.cat((beams_buf, next_beams))
|
| 409 |
+
next_values = lprobs[beamno].take(next_tokens.view(-1))
|
| 410 |
+
scores_buf = torch.cat((scores_buf, next_values))
|
| 411 |
+
|
| 412 |
+
# At the 0th time step, there is just one beam item
|
| 413 |
+
if step == 0:
|
| 414 |
+
break
|
| 415 |
+
|
| 416 |
+
# STEP 3: Compute the "bank" for each candidate. This is the
|
| 417 |
+
# number of constraints it's generated. We need this so that
|
| 418 |
+
# we can do round-robin allocation of the beam across these
|
| 419 |
+
# banks. If C is the number of constraints, we select the best
|
| 420 |
+
# item in bank C, then the best in bank C-1, etc, followed by
|
| 421 |
+
# the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so
|
| 422 |
+
# on, until the maximum beam size. We accomplish this by
|
| 423 |
+
# creating a sort key and striping across the banks.
|
| 424 |
+
|
| 425 |
+
# Compute the new states for all candidates
|
| 426 |
+
cands_size = indices_buf.size(0)
|
| 427 |
+
constraint_states = [
|
| 428 |
+
constraint_states[beams_buf[i]].advance(indices_buf[i])
|
| 429 |
+
for i in range(cands_size)
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
banks = torch.tensor([state.bank for state in constraint_states], device=device)
|
| 433 |
+
|
| 434 |
+
# STEP 4: Sort
|
| 435 |
+
num_constraint_tokens = len(state.tokens)
|
| 436 |
+
|
| 437 |
+
# Sort by keys (bank, score) (i.e., sort banks together, and scores
|
| 438 |
+
# within banks). AFAIK pytorch doesn't support either stable sort or
|
| 439 |
+
# multi-key sorting, so we have to hack this.
|
| 440 |
+
MAX_SCORE = -100
|
| 441 |
+
sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf
|
| 442 |
+
sort_values, sort_indices = sort_key.sort(dim=0, descending=True)
|
| 443 |
+
scores_buf = scores_buf[sort_indices]
|
| 444 |
+
indices_buf = indices_buf[sort_indices]
|
| 445 |
+
beams_buf = beams_buf[sort_indices]
|
| 446 |
+
banks = banks[sort_indices]
|
| 447 |
+
|
| 448 |
+
# Sort the constraints to follow suit
|
| 449 |
+
constraint_states = [constraint_states[i] for i in sort_indices]
|
| 450 |
+
|
| 451 |
+
# STEP 5: Remove duplicates. The topk calls (overall and
|
| 452 |
+
# per-row) plus the per-row generation of constraints will
|
| 453 |
+
# produce duplicates. Here we remove them.
|
| 454 |
+
|
| 455 |
+
def roll(t):
|
| 456 |
+
"""Rolls a 1d tensor left by 1.
|
| 457 |
+
|
| 458 |
+
[0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3]
|
| 459 |
+
"""
|
| 460 |
+
return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0)
|
| 461 |
+
|
| 462 |
+
# We map candidates (beam, token_id) to a single dimension.
|
| 463 |
+
# This is then shifted by 1. We can then easily identify
|
| 464 |
+
# duplicates and create a mask that identifies unique
|
| 465 |
+
# extensions.
|
| 466 |
+
uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf
|
| 467 |
+
uniques_mask = roll(uniques_mask) != uniques_mask
|
| 468 |
+
|
| 469 |
+
# Use the mask to pare down the data structures
|
| 470 |
+
scores_buf = torch.masked_select(scores_buf, uniques_mask)
|
| 471 |
+
indices_buf = torch.masked_select(indices_buf, uniques_mask)
|
| 472 |
+
beams_buf = torch.masked_select(beams_buf, uniques_mask)
|
| 473 |
+
banks = torch.masked_select(banks, uniques_mask)
|
| 474 |
+
i = 1
|
| 475 |
+
for mask in uniques_mask[1:]:
|
| 476 |
+
if not mask:
|
| 477 |
+
constraint_states.pop(i)
|
| 478 |
+
i += mask
|
| 479 |
+
|
| 480 |
+
# STEP 6: Assign IDs round-robin across banks, sort, and
|
| 481 |
+
# truncate. Now that the candidates are sorted by (bank,
|
| 482 |
+
# score) and uniqed, we dynamically allocate the {beam_size}
|
| 483 |
+
# beam by striping across the candidates. These stripes will
|
| 484 |
+
# be used as sort keys to do round-robin selection. This is
|
| 485 |
+
# accomplished in a single pass with offsets. Sorting by
|
| 486 |
+
# highest-banks (furthest-along hypotheses) first ensures
|
| 487 |
+
# progress through the constraints.
|
| 488 |
+
#
|
| 489 |
+
# e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0
|
| 490 |
+
# OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1
|
| 491 |
+
# NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7
|
| 492 |
+
# = 0 5 10 1 6 11 13 2 7 12 3 8
|
| 493 |
+
#
|
| 494 |
+
# Sorting by this then gives the following banks:
|
| 495 |
+
#
|
| 496 |
+
# 3 2 1 0 3 2 1 0 3 2 1 2
|
| 497 |
+
#
|
| 498 |
+
# We'll take the top {beam_size} of these.
|
| 499 |
+
stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)]
|
| 500 |
+
stripes = torch.zeros_like(banks)
|
| 501 |
+
cur_bank_count = -1
|
| 502 |
+
cur_bank = banks[0]
|
| 503 |
+
for i, bank in enumerate(banks):
|
| 504 |
+
if bank != cur_bank:
|
| 505 |
+
cur_bank_count = 0
|
| 506 |
+
cur_bank = bank
|
| 507 |
+
else:
|
| 508 |
+
cur_bank_count += 1
|
| 509 |
+
stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count]
|
| 510 |
+
|
| 511 |
+
# STEP 7: Sort by the stripes values
|
| 512 |
+
sort_values, sort_indices = stripes.sort(dim=0)
|
| 513 |
+
scores_buf = scores_buf[sort_indices]
|
| 514 |
+
indices_buf = indices_buf[sort_indices]
|
| 515 |
+
beams_buf = beams_buf[sort_indices]
|
| 516 |
+
constraint_states = [constraint_states[i] for i in sort_indices]
|
| 517 |
+
|
| 518 |
+
# STEP 8: Truncate to the candidates size!
|
| 519 |
+
scores_buf = scores_buf[: self.num_cands]
|
| 520 |
+
indices_buf = indices_buf[: self.num_cands]
|
| 521 |
+
beams_buf = beams_buf[: self.num_cands]
|
| 522 |
+
|
| 523 |
+
return scores_buf, indices_buf, beams_buf, constraint_states
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class LengthConstrainedBeamSearch(Search):
|
| 527 |
+
def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
|
| 528 |
+
super().__init__(tgt_dict)
|
| 529 |
+
self.min_len_a = min_len_a
|
| 530 |
+
self.min_len_b = min_len_b
|
| 531 |
+
self.max_len_a = max_len_a
|
| 532 |
+
self.max_len_b = max_len_b
|
| 533 |
+
self.beam = BeamSearch(tgt_dict)
|
| 534 |
+
self.needs_src_lengths = True
|
| 535 |
+
|
| 536 |
+
def step(
|
| 537 |
+
self,
|
| 538 |
+
step: int,
|
| 539 |
+
lprobs,
|
| 540 |
+
scores,
|
| 541 |
+
prev_output_tokens: Optional[Tensor] = None,
|
| 542 |
+
original_batch_idxs: Optional[Tensor] = None,
|
| 543 |
+
):
|
| 544 |
+
min_lens = self.min_len_a * self.src_lengths + self.min_len_b
|
| 545 |
+
max_lens = self.max_len_a * self.src_lengths + self.max_len_b
|
| 546 |
+
lprobs[step < min_lens, :, self.eos] = -math.inf
|
| 547 |
+
lprobs[step >= max_lens, :, self.eos] = 0
|
| 548 |
+
return self.beam.step(step, lprobs, scores)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class DiverseBeamSearch(Search):
|
| 552 |
+
"""Diverse Beam Search.
|
| 553 |
+
|
| 554 |
+
See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
|
| 555 |
+
Models" for details.
|
| 556 |
+
|
| 557 |
+
We only implement the Hamming Diversity penalty here, which performed best
|
| 558 |
+
in the original paper.
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
def __init__(self, tgt_dict, num_groups, diversity_strength):
|
| 562 |
+
super().__init__(tgt_dict)
|
| 563 |
+
self.num_groups = num_groups
|
| 564 |
+
self.diversity_strength = -diversity_strength
|
| 565 |
+
self.beam = BeamSearch(tgt_dict)
|
| 566 |
+
|
| 567 |
+
@torch.jit.export
|
| 568 |
+
def step(
|
| 569 |
+
self,
|
| 570 |
+
step: int,
|
| 571 |
+
lprobs,
|
| 572 |
+
scores,
|
| 573 |
+
prev_output_tokens: Optional[Tensor] = None,
|
| 574 |
+
original_batch_idxs: Optional[Tensor] = None,
|
| 575 |
+
):
|
| 576 |
+
bsz, beam_size, vocab_size = lprobs.size()
|
| 577 |
+
if beam_size % self.num_groups != 0:
|
| 578 |
+
raise ValueError(
|
| 579 |
+
"DiverseBeamSearch requires --beam to be divisible by the number of groups"
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# initialize diversity penalty
|
| 583 |
+
diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs)
|
| 584 |
+
|
| 585 |
+
scores_G, indices_G, beams_G = [], [], []
|
| 586 |
+
for g in range(self.num_groups):
|
| 587 |
+
lprobs_g = lprobs[:, g :: self.num_groups, :]
|
| 588 |
+
scores_g = scores[:, g :: self.num_groups, :] if step > 0 else None
|
| 589 |
+
|
| 590 |
+
# apply diversity penalty
|
| 591 |
+
if g > 0:
|
| 592 |
+
lprobs_g = torch.add(
|
| 593 |
+
lprobs_g,
|
| 594 |
+
other=diversity_buf.unsqueeze(1),
|
| 595 |
+
alpha=self.diversity_strength,
|
| 596 |
+
)
|
| 597 |
+
else:
|
| 598 |
+
lprobs_g = lprobs_g.contiguous()
|
| 599 |
+
|
| 600 |
+
scores_buf, indices_buf, beams_buf = self.beam.step(
|
| 601 |
+
step, lprobs_g, scores_g
|
| 602 |
+
)
|
| 603 |
+
beams_buf.mul_(self.num_groups).add_(g)
|
| 604 |
+
|
| 605 |
+
scores_G.append(scores_buf.clone())
|
| 606 |
+
indices_G.append(indices_buf.clone())
|
| 607 |
+
beams_G.append(beams_buf.clone())
|
| 608 |
+
|
| 609 |
+
# update diversity penalty
|
| 610 |
+
diversity_buf.scatter_add_(
|
| 611 |
+
1, indices_buf, torch.ones(indices_buf.size()).to(diversity_buf)
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# interleave results from different groups
|
| 615 |
+
scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1)
|
| 616 |
+
indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1)
|
| 617 |
+
beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1)
|
| 618 |
+
return scores_buf, indices_buf, beams_buf
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class Sampling(Search):
|
| 622 |
+
sampling_topk: int
|
| 623 |
+
sampling_topp: float
|
| 624 |
+
|
| 625 |
+
def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0):
|
| 626 |
+
super().__init__(tgt_dict)
|
| 627 |
+
self.sampling_topk = sampling_topk
|
| 628 |
+
self.sampling_topp = sampling_topp
|
| 629 |
+
|
| 630 |
+
def _sample_topp(self, lprobs):
|
| 631 |
+
"""Sample among the smallest set of elements whose cumulative probability mass exceeds p.
|
| 632 |
+
|
| 633 |
+
See `"The Curious Case of Neural Text Degeneration"
|
| 634 |
+
(Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
lprobs: (bsz x input_beam_size x vocab_size)
|
| 638 |
+
the model's log-probabilities over the vocabulary at the current step
|
| 639 |
+
|
| 640 |
+
Return: A tuple of (trimed_probs, truncated_indices) where:
|
| 641 |
+
trimed_probs: (bsz x input_beam_size x ?)
|
| 642 |
+
the model's probabilities over the elements selected to sample from. The
|
| 643 |
+
width of the third dimension is determined by top-P.
|
| 644 |
+
truncated_indices: (bsz x input_beam_size x ?)
|
| 645 |
+
the indices of the chosen elements.
|
| 646 |
+
"""
|
| 647 |
+
probs = lprobs.exp_()
|
| 648 |
+
|
| 649 |
+
# sort the last dimension (vocab dimension) in descending order
|
| 650 |
+
sorted_probs, sorted_indices = probs.sort(descending=True)
|
| 651 |
+
|
| 652 |
+
# compute a mask to indicate the words to be included in the top-P set.
|
| 653 |
+
cumsum_probs = sorted_probs.cumsum(dim=2)
|
| 654 |
+
mask = cumsum_probs.lt(self.sampling_topp)
|
| 655 |
+
|
| 656 |
+
# note that mask was computed by 'lt'. One more word needs to be included
|
| 657 |
+
# so that the cumulative probability mass can exceed p.
|
| 658 |
+
cumsum_mask = mask.cumsum(dim=2)
|
| 659 |
+
last_included = cumsum_mask[:, :, -1:]
|
| 660 |
+
last_included.clamp_(0, mask.size()[2] - 1)
|
| 661 |
+
mask = mask.scatter_(2, last_included, 1)
|
| 662 |
+
|
| 663 |
+
# truncate unnecessary dims.
|
| 664 |
+
max_dim = last_included.max()
|
| 665 |
+
truncated_mask = mask[:, :, : max_dim + 1]
|
| 666 |
+
truncated_probs = sorted_probs[:, :, : max_dim + 1]
|
| 667 |
+
truncated_indices = sorted_indices[:, :, : max_dim + 1]
|
| 668 |
+
|
| 669 |
+
# trim the words that are not in top-P by setting their probabilities
|
| 670 |
+
# to 0, so that they would not be sampled later.
|
| 671 |
+
trim_mask = ~truncated_mask
|
| 672 |
+
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
|
| 673 |
+
return trimed_probs, truncated_indices
|
| 674 |
+
|
| 675 |
+
@torch.jit.export
|
| 676 |
+
def step(
|
| 677 |
+
self,
|
| 678 |
+
step: int,
|
| 679 |
+
lprobs,
|
| 680 |
+
scores,
|
| 681 |
+
prev_output_tokens: Optional[Tensor] = None,
|
| 682 |
+
original_batch_idxs: Optional[Tensor] = None,
|
| 683 |
+
):
|
| 684 |
+
bsz, beam_size, vocab_size = lprobs.size()
|
| 685 |
+
|
| 686 |
+
if step == 0:
|
| 687 |
+
# at the first step all hypotheses are equally likely, so use
|
| 688 |
+
# only the first beam
|
| 689 |
+
lprobs = lprobs[:, ::beam_size, :].contiguous()
|
| 690 |
+
|
| 691 |
+
if self.sampling_topp > 0:
|
| 692 |
+
# only sample from the smallest set of words whose cumulative probability mass exceeds p
|
| 693 |
+
probs, top_indices = self._sample_topp(lprobs)
|
| 694 |
+
elif self.sampling_topk > 0:
|
| 695 |
+
# only sample from top-k candidates
|
| 696 |
+
lprobs, top_indices = lprobs.topk(self.sampling_topk)
|
| 697 |
+
probs = lprobs.exp_()
|
| 698 |
+
else:
|
| 699 |
+
probs = lprobs.exp_()
|
| 700 |
+
|
| 701 |
+
# dummy data to be consistent with true branch for type check
|
| 702 |
+
top_indices = torch.empty(0).to(probs)
|
| 703 |
+
# sample
|
| 704 |
+
if step == 0:
|
| 705 |
+
indices_buf = torch.multinomial(
|
| 706 |
+
probs.view(bsz, -1),
|
| 707 |
+
beam_size,
|
| 708 |
+
replacement=True,
|
| 709 |
+
).view(bsz, beam_size)
|
| 710 |
+
else:
|
| 711 |
+
indices_buf = torch.multinomial(
|
| 712 |
+
probs.view(bsz * beam_size, -1),
|
| 713 |
+
1,
|
| 714 |
+
replacement=True,
|
| 715 |
+
).view(bsz, beam_size)
|
| 716 |
+
|
| 717 |
+
if step == 0:
|
| 718 |
+
# expand to beam size
|
| 719 |
+
probs = probs.expand(bsz, beam_size, -1)
|
| 720 |
+
|
| 721 |
+
# gather scores
|
| 722 |
+
scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1))
|
| 723 |
+
scores_buf = scores_buf.log_().view(bsz, -1)
|
| 724 |
+
|
| 725 |
+
# remap indices if using top-k or top-P sampling
|
| 726 |
+
if self.sampling_topk > 0 or self.sampling_topp > 0:
|
| 727 |
+
indices_buf = torch.gather(
|
| 728 |
+
top_indices.expand(bsz, beam_size, -1),
|
| 729 |
+
dim=2,
|
| 730 |
+
index=indices_buf.unsqueeze(-1),
|
| 731 |
+
).squeeze(2)
|
| 732 |
+
|
| 733 |
+
if step == 0:
|
| 734 |
+
beams_buf = indices_buf.new_zeros(bsz, beam_size)
|
| 735 |
+
else:
|
| 736 |
+
beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1)
|
| 737 |
+
# make scores cumulative
|
| 738 |
+
scores_buf.add_(
|
| 739 |
+
torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf)
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
return scores_buf, indices_buf, beams_buf
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class DiverseSiblingsSearch(Search):
|
| 746 |
+
"""
|
| 747 |
+
Beam search with diverse siblings.
|
| 748 |
+
|
| 749 |
+
See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details.
|
| 750 |
+
https://arxiv.org/abs/1611.08562
|
| 751 |
+
|
| 752 |
+
1/ Calculate hypotheses for each beam
|
| 753 |
+
2/ Intra-sibling ordering
|
| 754 |
+
3/ Rewrite scores
|
| 755 |
+
4/ Choose top K hypotheses
|
| 756 |
+
|
| 757 |
+
if diversity_rate == 0 is equivalent to BeamSearch
|
| 758 |
+
"""
|
| 759 |
+
|
| 760 |
+
def __init__(self, tgt_dict, diversity_rate):
|
| 761 |
+
super().__init__(tgt_dict)
|
| 762 |
+
self.diversity_rate = diversity_rate
|
| 763 |
+
self.beam = BeamSearch(tgt_dict)
|
| 764 |
+
|
| 765 |
+
def step(
|
| 766 |
+
self,
|
| 767 |
+
step: int,
|
| 768 |
+
lprobs,
|
| 769 |
+
scores,
|
| 770 |
+
prev_output_tokens: Optional[Tensor] = None,
|
| 771 |
+
original_batch_idxs: Optional[Tensor] = None,
|
| 772 |
+
):
|
| 773 |
+
bsz, beam_size, vocab_size = lprobs.size()
|
| 774 |
+
k = min(
|
| 775 |
+
# Take the best 2 x beam_size predictions. We'll choose the first
|
| 776 |
+
# beam_size of these which don't predict eos to continue with.
|
| 777 |
+
beam_size * 2,
|
| 778 |
+
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
|
| 779 |
+
)
|
| 780 |
+
s_list: List[Tensor]
|
| 781 |
+
i_list: List[Tensor]
|
| 782 |
+
s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)]
|
| 783 |
+
i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)]
|
| 784 |
+
sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate
|
| 785 |
+
|
| 786 |
+
if step == 0:
|
| 787 |
+
return self.beam.step(step, lprobs, scores)
|
| 788 |
+
lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
|
| 789 |
+
|
| 790 |
+
# 1/ Calculate hypotheses for each beam
|
| 791 |
+
for i in range(beam_size):
|
| 792 |
+
torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i]))
|
| 793 |
+
i_list[i].fmod_(vocab_size)
|
| 794 |
+
|
| 795 |
+
# 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores
|
| 796 |
+
s_list[i].sub_(sibling_score)
|
| 797 |
+
|
| 798 |
+
# 4/ Choose top K hypotheses
|
| 799 |
+
indices = torch.stack(i_list, dim=1).view(bsz, -1)
|
| 800 |
+
|
| 801 |
+
final_scores = torch.empty(0).to(lprobs)
|
| 802 |
+
final_indices = torch.LongTensor().to(device=lprobs.device)
|
| 803 |
+
final_beams = torch.LongTensor().to(device=lprobs.device)
|
| 804 |
+
(final_scores, final_indices) = torch.topk(
|
| 805 |
+
torch.stack(s_list, dim=1).view(bsz, -1),
|
| 806 |
+
k,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
final_beams = final_indices // k
|
| 810 |
+
|
| 811 |
+
for i in range(bsz):
|
| 812 |
+
final_indices[i] = indices[i][final_indices[i]]
|
| 813 |
+
|
| 814 |
+
return final_scores, final_indices, final_beams
|
models/sequence_generator.py
ADDED
|
@@ -0,0 +1,1053 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Dict, List, Optional
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from fairseq import search, utils
|
| 13 |
+
from fairseq.models import FairseqIncrementalDecoder
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from fairseq.ngram_repeat_block import NGramRepeatBlock
|
| 16 |
+
|
| 17 |
+
from data import data_utils
|
| 18 |
+
|
| 19 |
+
class SequenceGenerator(nn.Module):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
models,
|
| 23 |
+
tgt_dict,
|
| 24 |
+
beam_size=1,
|
| 25 |
+
max_len_a=0,
|
| 26 |
+
max_len_b=200,
|
| 27 |
+
max_len=0,
|
| 28 |
+
min_len=1,
|
| 29 |
+
normalize_scores=True,
|
| 30 |
+
len_penalty=1.0,
|
| 31 |
+
unk_penalty=0.0,
|
| 32 |
+
temperature=1.0,
|
| 33 |
+
match_source_len=False,
|
| 34 |
+
no_repeat_ngram_size=0,
|
| 35 |
+
search_strategy=None,
|
| 36 |
+
eos=None,
|
| 37 |
+
symbols_to_strip_from_output=None,
|
| 38 |
+
lm_model=None,
|
| 39 |
+
lm_weight=1.0,
|
| 40 |
+
constraint_trie=None,
|
| 41 |
+
constraint_range=None,
|
| 42 |
+
gen_code=False,
|
| 43 |
+
gen_box=False,
|
| 44 |
+
ignore_eos=False,
|
| 45 |
+
zero_shot=False
|
| 46 |
+
):
|
| 47 |
+
"""Generates translations of a given source sentence.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
| 51 |
+
currently support fairseq.models.TransformerModel for scripting
|
| 52 |
+
beam_size (int, optional): beam width (default: 1)
|
| 53 |
+
max_len_a/b (int, optional): generate sequences of maximum length
|
| 54 |
+
ax + b, where x is the source length
|
| 55 |
+
max_len (int, optional): the maximum length of the generated output
|
| 56 |
+
(not including end-of-sentence)
|
| 57 |
+
min_len (int, optional): the minimum length of the generated output
|
| 58 |
+
(not including end-of-sentence)
|
| 59 |
+
normalize_scores (bool, optional): normalize scores by the length
|
| 60 |
+
of the output (default: True)
|
| 61 |
+
len_penalty (float, optional): length penalty, where <1.0 favors
|
| 62 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
| 63 |
+
unk_penalty (float, optional): unknown word penalty, where <0
|
| 64 |
+
produces more unks, >0 produces fewer (default: 0.0)
|
| 65 |
+
temperature (float, optional): temperature, where values
|
| 66 |
+
>1.0 produce more uniform samples and values <1.0 produce
|
| 67 |
+
sharper samples (default: 1.0)
|
| 68 |
+
match_source_len (bool, optional): outputs should match the source
|
| 69 |
+
length (default: False)
|
| 70 |
+
"""
|
| 71 |
+
super().__init__()
|
| 72 |
+
if isinstance(models, EnsembleModel):
|
| 73 |
+
self.model = models
|
| 74 |
+
else:
|
| 75 |
+
self.model = EnsembleModel(models)
|
| 76 |
+
self.gen_code = gen_code
|
| 77 |
+
self.gen_box = gen_box
|
| 78 |
+
self.ignore_eos = ignore_eos
|
| 79 |
+
self.tgt_dict = tgt_dict
|
| 80 |
+
self.pad = tgt_dict.pad()
|
| 81 |
+
self.unk = tgt_dict.unk()
|
| 82 |
+
self.bos = tgt_dict.bos()
|
| 83 |
+
self.eos = tgt_dict.eos() if eos is None else eos
|
| 84 |
+
self.symbols_to_strip_from_output = (
|
| 85 |
+
symbols_to_strip_from_output.union({self.eos})
|
| 86 |
+
if symbols_to_strip_from_output is not None
|
| 87 |
+
else {self.bos, self.eos}
|
| 88 |
+
)
|
| 89 |
+
self.vocab_size = len(tgt_dict)
|
| 90 |
+
self.beam_size = beam_size
|
| 91 |
+
# the max beam size is the dictionary size - 1, since we never select pad
|
| 92 |
+
self.beam_size = min(beam_size, self.vocab_size - 1)
|
| 93 |
+
self.max_len_a = max_len_a
|
| 94 |
+
self.max_len_b = max_len_b
|
| 95 |
+
self.min_len = min_len
|
| 96 |
+
self.max_len = max_len or self.model.max_decoder_positions()
|
| 97 |
+
|
| 98 |
+
self.normalize_scores = normalize_scores
|
| 99 |
+
self.len_penalty = len_penalty
|
| 100 |
+
self.unk_penalty = unk_penalty
|
| 101 |
+
self.temperature = temperature
|
| 102 |
+
self.match_source_len = match_source_len
|
| 103 |
+
self.zero_shot = zero_shot
|
| 104 |
+
|
| 105 |
+
if no_repeat_ngram_size > 0:
|
| 106 |
+
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
|
| 107 |
+
else:
|
| 108 |
+
self.repeat_ngram_blocker = None
|
| 109 |
+
|
| 110 |
+
assert temperature > 0, "--temperature must be greater than 0"
|
| 111 |
+
|
| 112 |
+
self.search = (
|
| 113 |
+
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
|
| 114 |
+
)
|
| 115 |
+
# We only need to set src_lengths in LengthConstrainedBeamSearch.
|
| 116 |
+
# As a module attribute, setting it would break in multithread
|
| 117 |
+
# settings when the model is shared.
|
| 118 |
+
self.should_set_src_lengths = (
|
| 119 |
+
hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.model.eval()
|
| 123 |
+
|
| 124 |
+
self.lm_model = lm_model
|
| 125 |
+
self.lm_weight = lm_weight
|
| 126 |
+
if self.lm_model is not None:
|
| 127 |
+
self.lm_model.eval()
|
| 128 |
+
|
| 129 |
+
self.constraint_trie = constraint_trie
|
| 130 |
+
|
| 131 |
+
self.constraint_start = None
|
| 132 |
+
self.constraint_end = None
|
| 133 |
+
if constraint_range is not None:
|
| 134 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
| 135 |
+
self.constraint_start = int(constraint_start)
|
| 136 |
+
self.constraint_end = int(constraint_end)
|
| 137 |
+
|
| 138 |
+
def cuda(self):
|
| 139 |
+
self.model.cuda()
|
| 140 |
+
return self
|
| 141 |
+
|
| 142 |
+
@torch.no_grad()
|
| 143 |
+
def forward(
|
| 144 |
+
self,
|
| 145 |
+
sample: Dict[str, Dict[str, Tensor]],
|
| 146 |
+
prefix_tokens: Optional[Tensor] = None,
|
| 147 |
+
bos_token: Optional[int] = None,
|
| 148 |
+
):
|
| 149 |
+
"""Generate a batch of translations.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
sample (dict): batch
|
| 153 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
| 154 |
+
with these tokens
|
| 155 |
+
bos_token (int, optional): beginning of sentence token
|
| 156 |
+
(default: self.eos)
|
| 157 |
+
"""
|
| 158 |
+
return self._generate(sample, prefix_tokens, bos_token=bos_token)
|
| 159 |
+
|
| 160 |
+
# TODO(myleott): unused, deprecate after pytorch-translate migration
|
| 161 |
+
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
|
| 162 |
+
"""Iterate over a batched dataset and yield individual translations.
|
| 163 |
+
Args:
|
| 164 |
+
cuda (bool, optional): use GPU for generation
|
| 165 |
+
timer (StopwatchMeter, optional): time generations
|
| 166 |
+
"""
|
| 167 |
+
for sample in data_itr:
|
| 168 |
+
s = utils.move_to_cuda(sample) if cuda else sample
|
| 169 |
+
if "net_input" not in s:
|
| 170 |
+
continue
|
| 171 |
+
input = s["net_input"]
|
| 172 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
| 173 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
| 174 |
+
encoder_input = {
|
| 175 |
+
k: v for k, v in input.items() if k != "prev_output_tokens"
|
| 176 |
+
}
|
| 177 |
+
if timer is not None:
|
| 178 |
+
timer.start()
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
hypos = self.generate(encoder_input)
|
| 181 |
+
if timer is not None:
|
| 182 |
+
timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
|
| 183 |
+
for i, id in enumerate(s["id"].data):
|
| 184 |
+
# remove padding
|
| 185 |
+
src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
|
| 186 |
+
ref = (
|
| 187 |
+
utils.strip_pad(s["target"].data[i, :], self.pad)
|
| 188 |
+
if s["target"] is not None
|
| 189 |
+
else None
|
| 190 |
+
)
|
| 191 |
+
yield id, src, ref, hypos[i]
|
| 192 |
+
|
| 193 |
+
@torch.no_grad()
|
| 194 |
+
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
|
| 195 |
+
"""Generate translations. Match the api of other fairseq generators.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
| 199 |
+
sample (dict): batch
|
| 200 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
| 201 |
+
with these tokens
|
| 202 |
+
constraints (torch.LongTensor, optional): force decoder to include
|
| 203 |
+
the list of constraints
|
| 204 |
+
bos_token (int, optional): beginning of sentence token
|
| 205 |
+
(default: self.eos)
|
| 206 |
+
"""
|
| 207 |
+
return self._generate(models, sample, **kwargs)
|
| 208 |
+
|
| 209 |
+
def _generate(
|
| 210 |
+
self,
|
| 211 |
+
models,
|
| 212 |
+
sample: Dict[str, Dict[str, Tensor]],
|
| 213 |
+
prefix_tokens: Optional[Tensor] = None,
|
| 214 |
+
constraints: Optional[Tensor] = None,
|
| 215 |
+
bos_token: Optional[int] = None,
|
| 216 |
+
):
|
| 217 |
+
model = EnsembleModel(models)
|
| 218 |
+
incremental_states = torch.jit.annotate(
|
| 219 |
+
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
| 220 |
+
[
|
| 221 |
+
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
| 222 |
+
for i in range(model.models_size)
|
| 223 |
+
],
|
| 224 |
+
)
|
| 225 |
+
net_input = sample["net_input"]
|
| 226 |
+
|
| 227 |
+
if "src_tokens" in net_input:
|
| 228 |
+
src_tokens = net_input["src_tokens"]
|
| 229 |
+
# length of the source text being the character length except EndOfSentence and pad
|
| 230 |
+
src_lengths = (
|
| 231 |
+
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
|
| 232 |
+
)
|
| 233 |
+
elif "source" in net_input:
|
| 234 |
+
src_tokens = net_input["source"]
|
| 235 |
+
src_lengths = (
|
| 236 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
| 237 |
+
if net_input["padding_mask"] is not None
|
| 238 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
| 239 |
+
)
|
| 240 |
+
elif "features" in net_input:
|
| 241 |
+
src_tokens = net_input["features"]
|
| 242 |
+
src_lengths = (
|
| 243 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
| 244 |
+
if net_input["padding_mask"] is not None
|
| 245 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
|
| 249 |
+
|
| 250 |
+
# bsz: total number of sentences in beam
|
| 251 |
+
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
|
| 252 |
+
bsz, src_len = src_tokens.size()[:2]
|
| 253 |
+
beam_size = self.beam_size
|
| 254 |
+
|
| 255 |
+
if constraints is not None and not self.search.supports_constraints:
|
| 256 |
+
raise NotImplementedError(
|
| 257 |
+
"Target-side constraints were provided, but search method doesn't support them"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Initialize constraints, when active
|
| 261 |
+
self.search.init_constraints(constraints, beam_size)
|
| 262 |
+
|
| 263 |
+
max_len: int = -1
|
| 264 |
+
if self.match_source_len:
|
| 265 |
+
max_len = src_lengths.max().item()
|
| 266 |
+
else:
|
| 267 |
+
max_len = int(self.max_len_a * src_len + self.max_len_b)
|
| 268 |
+
assert (
|
| 269 |
+
self.min_len <= max_len
|
| 270 |
+
), "min_len cannot be larger than max_len, please adjust these!"
|
| 271 |
+
# compute the encoder output for each beam
|
| 272 |
+
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
|
| 273 |
+
encoder_outs = model.forward_encoder(net_input)
|
| 274 |
+
|
| 275 |
+
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
|
| 276 |
+
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
| 277 |
+
new_order = new_order.to(src_tokens.device).long()
|
| 278 |
+
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
|
| 279 |
+
# ensure encoder_outs is a List.
|
| 280 |
+
assert encoder_outs is not None
|
| 281 |
+
|
| 282 |
+
# initialize buffers
|
| 283 |
+
scores = (
|
| 284 |
+
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
|
| 285 |
+
) # +1 for eos; pad is never chosen for scoring
|
| 286 |
+
tokens = (
|
| 287 |
+
torch.zeros(bsz * beam_size, max_len + 2)
|
| 288 |
+
.to(src_tokens)
|
| 289 |
+
.long()
|
| 290 |
+
.fill_(self.pad)
|
| 291 |
+
) # +2 for eos and pad
|
| 292 |
+
# tokens[:, 0] = self.eos if bos_token is None else bos_token
|
| 293 |
+
tokens[:, 0] = self.bos
|
| 294 |
+
attn: Optional[Tensor] = None
|
| 295 |
+
|
| 296 |
+
# A list that indicates candidates that should be ignored.
|
| 297 |
+
# For example, suppose we're sampling and have already finalized 2/5
|
| 298 |
+
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
|
| 299 |
+
# so that we only finalize the remaining 3 samples.
|
| 300 |
+
cands_to_ignore = (
|
| 301 |
+
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
|
| 302 |
+
) # forward and backward-compatible False mask
|
| 303 |
+
|
| 304 |
+
# list of completed sentences
|
| 305 |
+
finalized = torch.jit.annotate(
|
| 306 |
+
List[List[Dict[str, Tensor]]],
|
| 307 |
+
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
| 308 |
+
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
| 309 |
+
|
| 310 |
+
# a boolean array indicating if the sentence at the index is finished or not
|
| 311 |
+
finished = [False for i in range(bsz)]
|
| 312 |
+
num_remaining_sent = bsz # number of sentences remaining
|
| 313 |
+
|
| 314 |
+
# number of candidate hypos per step
|
| 315 |
+
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
| 316 |
+
|
| 317 |
+
# offset arrays for converting between different indexing schemes
|
| 318 |
+
bbsz_offsets = (
|
| 319 |
+
(torch.arange(0, bsz) * beam_size)
|
| 320 |
+
.unsqueeze(1)
|
| 321 |
+
.type_as(tokens)
|
| 322 |
+
.to(src_tokens.device)
|
| 323 |
+
)
|
| 324 |
+
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
|
| 325 |
+
|
| 326 |
+
reorder_state: Optional[Tensor] = None
|
| 327 |
+
batch_idxs: Optional[Tensor] = None
|
| 328 |
+
|
| 329 |
+
original_batch_idxs: Optional[Tensor] = None
|
| 330 |
+
if "id" in sample and isinstance(sample["id"], Tensor):
|
| 331 |
+
original_batch_idxs = sample["id"]
|
| 332 |
+
else:
|
| 333 |
+
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
|
| 334 |
+
|
| 335 |
+
for step in range(max_len + 1): # one extra step for EOS marker
|
| 336 |
+
# reorder decoder internal states based on the prev choice of beams
|
| 337 |
+
if reorder_state is not None:
|
| 338 |
+
if batch_idxs is not None:
|
| 339 |
+
# update beam indices to take into account removed sentences
|
| 340 |
+
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
|
| 341 |
+
batch_idxs
|
| 342 |
+
)
|
| 343 |
+
reorder_state.view(-1, beam_size).add_(
|
| 344 |
+
corr.unsqueeze(-1) * beam_size
|
| 345 |
+
)
|
| 346 |
+
original_batch_idxs = original_batch_idxs[batch_idxs]
|
| 347 |
+
model.reorder_incremental_state(incremental_states, reorder_state)
|
| 348 |
+
encoder_outs = model.reorder_encoder_out(
|
| 349 |
+
encoder_outs, reorder_state
|
| 350 |
+
)
|
| 351 |
+
with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
|
| 352 |
+
lprobs, avg_attn_scores = model.forward_decoder(
|
| 353 |
+
tokens[:, : step + 1],
|
| 354 |
+
encoder_outs,
|
| 355 |
+
incremental_states,
|
| 356 |
+
self.temperature,
|
| 357 |
+
constraint_trie=self.constraint_trie,
|
| 358 |
+
constraint_start=self.constraint_start,
|
| 359 |
+
constraint_end=self.constraint_end,
|
| 360 |
+
gen_code=self.gen_code,
|
| 361 |
+
zero_shot=self.zero_shot,
|
| 362 |
+
prefix_tokens=prefix_tokens
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
if self.lm_model is not None:
|
| 366 |
+
lm_out = self.lm_model(tokens[:, : step + 1])
|
| 367 |
+
probs = self.lm_model.get_normalized_probs(
|
| 368 |
+
lm_out, log_probs=True, sample=None
|
| 369 |
+
)
|
| 370 |
+
probs = probs[:, -1, :] * self.lm_weight
|
| 371 |
+
lprobs += probs
|
| 372 |
+
# handle prefix tokens (possibly with different lengths)
|
| 373 |
+
if (
|
| 374 |
+
prefix_tokens is not None
|
| 375 |
+
and step < prefix_tokens.size(1)
|
| 376 |
+
and step < max_len
|
| 377 |
+
):
|
| 378 |
+
lprobs, tokens, scores = self._prefix_tokens(
|
| 379 |
+
step, lprobs, scores, tokens, prefix_tokens, beam_size
|
| 380 |
+
)
|
| 381 |
+
elif step < self.min_len:
|
| 382 |
+
# minimum length constraint (does not apply if using prefix_tokens)
|
| 383 |
+
lprobs[:, self.eos] = -math.inf
|
| 384 |
+
|
| 385 |
+
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
|
| 386 |
+
|
| 387 |
+
lprobs[:, self.pad] = -math.inf # never select pad
|
| 388 |
+
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
| 389 |
+
|
| 390 |
+
if (self.gen_code or self.gen_box) and step < max_len:
|
| 391 |
+
lprobs[:, :4] = -math.inf
|
| 392 |
+
if self.gen_box:
|
| 393 |
+
lprobs[:, -1] = -math.inf
|
| 394 |
+
if (step + 1) % 5 == 0:
|
| 395 |
+
lprobs[:, self.constraint_start:59457] = -math.inf
|
| 396 |
+
else:
|
| 397 |
+
lprobs[:, 59457:] = -math.inf
|
| 398 |
+
|
| 399 |
+
# handle max length constraint
|
| 400 |
+
if step >= max_len:
|
| 401 |
+
lprobs[:, : self.eos] = -math.inf
|
| 402 |
+
lprobs[:, self.eos + 1 :] = -math.inf
|
| 403 |
+
if self.ignore_eos:
|
| 404 |
+
lprobs[:, self.eos] = 1
|
| 405 |
+
|
| 406 |
+
# Record attention scores, only support avg_attn_scores is a Tensor
|
| 407 |
+
if avg_attn_scores is not None:
|
| 408 |
+
if attn is None:
|
| 409 |
+
attn = torch.empty(
|
| 410 |
+
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
|
| 411 |
+
).to(scores)
|
| 412 |
+
attn[:, :, step + 1].copy_(avg_attn_scores)
|
| 413 |
+
|
| 414 |
+
scores = scores.type_as(lprobs)
|
| 415 |
+
eos_bbsz_idx = torch.empty(0).to(
|
| 416 |
+
tokens
|
| 417 |
+
) # indices of hypothesis ending with eos (finished sentences)
|
| 418 |
+
eos_scores = torch.empty(0).to(
|
| 419 |
+
scores
|
| 420 |
+
) # scores of hypothesis ending with eos (finished sentences)
|
| 421 |
+
|
| 422 |
+
if self.should_set_src_lengths:
|
| 423 |
+
self.search.set_src_lengths(src_lengths)
|
| 424 |
+
|
| 425 |
+
if self.repeat_ngram_blocker is not None:
|
| 426 |
+
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
|
| 427 |
+
|
| 428 |
+
# Shape: (batch, cand_size)
|
| 429 |
+
cand_scores, cand_indices, cand_beams = self.search.step(
|
| 430 |
+
step,
|
| 431 |
+
lprobs.view(bsz, -1, self.vocab_size),
|
| 432 |
+
scores.view(bsz, beam_size, -1)[:, :, :step],
|
| 433 |
+
tokens[:, : step + 1],
|
| 434 |
+
original_batch_idxs,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# cand_bbsz_idx contains beam indices for the top candidate
|
| 438 |
+
# hypotheses, with a range of values: [0, bsz*beam_size),
|
| 439 |
+
# and dimensions: [bsz, cand_size]
|
| 440 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
| 441 |
+
|
| 442 |
+
# finalize hypotheses that end in eos
|
| 443 |
+
# Shape of eos_mask: (batch size, beam size)
|
| 444 |
+
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
|
| 445 |
+
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
|
| 446 |
+
|
| 447 |
+
# only consider eos when it's among the top beam_size indices
|
| 448 |
+
# Now we know what beam item(s) to finish
|
| 449 |
+
# Shape: 1d list of absolute-numbered
|
| 450 |
+
eos_bbsz_idx = torch.masked_select(
|
| 451 |
+
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
finalized_sents: List[int] = []
|
| 455 |
+
if eos_bbsz_idx.numel() > 0:
|
| 456 |
+
eos_scores = torch.masked_select(
|
| 457 |
+
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
finalized_sents = self.finalize_hypos(
|
| 461 |
+
step,
|
| 462 |
+
eos_bbsz_idx,
|
| 463 |
+
eos_scores,
|
| 464 |
+
tokens,
|
| 465 |
+
scores,
|
| 466 |
+
finalized,
|
| 467 |
+
finished,
|
| 468 |
+
beam_size,
|
| 469 |
+
attn,
|
| 470 |
+
src_lengths,
|
| 471 |
+
max_len,
|
| 472 |
+
)
|
| 473 |
+
num_remaining_sent -= len(finalized_sents)
|
| 474 |
+
|
| 475 |
+
assert num_remaining_sent >= 0
|
| 476 |
+
if num_remaining_sent == 0:
|
| 477 |
+
break
|
| 478 |
+
if self.search.stop_on_max_len and step >= max_len:
|
| 479 |
+
break
|
| 480 |
+
assert step < max_len, f"{step} < {max_len}"
|
| 481 |
+
|
| 482 |
+
# Remove finalized sentences (ones for which {beam_size}
|
| 483 |
+
# finished hypotheses have been generated) from the batch.
|
| 484 |
+
if len(finalized_sents) > 0:
|
| 485 |
+
new_bsz = bsz - len(finalized_sents)
|
| 486 |
+
|
| 487 |
+
# construct batch_idxs which holds indices of batches to keep for the next pass
|
| 488 |
+
batch_mask = torch.ones(
|
| 489 |
+
bsz, dtype=torch.bool, device=cand_indices.device
|
| 490 |
+
)
|
| 491 |
+
batch_mask[finalized_sents] = False
|
| 492 |
+
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
|
| 493 |
+
batch_idxs = torch.arange(
|
| 494 |
+
bsz, device=cand_indices.device
|
| 495 |
+
).masked_select(batch_mask)
|
| 496 |
+
|
| 497 |
+
# Choose the subset of the hypothesized constraints that will continue
|
| 498 |
+
self.search.prune_sentences(batch_idxs)
|
| 499 |
+
|
| 500 |
+
eos_mask = eos_mask[batch_idxs]
|
| 501 |
+
cand_beams = cand_beams[batch_idxs]
|
| 502 |
+
bbsz_offsets.resize_(new_bsz, 1)
|
| 503 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
| 504 |
+
cand_scores = cand_scores[batch_idxs]
|
| 505 |
+
cand_indices = cand_indices[batch_idxs]
|
| 506 |
+
|
| 507 |
+
if prefix_tokens is not None:
|
| 508 |
+
prefix_tokens = prefix_tokens[batch_idxs]
|
| 509 |
+
src_lengths = src_lengths[batch_idxs]
|
| 510 |
+
cands_to_ignore = cands_to_ignore[batch_idxs]
|
| 511 |
+
|
| 512 |
+
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
| 513 |
+
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
| 514 |
+
if attn is not None:
|
| 515 |
+
attn = attn.view(bsz, -1)[batch_idxs].view(
|
| 516 |
+
new_bsz * beam_size, attn.size(1), -1
|
| 517 |
+
)
|
| 518 |
+
bsz = new_bsz
|
| 519 |
+
else:
|
| 520 |
+
batch_idxs = None
|
| 521 |
+
|
| 522 |
+
# Set active_mask so that values > cand_size indicate eos hypos
|
| 523 |
+
# and values < cand_size indicate candidate active hypos.
|
| 524 |
+
# After, the min values per row are the top candidate active hypos
|
| 525 |
+
|
| 526 |
+
# Rewrite the operator since the element wise or is not supported in torchscript.
|
| 527 |
+
|
| 528 |
+
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
|
| 529 |
+
active_mask = torch.add(
|
| 530 |
+
eos_mask.type_as(cand_offsets) * cand_size,
|
| 531 |
+
cand_offsets[: eos_mask.size(1)],
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# get the top beam_size active hypotheses, which are just
|
| 535 |
+
# the hypos with the smallest values in active_mask.
|
| 536 |
+
# {active_hypos} indicates which {beam_size} hypotheses
|
| 537 |
+
# from the list of {2 * beam_size} candidates were
|
| 538 |
+
# selected. Shapes: (batch size, beam size)
|
| 539 |
+
new_cands_to_ignore, active_hypos = torch.topk(
|
| 540 |
+
active_mask, k=beam_size, dim=1, largest=False
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# update cands_to_ignore to ignore any finalized hypos.
|
| 544 |
+
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
| 545 |
+
# Make sure there is at least one active item for each sentence in the batch.
|
| 546 |
+
assert (~cands_to_ignore).any(dim=1).all()
|
| 547 |
+
|
| 548 |
+
# update cands_to_ignore to ignore any finalized hypos
|
| 549 |
+
|
| 550 |
+
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
|
| 551 |
+
# can be selected more than once).
|
| 552 |
+
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
|
| 553 |
+
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
|
| 554 |
+
|
| 555 |
+
active_bbsz_idx = active_bbsz_idx.view(-1)
|
| 556 |
+
active_scores = active_scores.view(-1)
|
| 557 |
+
|
| 558 |
+
# copy tokens and scores for active hypotheses
|
| 559 |
+
|
| 560 |
+
# Set the tokens for each beam (can select the same row more than once)
|
| 561 |
+
tokens[:, : step + 1] = torch.index_select(
|
| 562 |
+
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
|
| 563 |
+
)
|
| 564 |
+
# Select the next token for each of them
|
| 565 |
+
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
|
| 566 |
+
cand_indices, dim=1, index=active_hypos
|
| 567 |
+
)
|
| 568 |
+
if step > 0:
|
| 569 |
+
scores[:, :step] = torch.index_select(
|
| 570 |
+
scores[:, :step], dim=0, index=active_bbsz_idx
|
| 571 |
+
)
|
| 572 |
+
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
|
| 573 |
+
cand_scores, dim=1, index=active_hypos
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Update constraints based on which candidates were selected for the next beam
|
| 577 |
+
self.search.update_constraints(active_hypos)
|
| 578 |
+
|
| 579 |
+
# copy attention for active hypotheses
|
| 580 |
+
if attn is not None:
|
| 581 |
+
attn[:, :, : step + 2] = torch.index_select(
|
| 582 |
+
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# reorder incremental state in decoder
|
| 586 |
+
reorder_state = active_bbsz_idx
|
| 587 |
+
|
| 588 |
+
# sort by score descending
|
| 589 |
+
for sent in range(len(finalized)):
|
| 590 |
+
scores = torch.tensor(
|
| 591 |
+
[float(elem["score"].item()) for elem in finalized[sent]]
|
| 592 |
+
)
|
| 593 |
+
_, sorted_scores_indices = torch.sort(scores, descending=True)
|
| 594 |
+
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
|
| 595 |
+
finalized[sent] = torch.jit.annotate(
|
| 596 |
+
List[Dict[str, Tensor]], finalized[sent]
|
| 597 |
+
)
|
| 598 |
+
return finalized
|
| 599 |
+
|
| 600 |
+
def _prefix_tokens(
|
| 601 |
+
self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
|
| 602 |
+
):
|
| 603 |
+
"""Handle prefix tokens"""
|
| 604 |
+
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
|
| 605 |
+
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
| 606 |
+
prefix_mask = prefix_toks.ne(self.pad)
|
| 607 |
+
if self.constraint_trie is None:
|
| 608 |
+
lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
|
| 609 |
+
else:
|
| 610 |
+
lprobs[prefix_mask] = -math.inf
|
| 611 |
+
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
|
| 612 |
+
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
|
| 613 |
+
)
|
| 614 |
+
# if prefix includes eos, then we should make sure tokens and
|
| 615 |
+
# scores are the same across all beams
|
| 616 |
+
eos_mask = prefix_toks.eq(self.eos)
|
| 617 |
+
if eos_mask.any():
|
| 618 |
+
# validate that the first beam matches the prefix
|
| 619 |
+
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
|
| 620 |
+
:, 0, 1 : step + 1
|
| 621 |
+
]
|
| 622 |
+
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
|
| 623 |
+
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
|
| 624 |
+
assert (first_beam == target_prefix).all()
|
| 625 |
+
|
| 626 |
+
# copy tokens, scores and lprobs from the first beam to all beams
|
| 627 |
+
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
|
| 628 |
+
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
|
| 629 |
+
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
|
| 630 |
+
return lprobs, tokens, scores
|
| 631 |
+
|
| 632 |
+
def replicate_first_beam(self, tensor, mask, beam_size: int):
|
| 633 |
+
tensor = tensor.view(-1, beam_size, tensor.size(-1))
|
| 634 |
+
tensor[mask] = tensor[mask][:, :1, :]
|
| 635 |
+
return tensor.view(-1, tensor.size(-1))
|
| 636 |
+
|
| 637 |
+
def finalize_hypos(
|
| 638 |
+
self,
|
| 639 |
+
step: int,
|
| 640 |
+
bbsz_idx,
|
| 641 |
+
eos_scores,
|
| 642 |
+
tokens,
|
| 643 |
+
scores,
|
| 644 |
+
finalized: List[List[Dict[str, Tensor]]],
|
| 645 |
+
finished: List[bool],
|
| 646 |
+
beam_size: int,
|
| 647 |
+
attn: Optional[Tensor],
|
| 648 |
+
src_lengths,
|
| 649 |
+
max_len: int,
|
| 650 |
+
):
|
| 651 |
+
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
|
| 652 |
+
A sentence is finalized when {beam_size} finished items have been collected for it.
|
| 653 |
+
|
| 654 |
+
Returns number of sentences (not beam items) being finalized.
|
| 655 |
+
These will be removed from the batch and not processed further.
|
| 656 |
+
Args:
|
| 657 |
+
bbsz_idx (Tensor):
|
| 658 |
+
"""
|
| 659 |
+
assert bbsz_idx.numel() == eos_scores.numel()
|
| 660 |
+
|
| 661 |
+
# clone relevant token and attention tensors.
|
| 662 |
+
# tokens is (batch * beam, max_len). So the index_select
|
| 663 |
+
# gets the newly EOS rows, then selects cols 1..{step + 2}
|
| 664 |
+
tokens_clone = tokens.index_select(0, bbsz_idx)[
|
| 665 |
+
:, 1 : step + 2
|
| 666 |
+
] # skip the first index, which is EOS
|
| 667 |
+
|
| 668 |
+
tokens_clone[:, step] = self.eos
|
| 669 |
+
attn_clone = (
|
| 670 |
+
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
|
| 671 |
+
if attn is not None
|
| 672 |
+
else None
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
# compute scores per token position
|
| 676 |
+
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
|
| 677 |
+
pos_scores[:, step] = eos_scores
|
| 678 |
+
# convert from cumulative to per-position scores
|
| 679 |
+
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
|
| 680 |
+
|
| 681 |
+
# normalize sentence-level scores
|
| 682 |
+
if self.normalize_scores:
|
| 683 |
+
eos_scores /= (step + 1) ** self.len_penalty
|
| 684 |
+
|
| 685 |
+
# cum_unfin records which sentences in the batch are finished.
|
| 686 |
+
# It helps match indexing between (a) the original sentences
|
| 687 |
+
# in the batch and (b) the current, possibly-reduced set of
|
| 688 |
+
# sentences.
|
| 689 |
+
cum_unfin: List[int] = []
|
| 690 |
+
prev = 0
|
| 691 |
+
for f in finished:
|
| 692 |
+
if f:
|
| 693 |
+
prev += 1
|
| 694 |
+
else:
|
| 695 |
+
cum_unfin.append(prev)
|
| 696 |
+
cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
|
| 697 |
+
|
| 698 |
+
unfin_idx = bbsz_idx // beam_size
|
| 699 |
+
sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
|
| 700 |
+
|
| 701 |
+
# Create a set of "{sent}{unfin_idx}", where
|
| 702 |
+
# "unfin_idx" is the index in the current (possibly reduced)
|
| 703 |
+
# list of sentences, and "sent" is the index in the original,
|
| 704 |
+
# unreduced batch
|
| 705 |
+
# For every finished beam item
|
| 706 |
+
# sentence index in the current (possibly reduced) batch
|
| 707 |
+
seen = (sent << 32) + unfin_idx
|
| 708 |
+
unique_seen: List[int] = torch.unique(seen).tolist()
|
| 709 |
+
|
| 710 |
+
if self.match_source_len:
|
| 711 |
+
condition = step > torch.index_select(src_lengths, 0, unfin_idx)
|
| 712 |
+
eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
|
| 713 |
+
sent_list: List[int] = sent.tolist()
|
| 714 |
+
for i in range(bbsz_idx.size()[0]):
|
| 715 |
+
# An input sentence (among those in a batch) is finished when
|
| 716 |
+
# beam_size hypotheses have been collected for it
|
| 717 |
+
if len(finalized[sent_list[i]]) < beam_size:
|
| 718 |
+
if attn_clone is not None:
|
| 719 |
+
# remove padding tokens from attn scores
|
| 720 |
+
hypo_attn = attn_clone[i]
|
| 721 |
+
else:
|
| 722 |
+
hypo_attn = torch.empty(0)
|
| 723 |
+
|
| 724 |
+
finalized[sent_list[i]].append(
|
| 725 |
+
{
|
| 726 |
+
"tokens": tokens_clone[i],
|
| 727 |
+
"score": eos_scores[i],
|
| 728 |
+
"attention": hypo_attn, # src_len x tgt_len
|
| 729 |
+
"alignment": torch.empty(0),
|
| 730 |
+
"positional_scores": pos_scores[i],
|
| 731 |
+
}
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
newly_finished: List[int] = []
|
| 735 |
+
for unique_s in unique_seen:
|
| 736 |
+
# check termination conditions for this sentence
|
| 737 |
+
unique_sent: int = unique_s >> 32
|
| 738 |
+
unique_unfin_idx: int = unique_s - (unique_sent << 32)
|
| 739 |
+
|
| 740 |
+
if not finished[unique_sent] and self.is_finished(
|
| 741 |
+
step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
|
| 742 |
+
):
|
| 743 |
+
finished[unique_sent] = True
|
| 744 |
+
newly_finished.append(unique_unfin_idx)
|
| 745 |
+
|
| 746 |
+
return newly_finished
|
| 747 |
+
|
| 748 |
+
def is_finished(
|
| 749 |
+
self,
|
| 750 |
+
step: int,
|
| 751 |
+
unfin_idx: int,
|
| 752 |
+
max_len: int,
|
| 753 |
+
finalized_sent_len: int,
|
| 754 |
+
beam_size: int,
|
| 755 |
+
):
|
| 756 |
+
"""
|
| 757 |
+
Check whether decoding for a sentence is finished, which
|
| 758 |
+
occurs when the list of finalized sentences has reached the
|
| 759 |
+
beam size, or when we reach the maximum length.
|
| 760 |
+
"""
|
| 761 |
+
assert finalized_sent_len <= beam_size
|
| 762 |
+
if finalized_sent_len == beam_size or step == max_len:
|
| 763 |
+
return True
|
| 764 |
+
return False
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
class EnsembleModel(nn.Module):
|
| 768 |
+
"""A wrapper around an ensemble of models."""
|
| 769 |
+
|
| 770 |
+
def __init__(self, models):
|
| 771 |
+
super().__init__()
|
| 772 |
+
self.models_size = len(models)
|
| 773 |
+
# method '__len__' is not supported in ModuleList for torch script
|
| 774 |
+
self.single_model = models[0]
|
| 775 |
+
self.models = nn.ModuleList(models)
|
| 776 |
+
|
| 777 |
+
self.has_incremental: bool = False
|
| 778 |
+
if all(
|
| 779 |
+
hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
|
| 780 |
+
for m in models
|
| 781 |
+
):
|
| 782 |
+
self.has_incremental = True
|
| 783 |
+
|
| 784 |
+
def forward(self):
|
| 785 |
+
pass
|
| 786 |
+
|
| 787 |
+
def has_encoder(self):
|
| 788 |
+
return hasattr(self.single_model, "encoder")
|
| 789 |
+
|
| 790 |
+
def has_incremental_states(self):
|
| 791 |
+
return self.has_incremental
|
| 792 |
+
|
| 793 |
+
def max_decoder_positions(self):
|
| 794 |
+
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
|
| 795 |
+
|
| 796 |
+
@torch.jit.export
|
| 797 |
+
def forward_encoder(self, net_input: Dict[str, Tensor]):
|
| 798 |
+
if not self.has_encoder():
|
| 799 |
+
return None
|
| 800 |
+
return [model.encoder.forward_torchscript(net_input) for model in self.models]
|
| 801 |
+
|
| 802 |
+
@torch.jit.export
|
| 803 |
+
def forward_decoder(
|
| 804 |
+
self,
|
| 805 |
+
tokens,
|
| 806 |
+
encoder_outs: List[Dict[str, List[Tensor]]],
|
| 807 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
| 808 |
+
temperature: float = 1.0,
|
| 809 |
+
constraint_trie=None,
|
| 810 |
+
constraint_start=None,
|
| 811 |
+
constraint_end=None,
|
| 812 |
+
gen_code=False,
|
| 813 |
+
zero_shot=False,
|
| 814 |
+
prefix_tokens=None
|
| 815 |
+
):
|
| 816 |
+
log_probs = []
|
| 817 |
+
avg_attn: Optional[Tensor] = None
|
| 818 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None
|
| 819 |
+
code_mask = (tokens.new_ones(tokens.size(0))*gen_code).bool()
|
| 820 |
+
for i, model in enumerate(self.models):
|
| 821 |
+
if self.has_encoder():
|
| 822 |
+
encoder_out = encoder_outs[i]
|
| 823 |
+
# decode each model
|
| 824 |
+
if self.has_incremental_states():
|
| 825 |
+
decoder_out = model.decoder.forward(
|
| 826 |
+
tokens,
|
| 827 |
+
code_masks=code_mask,
|
| 828 |
+
encoder_out=encoder_out,
|
| 829 |
+
incremental_state=incremental_states[i],
|
| 830 |
+
)
|
| 831 |
+
else:
|
| 832 |
+
if hasattr(model, "decoder"):
|
| 833 |
+
decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out)
|
| 834 |
+
else:
|
| 835 |
+
decoder_out = model.forward(tokens)
|
| 836 |
+
|
| 837 |
+
attn: Optional[Tensor] = None
|
| 838 |
+
decoder_len = len(decoder_out)
|
| 839 |
+
if decoder_len > 1 and decoder_out[1] is not None:
|
| 840 |
+
if isinstance(decoder_out[1], Tensor):
|
| 841 |
+
attn = decoder_out[1]
|
| 842 |
+
else:
|
| 843 |
+
attn_holder = decoder_out[1]["attn"]
|
| 844 |
+
if isinstance(attn_holder, Tensor):
|
| 845 |
+
attn = attn_holder
|
| 846 |
+
elif attn_holder is not None:
|
| 847 |
+
attn = attn_holder[0]
|
| 848 |
+
if attn is not None:
|
| 849 |
+
attn = attn[:, -1, :]
|
| 850 |
+
|
| 851 |
+
decoder_out_tuple = (
|
| 852 |
+
decoder_out[0][:, -1:, :].div_(temperature),
|
| 853 |
+
None if decoder_len <= 1 else decoder_out[1],
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size(0) if prefix_tokens is not None else 0
|
| 857 |
+
if constraint_trie is not None and not zero_shot:
|
| 858 |
+
assert constraint_start is None and constraint_end is None
|
| 859 |
+
constraint_masks = decoder_out_tuple[0].new_zeros(decoder_out_tuple[0].size()).bool()
|
| 860 |
+
constraint_prefix_tokens = tokens.tolist()
|
| 861 |
+
for token_index, constraint_prefix_token in enumerate(constraint_prefix_tokens):
|
| 862 |
+
prefix_len = prefix_tokens[token_index // beam_size].ne(1).sum().item() if prefix_tokens is not None else 0
|
| 863 |
+
if len(constraint_prefix_token) > prefix_len:
|
| 864 |
+
constraint_prefix_token = [0] + constraint_prefix_token[prefix_len+1:]
|
| 865 |
+
constraint_nodes = constraint_trie.get_next_layer(constraint_prefix_token)
|
| 866 |
+
constraint_masks[token_index][:, constraint_nodes] = True
|
| 867 |
+
else:
|
| 868 |
+
constraint_masks[token_index] = True
|
| 869 |
+
decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf)
|
| 870 |
+
if constraint_start is not None and constraint_end is not None and not zero_shot:
|
| 871 |
+
assert constraint_trie is None
|
| 872 |
+
decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf
|
| 873 |
+
decoder_out_tuple[0][:, :, constraint_end:] = -math.inf
|
| 874 |
+
|
| 875 |
+
probs = model.get_normalized_probs(
|
| 876 |
+
decoder_out_tuple, log_probs=True, sample=None
|
| 877 |
+
)
|
| 878 |
+
if constraint_trie is not None and zero_shot:
|
| 879 |
+
assert constraint_start is None and constraint_end is None
|
| 880 |
+
constraint_masks = decoder_out_tuple[0].new_zeros(decoder_out_tuple[0].size()).bool()
|
| 881 |
+
constraint_prefix_tokens = tokens.tolist()
|
| 882 |
+
for token_index, constraint_prefix_token in enumerate(constraint_prefix_tokens):
|
| 883 |
+
constraint_nodes = constraint_trie.get_next_layer(constraint_prefix_token)
|
| 884 |
+
constraint_masks[token_index][:, constraint_nodes] = True
|
| 885 |
+
probs.masked_fill_(~constraint_masks, -math.inf)
|
| 886 |
+
if constraint_start is not None and constraint_end is not None and zero_shot:
|
| 887 |
+
assert constraint_trie is None
|
| 888 |
+
probs[:, :, 4:constraint_start] = -math.inf
|
| 889 |
+
probs[:, :, constraint_end:] = -math.inf
|
| 890 |
+
probs = probs[:, -1, :]
|
| 891 |
+
if self.models_size == 1:
|
| 892 |
+
return probs, attn
|
| 893 |
+
|
| 894 |
+
log_probs.append(probs)
|
| 895 |
+
if attn is not None:
|
| 896 |
+
if avg_attn is None:
|
| 897 |
+
avg_attn = attn
|
| 898 |
+
else:
|
| 899 |
+
avg_attn.add_(attn)
|
| 900 |
+
|
| 901 |
+
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
|
| 902 |
+
self.models_size
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
if avg_attn is not None:
|
| 906 |
+
avg_attn.div_(self.models_size)
|
| 907 |
+
return avg_probs, avg_attn
|
| 908 |
+
|
| 909 |
+
@torch.jit.export
|
| 910 |
+
def reorder_encoder_out(
|
| 911 |
+
self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
|
| 912 |
+
):
|
| 913 |
+
"""
|
| 914 |
+
Reorder encoder output according to *new_order*.
|
| 915 |
+
|
| 916 |
+
Args:
|
| 917 |
+
encoder_out: output from the ``forward()`` method
|
| 918 |
+
new_order (LongTensor): desired order
|
| 919 |
+
|
| 920 |
+
Returns:
|
| 921 |
+
*encoder_out* rearranged according to *new_order*
|
| 922 |
+
"""
|
| 923 |
+
new_outs: List[Dict[str, List[Tensor]]] = []
|
| 924 |
+
if not self.has_encoder():
|
| 925 |
+
return new_outs
|
| 926 |
+
for i, model in enumerate(self.models):
|
| 927 |
+
assert encoder_outs is not None
|
| 928 |
+
new_outs.append(
|
| 929 |
+
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
|
| 930 |
+
)
|
| 931 |
+
return new_outs
|
| 932 |
+
|
| 933 |
+
@torch.jit.export
|
| 934 |
+
def reorder_incremental_state(
|
| 935 |
+
self,
|
| 936 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
| 937 |
+
new_order,
|
| 938 |
+
):
|
| 939 |
+
if not self.has_incremental_states():
|
| 940 |
+
return
|
| 941 |
+
for i, model in enumerate(self.models):
|
| 942 |
+
model.decoder.reorder_incremental_state_scripting(
|
| 943 |
+
incremental_states[i], new_order
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
class SequenceGeneratorWithAlignment(SequenceGenerator):
|
| 948 |
+
def __init__(
|
| 949 |
+
self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
|
| 950 |
+
):
|
| 951 |
+
"""Generates translations of a given source sentence.
|
| 952 |
+
|
| 953 |
+
Produces alignments following "Jointly Learning to Align and
|
| 954 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
| 955 |
+
|
| 956 |
+
Args:
|
| 957 |
+
left_pad_target (bool, optional): Whether or not the
|
| 958 |
+
hypothesis should be left padded or not when they are
|
| 959 |
+
teacher forced for generating alignments.
|
| 960 |
+
"""
|
| 961 |
+
super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
|
| 962 |
+
self.left_pad_target = left_pad_target
|
| 963 |
+
|
| 964 |
+
if print_alignment == "hard":
|
| 965 |
+
self.extract_alignment = utils.extract_hard_alignment
|
| 966 |
+
elif print_alignment == "soft":
|
| 967 |
+
self.extract_alignment = utils.extract_soft_alignment
|
| 968 |
+
|
| 969 |
+
@torch.no_grad()
|
| 970 |
+
def generate(self, models, sample, **kwargs):
|
| 971 |
+
finalized = super()._generate(sample, **kwargs)
|
| 972 |
+
|
| 973 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
| 974 |
+
bsz = src_tokens.shape[0]
|
| 975 |
+
beam_size = self.beam_size
|
| 976 |
+
(
|
| 977 |
+
src_tokens,
|
| 978 |
+
src_lengths,
|
| 979 |
+
prev_output_tokens,
|
| 980 |
+
tgt_tokens,
|
| 981 |
+
) = self._prepare_batch_for_alignment(sample, finalized)
|
| 982 |
+
if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
|
| 983 |
+
attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
|
| 984 |
+
else:
|
| 985 |
+
attn = [
|
| 986 |
+
finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
|
| 987 |
+
for i in range(bsz * beam_size)
|
| 988 |
+
]
|
| 989 |
+
|
| 990 |
+
if src_tokens.device != "cpu":
|
| 991 |
+
src_tokens = src_tokens.to("cpu")
|
| 992 |
+
tgt_tokens = tgt_tokens.to("cpu")
|
| 993 |
+
attn = [i.to("cpu") for i in attn]
|
| 994 |
+
|
| 995 |
+
# Process the attn matrix to extract hard alignments.
|
| 996 |
+
for i in range(bsz * beam_size):
|
| 997 |
+
alignment = self.extract_alignment(
|
| 998 |
+
attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
|
| 999 |
+
)
|
| 1000 |
+
finalized[i // beam_size][i % beam_size]["alignment"] = alignment
|
| 1001 |
+
return finalized
|
| 1002 |
+
|
| 1003 |
+
def _prepare_batch_for_alignment(self, sample, hypothesis):
|
| 1004 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
| 1005 |
+
bsz = src_tokens.shape[0]
|
| 1006 |
+
src_tokens = (
|
| 1007 |
+
src_tokens[:, None, :]
|
| 1008 |
+
.expand(-1, self.beam_size, -1)
|
| 1009 |
+
.contiguous()
|
| 1010 |
+
.view(bsz * self.beam_size, -1)
|
| 1011 |
+
)
|
| 1012 |
+
src_lengths = sample["net_input"]["src_lengths"]
|
| 1013 |
+
src_lengths = (
|
| 1014 |
+
src_lengths[:, None]
|
| 1015 |
+
.expand(-1, self.beam_size)
|
| 1016 |
+
.contiguous()
|
| 1017 |
+
.view(bsz * self.beam_size)
|
| 1018 |
+
)
|
| 1019 |
+
prev_output_tokens = data_utils.collate_tokens(
|
| 1020 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
| 1021 |
+
self.pad,
|
| 1022 |
+
self.eos,
|
| 1023 |
+
self.left_pad_target,
|
| 1024 |
+
move_eos_to_beginning=True,
|
| 1025 |
+
)
|
| 1026 |
+
tgt_tokens = data_utils.collate_tokens(
|
| 1027 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
| 1028 |
+
self.pad,
|
| 1029 |
+
self.eos,
|
| 1030 |
+
self.left_pad_target,
|
| 1031 |
+
move_eos_to_beginning=False,
|
| 1032 |
+
)
|
| 1033 |
+
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
class EnsembleModelWithAlignment(EnsembleModel):
|
| 1037 |
+
"""A wrapper around an ensemble of models."""
|
| 1038 |
+
|
| 1039 |
+
def __init__(self, models):
|
| 1040 |
+
super().__init__(models)
|
| 1041 |
+
|
| 1042 |
+
def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
|
| 1043 |
+
avg_attn = None
|
| 1044 |
+
for model in self.models:
|
| 1045 |
+
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
|
| 1046 |
+
attn = decoder_out[1]["attn"][0]
|
| 1047 |
+
if avg_attn is None:
|
| 1048 |
+
avg_attn = attn
|
| 1049 |
+
else:
|
| 1050 |
+
avg_attn.add_(attn)
|
| 1051 |
+
if len(self.models) > 1:
|
| 1052 |
+
avg_attn.div_(len(self.models))
|
| 1053 |
+
return avg_attn
|
notebooks/caption_infer.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ofa_module/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import data
|
| 2 |
+
import models
|
| 3 |
+
import tasks
|
| 4 |
+
import criterions
|
| 5 |
+
import utils
|
run_scripts/caption/coco_eval.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sys
|
| 3 |
+
import os.path as op
|
| 4 |
+
|
| 5 |
+
from pycocotools.coco import COCO
|
| 6 |
+
from pycocoevalcap.eval import COCOEvalCap
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def evaluate_on_coco_caption(res_file, label_file, outfile=None):
|
| 10 |
+
"""
|
| 11 |
+
res_file: txt file, each row is [image_key, json format list of captions].
|
| 12 |
+
Each caption is a dict, with fields "caption", "conf".
|
| 13 |
+
label_file: JSON file of ground truth captions in COCO format.
|
| 14 |
+
"""
|
| 15 |
+
coco = COCO(label_file)
|
| 16 |
+
cocoRes = coco.loadRes(res_file)
|
| 17 |
+
cocoEval = COCOEvalCap(coco, cocoRes)
|
| 18 |
+
|
| 19 |
+
# evaluate on a subset of images by setting
|
| 20 |
+
# cocoEval.params['image_id'] = cocoRes.getImgIds()
|
| 21 |
+
# please remove this line when evaluating the full validation set
|
| 22 |
+
cocoEval.params['image_id'] = cocoRes.getImgIds()
|
| 23 |
+
|
| 24 |
+
# evaluate results
|
| 25 |
+
# SPICE will take a few minutes the first time, but speeds up due to caching
|
| 26 |
+
cocoEval.evaluate()
|
| 27 |
+
result = cocoEval.eval
|
| 28 |
+
if not outfile:
|
| 29 |
+
print(result)
|
| 30 |
+
else:
|
| 31 |
+
with open(outfile, 'w') as fp:
|
| 32 |
+
json.dump(result, fp, indent=4)
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
if len(sys.argv) == 3:
|
| 38 |
+
evaluate_on_coco_caption(sys.argv[1], sys.argv[2])
|
| 39 |
+
elif len(sys.argv) == 4:
|
| 40 |
+
evaluate_on_coco_caption(sys.argv[1], sys.argv[2], sys.argv[3])
|
| 41 |
+
else:
|
| 42 |
+
raise NotImplementedError
|
run_scripts/caption/evaluate_caption.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
user_dir=../../ofa_module
|
| 4 |
+
bpe_dir=../../utils/BPE
|
| 5 |
+
|
| 6 |
+
data=../../dataset/caption_data/caption_test.tsv
|
| 7 |
+
path=../../checkpoints/caption_large_best_clean.pt
|
| 8 |
+
result_path=../../results/caption
|
| 9 |
+
selected_cols=1,4,2
|
| 10 |
+
split='test'
|
| 11 |
+
|
| 12 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ../../evaluate.py \
|
| 13 |
+
${data} \
|
| 14 |
+
--path=${path} \
|
| 15 |
+
--user-dir=${user_dir} \
|
| 16 |
+
--task=caption \
|
| 17 |
+
--batch-size=16 \
|
| 18 |
+
--log-format=simple --log-interval=10 \
|
| 19 |
+
--seed=7 \
|
| 20 |
+
--gen-subset=${split} \
|
| 21 |
+
--results-path=${result_path} \
|
| 22 |
+
--beam=5 \
|
| 23 |
+
--max-len-b=16 \
|
| 24 |
+
--no-repeat-ngram-size=3 \
|
| 25 |
+
--fp16 \
|
| 26 |
+
--num-workers=0 \
|
| 27 |
+
--model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"eval_cider\":False,\"selected_cols\":\"${selected_cols}\"}"
|
| 28 |
+
|
| 29 |
+
python coco_eval.py ../../results/caption/test_predict.json ../../dataset/caption_data/test_caption_coco_format.json
|
run_scripts/caption/train_caption_stage1.sh
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
|
| 3 |
+
log_dir=./stage1_logs
|
| 4 |
+
save_dir=./stage1_checkpoints
|
| 5 |
+
mkdir -p $log_dir $save_dir
|
| 6 |
+
|
| 7 |
+
bpe_dir=../../utils/BPE
|
| 8 |
+
user_dir=../../ofa_module
|
| 9 |
+
|
| 10 |
+
data_dir=../../dataset/caption_data
|
| 11 |
+
data=${data_dir}/caption_stage1_train.tsv,${data_dir}/caption_val.tsv
|
| 12 |
+
restore_file=../../checkpoints/ofa_large.pt
|
| 13 |
+
selected_cols=0,4,2
|
| 14 |
+
|
| 15 |
+
task=caption
|
| 16 |
+
arch=ofa_large
|
| 17 |
+
criterion=ajust_label_smoothed_cross_entropy
|
| 18 |
+
label_smoothing=0.1
|
| 19 |
+
lr=1e-5
|
| 20 |
+
max_epoch=5
|
| 21 |
+
warmup_ratio=0.06
|
| 22 |
+
batch_size=8
|
| 23 |
+
update_freq=4
|
| 24 |
+
resnet_drop_path_rate=0.0
|
| 25 |
+
encoder_drop_path_rate=0.1
|
| 26 |
+
decoder_drop_path_rate=0.1
|
| 27 |
+
dropout=0.1
|
| 28 |
+
attention_dropout=0.0
|
| 29 |
+
max_src_length=80
|
| 30 |
+
max_tgt_length=20
|
| 31 |
+
num_bins=1000
|
| 32 |
+
patch_image_size=480
|
| 33 |
+
eval_cider_cached=${data_dir}/cider_cached_tokens/coco-valid-words.p
|
| 34 |
+
drop_worst_ratio=0.2
|
| 35 |
+
|
| 36 |
+
for max_epoch in {2,}; do
|
| 37 |
+
echo "max_epoch "${max_epoch}
|
| 38 |
+
for warmup_ratio in {0.06,}; do
|
| 39 |
+
echo "warmup_ratio "${warmup_ratio}
|
| 40 |
+
for drop_worst_after in {2500,}; do
|
| 41 |
+
echo "drop_worst_after "${drop_worst_after}
|
| 42 |
+
|
| 43 |
+
log_file=${log_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after}".log"
|
| 44 |
+
save_path=${save_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after}
|
| 45 |
+
mkdir -p $save_path
|
| 46 |
+
|
| 47 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../train.py \
|
| 48 |
+
$data \
|
| 49 |
+
--selected-cols=${selected_cols} \
|
| 50 |
+
--bpe-dir=${bpe_dir} \
|
| 51 |
+
--user-dir=${user_dir} \
|
| 52 |
+
--restore-file=${restore_file} \
|
| 53 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 54 |
+
--save-dir=${save_path} \
|
| 55 |
+
--task=${task} \
|
| 56 |
+
--arch=${arch} \
|
| 57 |
+
--criterion=${criterion} \
|
| 58 |
+
--label-smoothing=${label_smoothing} \
|
| 59 |
+
--batch-size=${batch_size} \
|
| 60 |
+
--update-freq=${update_freq} \
|
| 61 |
+
--encoder-normalize-before \
|
| 62 |
+
--decoder-normalize-before \
|
| 63 |
+
--share-decoder-input-output-embed \
|
| 64 |
+
--share-all-embeddings \
|
| 65 |
+
--layernorm-embedding \
|
| 66 |
+
--patch-layernorm-embedding \
|
| 67 |
+
--code-layernorm-embedding \
|
| 68 |
+
--resnet-drop-path-rate=${resnet_drop_path_rate} \
|
| 69 |
+
--encoder-drop-path-rate=${encoder_drop_path_rate} \
|
| 70 |
+
--decoder-drop-path-rate=${decoder_drop_path_rate} \
|
| 71 |
+
--dropout=${dropout} \
|
| 72 |
+
--attention-dropout=${attention_dropout} \
|
| 73 |
+
--weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \
|
| 74 |
+
--lr-scheduler=polynomial_decay --lr=${lr} \
|
| 75 |
+
--max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \
|
| 76 |
+
--log-format=simple --log-interval=10 \
|
| 77 |
+
--fixed-validation-seed=7 \
|
| 78 |
+
--no-epoch-checkpoints --keep-best-checkpoints=1 \
|
| 79 |
+
--save-interval=1 --validate-interval=1 \
|
| 80 |
+
--save-interval-updates=500 --validate-interval-updates=500 \
|
| 81 |
+
--eval-cider \
|
| 82 |
+
--eval-cider-cached-tokens=${eval_cider_cached} \
|
| 83 |
+
--eval-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \
|
| 84 |
+
--best-checkpoint-metric=cider --maximize-best-checkpoint-metric \
|
| 85 |
+
--max-src-length=${max_src_length} \
|
| 86 |
+
--max-tgt-length=${max_tgt_length} \
|
| 87 |
+
--find-unused-parameters \
|
| 88 |
+
--freeze-encoder-embedding \
|
| 89 |
+
--freeze-decoder-embedding \
|
| 90 |
+
--add-type-embedding \
|
| 91 |
+
--scale-attn \
|
| 92 |
+
--scale-fc \
|
| 93 |
+
--scale-heads \
|
| 94 |
+
--disable-entangle \
|
| 95 |
+
--num-bins=${num_bins} \
|
| 96 |
+
--patch-image-size=${patch_image_size} \
|
| 97 |
+
--drop-worst-ratio=${drop_worst_ratio} \
|
| 98 |
+
--drop-worst-after=${drop_worst_after} \
|
| 99 |
+
--fp16 \
|
| 100 |
+
--fp16-scale-window=512 \
|
| 101 |
+
--num-workers=0 >> ${log_file} 2>&1
|
| 102 |
+
done
|
| 103 |
+
done
|
| 104 |
+
done
|
run_scripts/caption/train_caption_stage2.sh
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
|
| 3 |
+
log_dir=./stage2_logs
|
| 4 |
+
save_dir=./stage2_checkpoints
|
| 5 |
+
mkdir -p $log_dir $save_dir
|
| 6 |
+
|
| 7 |
+
bpe_dir=../../utils/BPE
|
| 8 |
+
user_dir=../../ofa_module
|
| 9 |
+
|
| 10 |
+
data_dir=../../dataset/caption_data
|
| 11 |
+
data=${data_dir}/caption_stage2_train.tsv,${data_dir}/caption_val.tsv
|
| 12 |
+
restore_file=../../checkpoints/caption_stage1_best.pt
|
| 13 |
+
selected_cols=1,4,2
|
| 14 |
+
|
| 15 |
+
task=caption
|
| 16 |
+
arch=ofa_large
|
| 17 |
+
criterion=scst_reward_criterion
|
| 18 |
+
label_smoothing=0.1
|
| 19 |
+
lr=1e-5
|
| 20 |
+
max_epoch=5
|
| 21 |
+
warmup_ratio=0.06
|
| 22 |
+
batch_size=2
|
| 23 |
+
update_freq=4
|
| 24 |
+
resnet_drop_path_rate=0.0
|
| 25 |
+
encoder_drop_path_rate=0.0
|
| 26 |
+
decoder_drop_path_rate=0.0
|
| 27 |
+
dropout=0.0
|
| 28 |
+
attention_dropout=0.0
|
| 29 |
+
max_src_length=80
|
| 30 |
+
max_tgt_length=20
|
| 31 |
+
num_bins=1000
|
| 32 |
+
patch_image_size=480
|
| 33 |
+
eval_cider_cached=${data_dir}/cider_cached_tokens/coco-valid-words.p
|
| 34 |
+
scst_cider_cached=${data_dir}/cider_cached_tokens/coco-train-words.p
|
| 35 |
+
|
| 36 |
+
for lr in {1e-5,}; do
|
| 37 |
+
echo "lr "${lr}
|
| 38 |
+
for max_epoch in {4,}; do
|
| 39 |
+
echo "max_epoch "${max_epoch}
|
| 40 |
+
|
| 41 |
+
log_file=${log_dir}/${lr}"_"${max_epoch}".log"
|
| 42 |
+
save_path=${save_dir}/${lr}"_"${max_epoch}
|
| 43 |
+
mkdir -p $save_path
|
| 44 |
+
|
| 45 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 ../../train.py \
|
| 46 |
+
$data \
|
| 47 |
+
--selected-cols=${selected_cols} \
|
| 48 |
+
--bpe-dir=${bpe_dir} \
|
| 49 |
+
--user-dir=${user_dir} \
|
| 50 |
+
--restore-file=${restore_file} \
|
| 51 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 52 |
+
--save-dir=${save_path} \
|
| 53 |
+
--task=${task} \
|
| 54 |
+
--arch=${arch} \
|
| 55 |
+
--criterion=${criterion} \
|
| 56 |
+
--batch-size=${batch_size} \
|
| 57 |
+
--update-freq=${update_freq} \
|
| 58 |
+
--encoder-normalize-before \
|
| 59 |
+
--decoder-normalize-before \
|
| 60 |
+
--share-decoder-input-output-embed \
|
| 61 |
+
--share-all-embeddings \
|
| 62 |
+
--layernorm-embedding \
|
| 63 |
+
--patch-layernorm-embedding \
|
| 64 |
+
--code-layernorm-embedding \
|
| 65 |
+
--resnet-drop-path-rate=${resnet_drop_path_rate} \
|
| 66 |
+
--encoder-drop-path-rate=${encoder_drop_path_rate} \
|
| 67 |
+
--decoder-drop-path-rate=${decoder_drop_path_rate} \
|
| 68 |
+
--dropout=${dropout} \
|
| 69 |
+
--attention-dropout=${attention_dropout} \
|
| 70 |
+
--weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \
|
| 71 |
+
--lr-scheduler=polynomial_decay --lr=${lr} \
|
| 72 |
+
--max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \
|
| 73 |
+
--log-format=simple --log-interval=10 \
|
| 74 |
+
--fixed-validation-seed=7 \
|
| 75 |
+
--no-epoch-checkpoints --keep-best-checkpoints=1 \
|
| 76 |
+
--save-interval=1 --validate-interval=1 \
|
| 77 |
+
--save-interval-updates=500 --validate-interval-updates=500 \
|
| 78 |
+
--eval-cider \
|
| 79 |
+
--eval-cider-cached-tokens=${eval_cider_cached} \
|
| 80 |
+
--eval-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \
|
| 81 |
+
--best-checkpoint-metric=cider --maximize-best-checkpoint-metric \
|
| 82 |
+
--max-src-length=${max_src_length} \
|
| 83 |
+
--max-tgt-length=${max_tgt_length} \
|
| 84 |
+
--find-unused-parameters \
|
| 85 |
+
--freeze-encoder-embedding \
|
| 86 |
+
--freeze-decoder-embedding \
|
| 87 |
+
--add-type-embedding \
|
| 88 |
+
--scale-attn \
|
| 89 |
+
--scale-fc \
|
| 90 |
+
--scale-heads \
|
| 91 |
+
--disable-entangle \
|
| 92 |
+
--num-bins=${num_bins} \
|
| 93 |
+
--patch-image-size=${patch_image_size} \
|
| 94 |
+
--scst \
|
| 95 |
+
--scst-cider-cached-tokens=${scst_cider_cached} \
|
| 96 |
+
--scst-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \
|
| 97 |
+
--memory-efficient-fp16 \
|
| 98 |
+
--fp16-scale-window=512 \
|
| 99 |
+
--num-workers=0 >> ${log_file} 2>&1
|
| 100 |
+
done
|
| 101 |
+
done
|
tasks/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mm_tasks import *
|
| 2 |
+
from .ofa_task import OFATask
|
tasks/mm_tasks/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .caption import CaptionTask
|
tasks/mm_tasks/caption.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from argparse import Namespace
|
| 11 |
+
from itertools import zip_longest
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import sacrebleu
|
| 16 |
+
import string
|
| 17 |
+
from fairseq import metrics, utils
|
| 18 |
+
from fairseq.tasks import register_task
|
| 19 |
+
|
| 20 |
+
from tasks.ofa_task import OFATask, OFAConfig
|
| 21 |
+
from data.mm_data.caption_dataset import CaptionDataset
|
| 22 |
+
from data.file_dataset import FileDataset
|
| 23 |
+
from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
|
| 24 |
+
|
| 25 |
+
EVAL_BLEU_ORDER = 4
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class CaptionConfig(OFAConfig):
|
| 32 |
+
eval_bleu: bool = field(
|
| 33 |
+
default=False, metadata={"help": "evaluation with BLEU scores"}
|
| 34 |
+
)
|
| 35 |
+
eval_cider: bool = field(
|
| 36 |
+
default=False, metadata={"help": "evaluation with CIDEr scores"}
|
| 37 |
+
)
|
| 38 |
+
eval_args: Optional[str] = field(
|
| 39 |
+
default='{}',
|
| 40 |
+
metadata={
|
| 41 |
+
"help": 'generation args for BLUE or CIDEr scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
|
| 42 |
+
},
|
| 43 |
+
)
|
| 44 |
+
eval_print_samples: bool = field(
|
| 45 |
+
default=False, metadata={"help": "print sample generations during validation"}
|
| 46 |
+
)
|
| 47 |
+
eval_cider_cached_tokens: Optional[str] = field(
|
| 48 |
+
default=None,
|
| 49 |
+
metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
scst: bool = field(
|
| 53 |
+
default=False, metadata={"help": "Self-critical sequence training"}
|
| 54 |
+
)
|
| 55 |
+
scst_args: str = field(
|
| 56 |
+
default='{}',
|
| 57 |
+
metadata={
|
| 58 |
+
"help": 'generation args for Self-critical sequence training, as JSON string'
|
| 59 |
+
},
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@register_task("caption", dataclass=CaptionConfig)
|
| 64 |
+
class CaptionTask(OFATask):
|
| 65 |
+
def __init__(self, cfg: CaptionConfig, src_dict, tgt_dict):
|
| 66 |
+
super().__init__(cfg, src_dict, tgt_dict)
|
| 67 |
+
|
| 68 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
| 69 |
+
paths = self.cfg.data.split(',')
|
| 70 |
+
assert len(paths) > 0
|
| 71 |
+
|
| 72 |
+
if split == 'train':
|
| 73 |
+
file_path = paths[(epoch - 1) % (len(paths) - 1)]
|
| 74 |
+
else:
|
| 75 |
+
file_path = paths[-1]
|
| 76 |
+
dataset = FileDataset(file_path, self.cfg.selected_cols)
|
| 77 |
+
|
| 78 |
+
self.datasets[split] = CaptionDataset(
|
| 79 |
+
split,
|
| 80 |
+
dataset,
|
| 81 |
+
self.bpe,
|
| 82 |
+
self.src_dict,
|
| 83 |
+
self.tgt_dict,
|
| 84 |
+
max_src_length=self.cfg.max_src_length,
|
| 85 |
+
max_tgt_length=self.cfg.max_tgt_length,
|
| 86 |
+
patch_image_size=self.cfg.patch_image_size,
|
| 87 |
+
imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
|
| 88 |
+
scst=getattr(self.cfg, 'scst', False)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def build_model(self, cfg):
|
| 92 |
+
model = super().build_model(cfg)
|
| 93 |
+
if self.cfg.eval_bleu or self.cfg.eval_cider:
|
| 94 |
+
gen_args = json.loads(self.cfg.eval_args)
|
| 95 |
+
self.sequence_generator = self.build_generator(
|
| 96 |
+
[model], Namespace(**gen_args)
|
| 97 |
+
)
|
| 98 |
+
if self.cfg.eval_cider:
|
| 99 |
+
self.CiderD_scorer = CiderD(df=self.cfg.eval_cider_cached_tokens)
|
| 100 |
+
if self.cfg.scst:
|
| 101 |
+
scst_args = json.loads(self.cfg.scst_args)
|
| 102 |
+
self.scst_generator = self.build_generator(
|
| 103 |
+
[model], Namespace(**scst_args)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return model
|
| 107 |
+
|
| 108 |
+
def _calculate_cider_scores(self, gen_res, gt_res):
|
| 109 |
+
'''
|
| 110 |
+
gen_res: generated captions, list of str
|
| 111 |
+
gt_idx: list of int, of the same length as gen_res
|
| 112 |
+
gt_res: ground truth captions, list of list of str.
|
| 113 |
+
gen_res[i] corresponds to gt_res[gt_idx[i]]
|
| 114 |
+
Each image can have multiple ground truth captions
|
| 115 |
+
'''
|
| 116 |
+
gen_res_size = len(gen_res)
|
| 117 |
+
|
| 118 |
+
res = OrderedDict()
|
| 119 |
+
for i in range(gen_res_size):
|
| 120 |
+
res[i] = [gen_res[i].strip()]
|
| 121 |
+
|
| 122 |
+
gts = OrderedDict()
|
| 123 |
+
gt_res_ = [
|
| 124 |
+
[gt_res[i][j].strip() for j in range(len(gt_res[i]))]
|
| 125 |
+
for i in range(len(gt_res))
|
| 126 |
+
]
|
| 127 |
+
for i in range(gen_res_size):
|
| 128 |
+
gts[i] = gt_res_[i]
|
| 129 |
+
|
| 130 |
+
res_ = [{'image_id': i, 'caption': res[i]} for i in range(len(res))]
|
| 131 |
+
_, scores = self.CiderD_scorer.compute_score(gts, res_)
|
| 132 |
+
return scores
|
| 133 |
+
|
| 134 |
+
def valid_step(self, sample, model, criterion):
|
| 135 |
+
loss, sample_size, logging_output = criterion(model, sample)
|
| 136 |
+
|
| 137 |
+
model.eval()
|
| 138 |
+
if self.cfg.eval_bleu or self.cfg.eval_cider:
|
| 139 |
+
hyps, refs = self._inference(self.sequence_generator, sample, model)
|
| 140 |
+
if self.cfg.eval_bleu:
|
| 141 |
+
if self.cfg.eval_tokenized_bleu:
|
| 142 |
+
bleu = sacrebleu.corpus_bleu(hyps, list(zip_longest(*refs)), tokenize="none")
|
| 143 |
+
else:
|
| 144 |
+
bleu = sacrebleu.corpus_bleu(hyps, list(zip_longest(*refs)))
|
| 145 |
+
logging_output["_bleu_sys_len"] = bleu.sys_len
|
| 146 |
+
logging_output["_bleu_ref_len"] = bleu.ref_len
|
| 147 |
+
# we split counts into separate entries so that they can be
|
| 148 |
+
# summed efficiently across workers using fast-stat-sync
|
| 149 |
+
assert len(bleu.counts) == EVAL_BLEU_ORDER
|
| 150 |
+
for i in range(EVAL_BLEU_ORDER):
|
| 151 |
+
logging_output["_bleu_counts_" + str(i)] = bleu.counts[i]
|
| 152 |
+
logging_output["_bleu_totals_" + str(i)] = bleu.totals[i]
|
| 153 |
+
if self.cfg.eval_cider:
|
| 154 |
+
scores = self._calculate_cider_scores(hyps, refs)
|
| 155 |
+
logging_output["_cider_score_sum"] = scores.sum()
|
| 156 |
+
logging_output["_cider_cnt"] = scores.size
|
| 157 |
+
|
| 158 |
+
return loss, sample_size, logging_output
|
| 159 |
+
|
| 160 |
+
def reduce_metrics(self, logging_outputs, criterion):
|
| 161 |
+
super().reduce_metrics(logging_outputs, criterion)
|
| 162 |
+
|
| 163 |
+
def sum_logs(key):
|
| 164 |
+
import torch
|
| 165 |
+
result = sum(log.get(key, 0) for log in logging_outputs)
|
| 166 |
+
if torch.is_tensor(result):
|
| 167 |
+
result = result.cpu()
|
| 168 |
+
return result
|
| 169 |
+
|
| 170 |
+
if self.cfg.eval_bleu:
|
| 171 |
+
counts, totals = [], []
|
| 172 |
+
for i in range(EVAL_BLEU_ORDER):
|
| 173 |
+
counts.append(sum_logs("_bleu_counts_" + str(i)))
|
| 174 |
+
totals.append(sum_logs("_bleu_totals_" + str(i)))
|
| 175 |
+
|
| 176 |
+
if max(totals) > 0:
|
| 177 |
+
# log counts as numpy arrays -- log_scalar will sum them correctly
|
| 178 |
+
metrics.log_scalar("_bleu_counts", np.array(counts))
|
| 179 |
+
metrics.log_scalar("_bleu_totals", np.array(totals))
|
| 180 |
+
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
|
| 181 |
+
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
|
| 182 |
+
|
| 183 |
+
def compute_bleu(meters):
|
| 184 |
+
import inspect
|
| 185 |
+
import sacrebleu
|
| 186 |
+
|
| 187 |
+
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
|
| 188 |
+
if "smooth_method" in fn_sig:
|
| 189 |
+
smooth = {"smooth_method": "exp"}
|
| 190 |
+
else:
|
| 191 |
+
smooth = {"smooth": "exp"}
|
| 192 |
+
bleu = sacrebleu.compute_bleu(
|
| 193 |
+
correct=meters["_bleu_counts"].sum,
|
| 194 |
+
total=meters["_bleu_totals"].sum,
|
| 195 |
+
sys_len=meters["_bleu_sys_len"].sum,
|
| 196 |
+
ref_len=meters["_bleu_ref_len"].sum,
|
| 197 |
+
**smooth
|
| 198 |
+
)
|
| 199 |
+
return round(bleu.score, 2)
|
| 200 |
+
|
| 201 |
+
metrics.log_derived("bleu", compute_bleu)
|
| 202 |
+
|
| 203 |
+
if self.cfg.eval_cider:
|
| 204 |
+
def compute_cider(meters):
|
| 205 |
+
cider = meters["_cider_score_sum"].sum / meters["_cider_cnt"].sum
|
| 206 |
+
cider = cider if isinstance(cider, float) else cider.item()
|
| 207 |
+
return round(cider, 3)
|
| 208 |
+
|
| 209 |
+
if sum_logs("_cider_cnt") > 0:
|
| 210 |
+
metrics.log_scalar("_cider_score_sum", sum_logs("_cider_score_sum"))
|
| 211 |
+
metrics.log_scalar("_cider_cnt", sum_logs("_cider_cnt"))
|
| 212 |
+
metrics.log_derived("cider", compute_cider)
|
| 213 |
+
|
| 214 |
+
def _inference(self, generator, sample, model):
|
| 215 |
+
|
| 216 |
+
def decode(toks, escape_unk=False):
|
| 217 |
+
s = self.tgt_dict.string(
|
| 218 |
+
toks.int().cpu(),
|
| 219 |
+
# The default unknown string in fairseq is `<unk>`, but
|
| 220 |
+
# this is tokenized by sacrebleu as `< unk >`, inflating
|
| 221 |
+
# BLEU scores. Instead, we use a somewhat more verbose
|
| 222 |
+
# alternative that is unlikely to appear in the real
|
| 223 |
+
# reference, but doesn't get split into multiple tokens.
|
| 224 |
+
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
|
| 225 |
+
)
|
| 226 |
+
if self.bpe:
|
| 227 |
+
s = self.bpe.decode(s)
|
| 228 |
+
return s
|
| 229 |
+
|
| 230 |
+
gen_out = self.inference_step(generator, [model], sample)
|
| 231 |
+
hyps, refs = [], []
|
| 232 |
+
transtab = str.maketrans({key: None for key in string.punctuation})
|
| 233 |
+
for i in range(len(gen_out)):
|
| 234 |
+
decode_tokens = decode(gen_out[i][0]["tokens"])
|
| 235 |
+
hyps.append(decode_tokens.translate(transtab).strip())
|
| 236 |
+
refs.append(
|
| 237 |
+
[
|
| 238 |
+
sent.translate(transtab).strip()
|
| 239 |
+
for sent in decode(
|
| 240 |
+
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
|
| 241 |
+
escape_unk=True, # don't count <unk> as matches to the hypo
|
| 242 |
+
).split('&&')
|
| 243 |
+
]
|
| 244 |
+
)
|
| 245 |
+
if self.cfg.eval_print_samples:
|
| 246 |
+
logger.info("example hypothesis: " + hyps[0])
|
| 247 |
+
logger.info("example reference: " + ' && '.join(refs[0]))
|
| 248 |
+
|
| 249 |
+
return hyps, refs
|
tasks/ofa_task.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
from typing import Dict, Optional
|
| 12 |
+
|
| 13 |
+
from fairseq import search
|
| 14 |
+
from fairseq.data import FairseqDataset, iterators
|
| 15 |
+
from fairseq.optim.amp_optimizer import AMPOptimizer
|
| 16 |
+
from fairseq.dataclass import FairseqDataclass
|
| 17 |
+
from fairseq.tasks import FairseqTask, register_task
|
| 18 |
+
from omegaconf import DictConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class OFAConfig(FairseqDataclass):
|
| 26 |
+
data: Optional[str] = field(
|
| 27 |
+
default=None,
|
| 28 |
+
metadata={
|
| 29 |
+
"help": "colon separated path to data directories list, will be iterated upon during epochs "
|
| 30 |
+
"in round-robin manner; however, valid and test data are always in the first directory "
|
| 31 |
+
"to avoid the need for repeating them in all directories"
|
| 32 |
+
},
|
| 33 |
+
)
|
| 34 |
+
selected_cols: Optional[str] = field(
|
| 35 |
+
default=None,
|
| 36 |
+
metadata={"help": "selected cols"},
|
| 37 |
+
)
|
| 38 |
+
bpe_dir: Optional[str] = field(
|
| 39 |
+
default=None,
|
| 40 |
+
metadata={"help": "bpe dir"},
|
| 41 |
+
)
|
| 42 |
+
max_source_positions: int = field(
|
| 43 |
+
default=1024, metadata={"help": "max number of tokens in the source sequence"}
|
| 44 |
+
)
|
| 45 |
+
max_target_positions: int = field(
|
| 46 |
+
default=1024, metadata={"help": "max number of tokens in the target sequence"}
|
| 47 |
+
)
|
| 48 |
+
max_src_length: int = field(
|
| 49 |
+
default=128, metadata={"help": "the maximum src sequence length"}
|
| 50 |
+
)
|
| 51 |
+
max_tgt_length: int = field(
|
| 52 |
+
default=30, metadata={"help": "the maximum target sequence length"}
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
code_dict_size: int = field(
|
| 56 |
+
default=8192, metadata={"help": "code dict size"}
|
| 57 |
+
)
|
| 58 |
+
patch_image_size: int = field(
|
| 59 |
+
default=480, metadata={"help": "patch image size"}
|
| 60 |
+
)
|
| 61 |
+
num_bins: int = field(
|
| 62 |
+
default=1000, metadata={"help": "number of quantization bins"}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
imagenet_default_mean_and_std: bool = field(
|
| 66 |
+
default=False,
|
| 67 |
+
metadata={"help": "imagenet normalize"},
|
| 68 |
+
)
|
| 69 |
+
constraint_range: Optional[str] = field(
|
| 70 |
+
default=None,
|
| 71 |
+
metadata={"help": "constraint range"}
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@register_task("ofa", dataclass=OFAConfig)
|
| 76 |
+
class OFATask(FairseqTask):
|
| 77 |
+
def __init__(self, cfg: OFAConfig, src_dict, tgt_dict):
|
| 78 |
+
super().__init__(cfg)
|
| 79 |
+
self.src_dict = src_dict
|
| 80 |
+
self.tgt_dict = tgt_dict
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def setup_task(cls, cfg: DictConfig, **kwargs):
|
| 84 |
+
"""Setup the task."""
|
| 85 |
+
|
| 86 |
+
# load dictionaries
|
| 87 |
+
src_dict = cls.load_dictionary(
|
| 88 |
+
os.path.join(cfg.bpe_dir, "dict.txt")
|
| 89 |
+
)
|
| 90 |
+
tgt_dict = cls.load_dictionary(
|
| 91 |
+
os.path.join(cfg.bpe_dir, "dict.txt")
|
| 92 |
+
)
|
| 93 |
+
src_dict.add_symbol("<mask>")
|
| 94 |
+
tgt_dict.add_symbol("<mask>")
|
| 95 |
+
for i in range(cfg.code_dict_size):
|
| 96 |
+
src_dict.add_symbol("<code_{}>".format(i))
|
| 97 |
+
tgt_dict.add_symbol("<code_{}>".format(i))
|
| 98 |
+
# quantization
|
| 99 |
+
for i in range(cfg.num_bins):
|
| 100 |
+
src_dict.add_symbol("<bin_{}>".format(i))
|
| 101 |
+
tgt_dict.add_symbol("<bin_{}>".format(i))
|
| 102 |
+
|
| 103 |
+
logger.info("source dictionary: {} types".format(len(src_dict)))
|
| 104 |
+
logger.info("target dictionary: {} types".format(len(tgt_dict)))
|
| 105 |
+
return cls(cfg, src_dict, tgt_dict)
|
| 106 |
+
|
| 107 |
+
def get_batch_iterator(
|
| 108 |
+
self,
|
| 109 |
+
dataset,
|
| 110 |
+
max_tokens=None,
|
| 111 |
+
max_sentences=None,
|
| 112 |
+
max_positions=None,
|
| 113 |
+
ignore_invalid_inputs=False,
|
| 114 |
+
required_batch_size_multiple=1,
|
| 115 |
+
seed=1,
|
| 116 |
+
num_shards=1,
|
| 117 |
+
shard_id=0,
|
| 118 |
+
num_workers=0,
|
| 119 |
+
epoch=1,
|
| 120 |
+
data_buffer_size=0,
|
| 121 |
+
disable_iterator_cache=False,
|
| 122 |
+
):
|
| 123 |
+
assert isinstance(dataset, FairseqDataset)
|
| 124 |
+
|
| 125 |
+
# initialize the dataset with the correct starting epoch
|
| 126 |
+
dataset.set_epoch(epoch)
|
| 127 |
+
|
| 128 |
+
# create mini-batches with given size constraints
|
| 129 |
+
batch_sampler = [
|
| 130 |
+
[j for j in range(i, min(i + max_sentences, len(dataset)))]
|
| 131 |
+
for i in range(0, len(dataset), max_sentences)
|
| 132 |
+
]
|
| 133 |
+
total_row_count = dataset.dataset.get_total_row_count()
|
| 134 |
+
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences)
|
| 135 |
+
if len(batch_sampler) < num_batches:
|
| 136 |
+
batch_sampler.append([])
|
| 137 |
+
|
| 138 |
+
# return a reusable, sharded iterator
|
| 139 |
+
epoch_iter = iterators.EpochBatchIterator(
|
| 140 |
+
dataset=dataset,
|
| 141 |
+
collate_fn=dataset.collater,
|
| 142 |
+
batch_sampler=batch_sampler,
|
| 143 |
+
seed=seed,
|
| 144 |
+
num_shards=1,
|
| 145 |
+
shard_id=0,
|
| 146 |
+
num_workers=num_workers,
|
| 147 |
+
epoch=epoch,
|
| 148 |
+
buffer_size=data_buffer_size
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return epoch_iter
|
| 152 |
+
|
| 153 |
+
def build_model(self, cfg: FairseqDataclass):
|
| 154 |
+
model = super().build_model(cfg)
|
| 155 |
+
bpe_dict = {
|
| 156 |
+
"_name": "gpt2",
|
| 157 |
+
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
|
| 158 |
+
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
|
| 159 |
+
}
|
| 160 |
+
bpe_dict = DictConfig(bpe_dict)
|
| 161 |
+
self.bpe = self.build_bpe(bpe_dict)
|
| 162 |
+
return model
|
| 163 |
+
|
| 164 |
+
def build_generator(
|
| 165 |
+
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
|
| 166 |
+
):
|
| 167 |
+
"""
|
| 168 |
+
Build a :class:`~fairseq.SequenceGenerator` instance for this
|
| 169 |
+
task.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
| 173 |
+
args (fairseq.dataclass.configs.GenerationConfig):
|
| 174 |
+
configuration object (dataclass) for generation
|
| 175 |
+
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
|
| 176 |
+
through to SequenceGenerator
|
| 177 |
+
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
|
| 178 |
+
If provided, this function constrains the beam search to
|
| 179 |
+
allowed tokens only at each step. The provided function
|
| 180 |
+
should take 2 arguments: the batch ID (`batch_id: int`)
|
| 181 |
+
and a unidimensional tensor of token ids (`inputs_ids:
|
| 182 |
+
torch.Tensor`). It has to return a `List[int]` with the
|
| 183 |
+
allowed tokens for the next generation step conditioned
|
| 184 |
+
on the previously generated tokens (`inputs_ids`) and
|
| 185 |
+
the batch ID (`batch_id`). This argument is useful for
|
| 186 |
+
constrained generation conditioned on the prefix, as
|
| 187 |
+
described in "Autoregressive Entity Retrieval"
|
| 188 |
+
(https://arxiv.org/abs/2010.00904) and
|
| 189 |
+
https://github.com/facebookresearch/GENRE.
|
| 190 |
+
"""
|
| 191 |
+
if getattr(args, "score_reference", False):
|
| 192 |
+
from fairseq.sequence_scorer import SequenceScorer
|
| 193 |
+
|
| 194 |
+
return SequenceScorer(
|
| 195 |
+
self.target_dictionary,
|
| 196 |
+
compute_alignment=getattr(args, "print_alignment", False),
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
from fairseq.sequence_generator import (
|
| 200 |
+
# SequenceGenerator,
|
| 201 |
+
SequenceGeneratorWithAlignment,
|
| 202 |
+
)
|
| 203 |
+
from models.sequence_generator import SequenceGenerator
|
| 204 |
+
|
| 205 |
+
# Choose search strategy. Defaults to Beam Search.
|
| 206 |
+
sampling = getattr(args, "sampling", False)
|
| 207 |
+
sampling_topk = getattr(args, "sampling_topk", -1)
|
| 208 |
+
sampling_topp = getattr(args, "sampling_topp", -1.0)
|
| 209 |
+
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
|
| 210 |
+
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
|
| 211 |
+
match_source_len = getattr(args, "match_source_len", False)
|
| 212 |
+
diversity_rate = getattr(args, "diversity_rate", -1)
|
| 213 |
+
constrained = getattr(args, "constraints", False)
|
| 214 |
+
if prefix_allowed_tokens_fn is None:
|
| 215 |
+
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
|
| 216 |
+
if (
|
| 217 |
+
sum(
|
| 218 |
+
int(cond)
|
| 219 |
+
for cond in [
|
| 220 |
+
sampling,
|
| 221 |
+
diverse_beam_groups > 0,
|
| 222 |
+
match_source_len,
|
| 223 |
+
diversity_rate > 0,
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
> 1
|
| 227 |
+
):
|
| 228 |
+
raise ValueError("Provided Search parameters are mutually exclusive.")
|
| 229 |
+
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
|
| 230 |
+
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
|
| 231 |
+
|
| 232 |
+
if sampling:
|
| 233 |
+
search_strategy = search.Sampling(
|
| 234 |
+
self.target_dictionary, sampling_topk, sampling_topp
|
| 235 |
+
)
|
| 236 |
+
elif diverse_beam_groups > 0:
|
| 237 |
+
search_strategy = search.DiverseBeamSearch(
|
| 238 |
+
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
|
| 239 |
+
)
|
| 240 |
+
elif match_source_len:
|
| 241 |
+
# this is useful for tagging applications where the output
|
| 242 |
+
# length should match the input length, so we hardcode the
|
| 243 |
+
# length constraints for simplicity
|
| 244 |
+
search_strategy = search.LengthConstrainedBeamSearch(
|
| 245 |
+
self.target_dictionary,
|
| 246 |
+
min_len_a=1,
|
| 247 |
+
min_len_b=0,
|
| 248 |
+
max_len_a=1,
|
| 249 |
+
max_len_b=0,
|
| 250 |
+
)
|
| 251 |
+
elif diversity_rate > -1:
|
| 252 |
+
search_strategy = search.DiverseSiblingsSearch(
|
| 253 |
+
self.target_dictionary, diversity_rate
|
| 254 |
+
)
|
| 255 |
+
elif constrained:
|
| 256 |
+
search_strategy = search.LexicallyConstrainedBeamSearch(
|
| 257 |
+
self.target_dictionary, args.constraints
|
| 258 |
+
)
|
| 259 |
+
elif prefix_allowed_tokens_fn:
|
| 260 |
+
search_strategy = search.PrefixConstrainedBeamSearch(
|
| 261 |
+
self.target_dictionary, prefix_allowed_tokens_fn
|
| 262 |
+
)
|
| 263 |
+
else:
|
| 264 |
+
search_strategy = search.BeamSearch(self.target_dictionary)
|
| 265 |
+
|
| 266 |
+
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
| 267 |
+
if seq_gen_cls is None:
|
| 268 |
+
if getattr(args, "print_alignment", False):
|
| 269 |
+
seq_gen_cls = SequenceGeneratorWithAlignment
|
| 270 |
+
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
|
| 271 |
+
else:
|
| 272 |
+
seq_gen_cls = SequenceGenerator
|
| 273 |
+
|
| 274 |
+
return seq_gen_cls(
|
| 275 |
+
models,
|
| 276 |
+
self.target_dictionary,
|
| 277 |
+
beam_size=getattr(args, "beam", 5),
|
| 278 |
+
max_len_a=getattr(args, "max_len_a", 0),
|
| 279 |
+
max_len_b=getattr(args, "max_len_b", 200),
|
| 280 |
+
min_len=getattr(args, "min_len", 1),
|
| 281 |
+
normalize_scores=(not getattr(args, "unnormalized", False)),
|
| 282 |
+
len_penalty=getattr(args, "lenpen", 1),
|
| 283 |
+
unk_penalty=getattr(args, "unkpen", 0),
|
| 284 |
+
temperature=getattr(args, "temperature", 1.0),
|
| 285 |
+
match_source_len=getattr(args, "match_source_len", False),
|
| 286 |
+
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
| 287 |
+
search_strategy=search_strategy,
|
| 288 |
+
constraint_range=self.cfg.constraint_range,
|
| 289 |
+
**extra_gen_cls_kwargs,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
def train_step(
|
| 293 |
+
self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs
|
| 294 |
+
):
|
| 295 |
+
"""
|
| 296 |
+
Do forward and backward, and return the loss as computed by *criterion*
|
| 297 |
+
for the given *model* and *sample*.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
sample (dict): the mini-batch. The format is defined by the
|
| 301 |
+
:class:`~fairseq.data.FairseqDataset`.
|
| 302 |
+
model (~fairseq.models.BaseFairseqModel): the model
|
| 303 |
+
criterion (~fairseq.criterions.FairseqCriterion): the criterion
|
| 304 |
+
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
|
| 305 |
+
update_num (int): the current update
|
| 306 |
+
ignore_grad (bool): multiply loss by 0 if this is set to True
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
tuple:
|
| 310 |
+
- the loss
|
| 311 |
+
- the sample size, which is used as the denominator for the
|
| 312 |
+
gradient
|
| 313 |
+
- logging outputs to display while training
|
| 314 |
+
"""
|
| 315 |
+
model.train()
|
| 316 |
+
model.set_num_updates(update_num)
|
| 317 |
+
with torch.autograd.profiler.record_function("forward"):
|
| 318 |
+
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
|
| 319 |
+
loss, sample_size, logging_output = criterion(model, sample, update_num)
|
| 320 |
+
if ignore_grad:
|
| 321 |
+
loss *= 0
|
| 322 |
+
with torch.autograd.profiler.record_function("backward"):
|
| 323 |
+
optimizer.backward(loss)
|
| 324 |
+
return loss, sample_size, logging_output
|
| 325 |
+
|
| 326 |
+
def max_positions(self):
|
| 327 |
+
"""Return the max sentence length allowed by the task."""
|
| 328 |
+
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
|
| 329 |
+
|
| 330 |
+
@property
|
| 331 |
+
def source_dictionary(self):
|
| 332 |
+
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
| 333 |
+
return self.src_dict
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def target_dictionary(self):
|
| 337 |
+
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
| 338 |
+
return self.tgt_dict
|
train.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Train a new model on one or across multiple GPUs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Dict, Optional, Any, List, Tuple, Callable
|
| 16 |
+
|
| 17 |
+
# We need to setup root logger before importing any fairseq libraries.
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
|
| 20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 21 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 22 |
+
stream=sys.stdout,
|
| 23 |
+
)
|
| 24 |
+
logger = logging.getLogger("fairseq_cli.train")
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from fairseq import (
|
| 29 |
+
# checkpoint_utils,
|
| 30 |
+
options,
|
| 31 |
+
quantization_utils,
|
| 32 |
+
tasks,
|
| 33 |
+
utils,
|
| 34 |
+
)
|
| 35 |
+
from fairseq.data import iterators
|
| 36 |
+
from fairseq.data.plasma_utils import PlasmaStore
|
| 37 |
+
from fairseq.dataclass.configs import FairseqConfig
|
| 38 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 39 |
+
from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
|
| 40 |
+
from fairseq.file_io import PathManager
|
| 41 |
+
from fairseq.logging import meters, metrics, progress_bar
|
| 42 |
+
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
|
| 43 |
+
# from fairseq.trainer import Trainer
|
| 44 |
+
from omegaconf import DictConfig, OmegaConf
|
| 45 |
+
|
| 46 |
+
from utils import checkpoint_utils
|
| 47 |
+
from trainer import Trainer
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main(cfg: FairseqConfig) -> None:
|
| 51 |
+
if isinstance(cfg, argparse.Namespace):
|
| 52 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
| 53 |
+
|
| 54 |
+
utils.import_user_module(cfg.common)
|
| 55 |
+
|
| 56 |
+
if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
|
| 57 |
+
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
| 58 |
+
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
|
| 59 |
+
|
| 60 |
+
assert (
|
| 61 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
| 62 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
| 63 |
+
metrics.reset()
|
| 64 |
+
|
| 65 |
+
if cfg.common.log_file is not None:
|
| 66 |
+
handler = logging.FileHandler(filename=cfg.common.log_file)
|
| 67 |
+
logger.addHandler(handler)
|
| 68 |
+
|
| 69 |
+
np.random.seed(cfg.common.seed)
|
| 70 |
+
utils.set_torch_seed(cfg.common.seed)
|
| 71 |
+
|
| 72 |
+
if distributed_utils.is_master(cfg.distributed_training):
|
| 73 |
+
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
|
| 74 |
+
|
| 75 |
+
# Print args
|
| 76 |
+
logger.info(cfg)
|
| 77 |
+
|
| 78 |
+
if cfg.checkpoint.write_checkpoints_asynchronously:
|
| 79 |
+
try:
|
| 80 |
+
import iopath # noqa: F401
|
| 81 |
+
except ImportError:
|
| 82 |
+
logging.exception(
|
| 83 |
+
"Asynchronous checkpoint writing is specified but iopath is "
|
| 84 |
+
"not installed: `pip install iopath`"
|
| 85 |
+
)
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
# Setup task, e.g., translation, language modeling, etc.
|
| 89 |
+
task = tasks.setup_task(cfg.task)
|
| 90 |
+
|
| 91 |
+
assert cfg.criterion, "Please specify criterion to train a model"
|
| 92 |
+
|
| 93 |
+
# Build model and criterion
|
| 94 |
+
if cfg.distributed_training.ddp_backend == "fully_sharded":
|
| 95 |
+
with fsdp_enable_wrap(cfg.distributed_training):
|
| 96 |
+
model = fsdp_wrap(task.build_model(cfg.model))
|
| 97 |
+
else:
|
| 98 |
+
model = task.build_model(cfg.model)
|
| 99 |
+
criterion = task.build_criterion(cfg.criterion)
|
| 100 |
+
logger.info(model)
|
| 101 |
+
logger.info("task: {}".format(task.__class__.__name__))
|
| 102 |
+
logger.info("model: {}".format(model.__class__.__name__))
|
| 103 |
+
logger.info("criterion: {}".format(criterion.__class__.__name__))
|
| 104 |
+
logger.info(
|
| 105 |
+
"num. shared model params: {:,} (num. trained: {:,})".format(
|
| 106 |
+
sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)),
|
| 107 |
+
sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad)
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
logger.info(
|
| 112 |
+
"num. expert model params: {} (num. trained: {})".format(
|
| 113 |
+
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
|
| 114 |
+
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad),
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
| 119 |
+
# We load the valid dataset AFTER building the model
|
| 120 |
+
# data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
|
| 121 |
+
if cfg.dataset.combine_valid_subsets:
|
| 122 |
+
task.load_dataset("valid", combine=True, epoch=1)
|
| 123 |
+
else:
|
| 124 |
+
for valid_sub_split in cfg.dataset.valid_subset.split(","):
|
| 125 |
+
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
| 126 |
+
|
| 127 |
+
# (optionally) Configure quantization
|
| 128 |
+
if cfg.common.quantization_config_path is not None:
|
| 129 |
+
quantizer = quantization_utils.Quantizer(
|
| 130 |
+
config_path=cfg.common.quantization_config_path,
|
| 131 |
+
max_epoch=cfg.optimization.max_epoch,
|
| 132 |
+
max_update=cfg.optimization.max_update,
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
quantizer = None
|
| 136 |
+
|
| 137 |
+
# Build trainer
|
| 138 |
+
if cfg.common.model_parallel_size == 1:
|
| 139 |
+
trainer = Trainer(cfg, task, model, criterion, quantizer)
|
| 140 |
+
else:
|
| 141 |
+
trainer = MegatronTrainer(cfg, task, model, criterion)
|
| 142 |
+
logger.info(
|
| 143 |
+
"training on {} devices (GPUs/TPUs)".format(
|
| 144 |
+
cfg.distributed_training.distributed_world_size
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
logger.info(
|
| 148 |
+
"max tokens per device = {} and max sentences per device = {}".format(
|
| 149 |
+
cfg.dataset.max_tokens,
|
| 150 |
+
cfg.dataset.batch_size,
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Load the latest checkpoint if one is available and restore the
|
| 155 |
+
# corresponding train iterator
|
| 156 |
+
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
|
| 157 |
+
cfg.checkpoint,
|
| 158 |
+
trainer,
|
| 159 |
+
# don't cache epoch iterators for sharded datasets
|
| 160 |
+
disable_iterator_cache=task.has_sharded_data("train"),
|
| 161 |
+
)
|
| 162 |
+
if cfg.common.tpu:
|
| 163 |
+
import torch_xla.core.xla_model as xm
|
| 164 |
+
xm.rendezvous("load_checkpoint") # wait for all workers
|
| 165 |
+
|
| 166 |
+
max_epoch = cfg.optimization.max_epoch or math.inf
|
| 167 |
+
if max_epoch > 0:
|
| 168 |
+
num_iter_per_epoch = (len(epoch_itr) + cfg.distributed_training.distributed_world_size - 1) \
|
| 169 |
+
// cfg.distributed_training.distributed_world_size
|
| 170 |
+
trainer.lr_reinit(num_iter_per_epoch * max_epoch, trainer.get_num_updates())
|
| 171 |
+
lr = trainer.get_lr()
|
| 172 |
+
|
| 173 |
+
train_meter = meters.StopwatchMeter()
|
| 174 |
+
train_meter.start()
|
| 175 |
+
while epoch_itr.next_epoch_idx <= max_epoch:
|
| 176 |
+
if lr <= cfg.optimization.stop_min_lr:
|
| 177 |
+
logger.info(
|
| 178 |
+
f"stopping training because current learning rate ({lr}) is smaller "
|
| 179 |
+
"than or equal to minimum learning rate "
|
| 180 |
+
f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
|
| 181 |
+
)
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
# train for one epoch
|
| 185 |
+
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
|
| 186 |
+
if should_stop:
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
# only use first validation loss to update the learning rate
|
| 190 |
+
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
|
| 191 |
+
|
| 192 |
+
epoch_itr = trainer.get_train_iterator(
|
| 193 |
+
epoch_itr.next_epoch_idx,
|
| 194 |
+
# sharded data: get train iterator for next epoch
|
| 195 |
+
load_dataset=True,
|
| 196 |
+
# don't cache epoch iterators for sharded datasets
|
| 197 |
+
disable_iterator_cache=task.has_sharded_data("train"),
|
| 198 |
+
)
|
| 199 |
+
train_meter.stop()
|
| 200 |
+
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
|
| 201 |
+
|
| 202 |
+
# ioPath implementation to wait for all asynchronous file writes to complete.
|
| 203 |
+
if cfg.checkpoint.write_checkpoints_asynchronously:
|
| 204 |
+
logger.info(
|
| 205 |
+
"ioPath PathManager waiting for all asynchronous checkpoint "
|
| 206 |
+
"writes to finish."
|
| 207 |
+
)
|
| 208 |
+
PathManager.async_close()
|
| 209 |
+
logger.info("ioPath PathManager finished waiting.")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
|
| 213 |
+
# skip check if no validation was done in the current epoch
|
| 214 |
+
if valid_loss is None:
|
| 215 |
+
return False
|
| 216 |
+
if cfg.checkpoint.patience <= 0:
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
def is_better(a, b):
|
| 220 |
+
return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
|
| 221 |
+
|
| 222 |
+
prev_best = getattr(should_stop_early, "best", None)
|
| 223 |
+
if prev_best is None or is_better(valid_loss, prev_best):
|
| 224 |
+
should_stop_early.best = valid_loss
|
| 225 |
+
should_stop_early.num_runs = 0
|
| 226 |
+
return False
|
| 227 |
+
else:
|
| 228 |
+
should_stop_early.num_runs += 1
|
| 229 |
+
if should_stop_early.num_runs >= cfg.checkpoint.patience:
|
| 230 |
+
logger.info(
|
| 231 |
+
"early stop since valid performance hasn't improved for last {} runs".format(
|
| 232 |
+
cfg.checkpoint.patience
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
return True
|
| 236 |
+
else:
|
| 237 |
+
return False
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@metrics.aggregate("train")
|
| 241 |
+
def train(
|
| 242 |
+
cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
|
| 243 |
+
) -> Tuple[List[Optional[float]], bool]:
|
| 244 |
+
"""Train the model for one epoch and return validation losses."""
|
| 245 |
+
# Initialize data iterator
|
| 246 |
+
itr = epoch_itr.next_epoch_itr(
|
| 247 |
+
fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
|
| 248 |
+
shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
|
| 249 |
+
)
|
| 250 |
+
update_freq = (
|
| 251 |
+
cfg.optimization.update_freq[epoch_itr.epoch - 1]
|
| 252 |
+
if epoch_itr.epoch <= len(cfg.optimization.update_freq)
|
| 253 |
+
else cfg.optimization.update_freq[-1]
|
| 254 |
+
)
|
| 255 |
+
itr = iterators.GroupedIterator(itr, update_freq)
|
| 256 |
+
if cfg.common.tpu:
|
| 257 |
+
itr = utils.tpu_data_loader(itr)
|
| 258 |
+
progress = progress_bar.progress_bar(
|
| 259 |
+
itr,
|
| 260 |
+
log_format=cfg.common.log_format,
|
| 261 |
+
log_file=cfg.common.log_file,
|
| 262 |
+
log_interval=cfg.common.log_interval,
|
| 263 |
+
epoch=epoch_itr.epoch,
|
| 264 |
+
tensorboard_logdir=(
|
| 265 |
+
cfg.common.tensorboard_logdir
|
| 266 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
| 267 |
+
else None
|
| 268 |
+
),
|
| 269 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
| 270 |
+
wandb_project=(
|
| 271 |
+
cfg.common.wandb_project
|
| 272 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
| 273 |
+
else None
|
| 274 |
+
),
|
| 275 |
+
wandb_run_name=os.environ.get(
|
| 276 |
+
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
|
| 277 |
+
),
|
| 278 |
+
azureml_logging=(
|
| 279 |
+
cfg.common.azureml_logging
|
| 280 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
| 281 |
+
else False
|
| 282 |
+
),
|
| 283 |
+
)
|
| 284 |
+
progress.update_config(_flatten_config(cfg))
|
| 285 |
+
|
| 286 |
+
trainer.begin_epoch(epoch_itr.epoch)
|
| 287 |
+
|
| 288 |
+
valid_subsets = cfg.dataset.valid_subset.split(",")
|
| 289 |
+
should_stop = False
|
| 290 |
+
num_updates = trainer.get_num_updates()
|
| 291 |
+
logger.info("Start iterating over samples")
|
| 292 |
+
for i, samples in enumerate(progress):
|
| 293 |
+
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
|
| 294 |
+
"train_step-%d" % i
|
| 295 |
+
):
|
| 296 |
+
log_output = trainer.train_step(samples)
|
| 297 |
+
|
| 298 |
+
if log_output is not None: # not OOM, overflow, ...
|
| 299 |
+
# log mid-epoch stats
|
| 300 |
+
num_updates = trainer.get_num_updates()
|
| 301 |
+
if num_updates % cfg.common.log_interval == 0:
|
| 302 |
+
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
|
| 303 |
+
progress.log(stats, tag="train_inner", step=num_updates)
|
| 304 |
+
|
| 305 |
+
# reset mid-epoch stats after each log interval
|
| 306 |
+
# the end-of-epoch stats will still be preserved
|
| 307 |
+
metrics.reset_meters("train_inner")
|
| 308 |
+
|
| 309 |
+
end_of_epoch = not itr.has_next()
|
| 310 |
+
valid_losses, should_stop = validate_and_save(
|
| 311 |
+
cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if should_stop:
|
| 315 |
+
break
|
| 316 |
+
|
| 317 |
+
# log end-of-epoch stats
|
| 318 |
+
logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
|
| 319 |
+
stats = get_training_stats(metrics.get_smoothed_values("train"))
|
| 320 |
+
progress.print(stats, tag="train", step=num_updates)
|
| 321 |
+
|
| 322 |
+
# reset epoch-level meters
|
| 323 |
+
metrics.reset_meters("train")
|
| 324 |
+
return valid_losses, should_stop
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _flatten_config(cfg: DictConfig):
|
| 328 |
+
config = OmegaConf.to_container(cfg)
|
| 329 |
+
# remove any legacy Namespaces and replace with a single "args"
|
| 330 |
+
namespace = None
|
| 331 |
+
for k, v in list(config.items()):
|
| 332 |
+
if isinstance(v, argparse.Namespace):
|
| 333 |
+
namespace = v
|
| 334 |
+
del config[k]
|
| 335 |
+
if namespace is not None:
|
| 336 |
+
config["args"] = vars(namespace)
|
| 337 |
+
return config
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def validate_and_save(
|
| 341 |
+
cfg: DictConfig,
|
| 342 |
+
trainer: Trainer,
|
| 343 |
+
task: tasks.FairseqTask,
|
| 344 |
+
epoch_itr,
|
| 345 |
+
valid_subsets: List[str],
|
| 346 |
+
end_of_epoch: bool,
|
| 347 |
+
) -> Tuple[List[Optional[float]], bool]:
|
| 348 |
+
num_updates = trainer.get_num_updates()
|
| 349 |
+
max_update = cfg.optimization.max_update or math.inf
|
| 350 |
+
|
| 351 |
+
# Stopping conditions (and an additional one based on validation loss later
|
| 352 |
+
# on)
|
| 353 |
+
should_stop = False
|
| 354 |
+
if num_updates >= max_update:
|
| 355 |
+
should_stop = True
|
| 356 |
+
logger.info(
|
| 357 |
+
f"Stopping training due to "
|
| 358 |
+
f"num_updates: {num_updates} >= max_update: {max_update}"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
training_time_hours = trainer.cumulative_training_time() / (60 * 60)
|
| 362 |
+
if (
|
| 363 |
+
cfg.optimization.stop_time_hours > 0
|
| 364 |
+
and training_time_hours > cfg.optimization.stop_time_hours
|
| 365 |
+
):
|
| 366 |
+
should_stop = True
|
| 367 |
+
logger.info(
|
| 368 |
+
f"Stopping training due to "
|
| 369 |
+
f"cumulative_training_time: {training_time_hours} > "
|
| 370 |
+
f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
do_save = (
|
| 374 |
+
(end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
|
| 375 |
+
or should_stop
|
| 376 |
+
or (
|
| 377 |
+
cfg.checkpoint.save_interval_updates > 0
|
| 378 |
+
and num_updates > 0
|
| 379 |
+
and num_updates % cfg.checkpoint.save_interval_updates == 0
|
| 380 |
+
and num_updates >= cfg.dataset.validate_after_updates
|
| 381 |
+
)
|
| 382 |
+
)
|
| 383 |
+
do_validate = (
|
| 384 |
+
(not end_of_epoch and do_save) # validate during mid-epoch saves
|
| 385 |
+
or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
|
| 386 |
+
or should_stop
|
| 387 |
+
or (
|
| 388 |
+
cfg.dataset.validate_interval_updates > 0
|
| 389 |
+
and num_updates > 0
|
| 390 |
+
and num_updates % cfg.dataset.validate_interval_updates == 0
|
| 391 |
+
)
|
| 392 |
+
) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates
|
| 393 |
+
|
| 394 |
+
# Validate
|
| 395 |
+
valid_losses = [None]
|
| 396 |
+
if do_validate:
|
| 397 |
+
valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
|
| 398 |
+
|
| 399 |
+
should_stop |= should_stop_early(cfg, valid_losses[0])
|
| 400 |
+
|
| 401 |
+
# Save checkpoint
|
| 402 |
+
if do_save or should_stop:
|
| 403 |
+
checkpoint_utils.save_checkpoint(
|
| 404 |
+
cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
return valid_losses, should_stop
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
|
| 411 |
+
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
|
| 412 |
+
return stats
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def validate(
|
| 416 |
+
cfg: DictConfig,
|
| 417 |
+
trainer: Trainer,
|
| 418 |
+
task: tasks.FairseqTask,
|
| 419 |
+
epoch_itr,
|
| 420 |
+
subsets: List[str],
|
| 421 |
+
) -> List[Optional[float]]:
|
| 422 |
+
"""Evaluate the model on the validation set(s) and return the losses."""
|
| 423 |
+
|
| 424 |
+
if cfg.dataset.fixed_validation_seed is not None:
|
| 425 |
+
# set fixed seed for every validation
|
| 426 |
+
utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
|
| 427 |
+
|
| 428 |
+
trainer.begin_valid_epoch(epoch_itr.epoch)
|
| 429 |
+
valid_losses = []
|
| 430 |
+
for subset in subsets:
|
| 431 |
+
logger.info('begin validation on "{}" subset'.format(subset))
|
| 432 |
+
|
| 433 |
+
# Initialize data iterator
|
| 434 |
+
itr = trainer.get_valid_iterator(subset).next_epoch_itr(
|
| 435 |
+
shuffle=False, set_dataset_epoch=False # use a fixed valid set
|
| 436 |
+
)
|
| 437 |
+
if cfg.common.tpu:
|
| 438 |
+
itr = utils.tpu_data_loader(itr)
|
| 439 |
+
progress = progress_bar.progress_bar(
|
| 440 |
+
itr,
|
| 441 |
+
log_format=cfg.common.log_format,
|
| 442 |
+
log_interval=cfg.common.log_interval,
|
| 443 |
+
epoch=epoch_itr.epoch,
|
| 444 |
+
prefix=f"valid on '{subset}' subset",
|
| 445 |
+
tensorboard_logdir=(
|
| 446 |
+
cfg.common.tensorboard_logdir
|
| 447 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
| 448 |
+
else None
|
| 449 |
+
),
|
| 450 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
| 451 |
+
wandb_project=(
|
| 452 |
+
cfg.common.wandb_project
|
| 453 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
| 454 |
+
else None
|
| 455 |
+
),
|
| 456 |
+
wandb_run_name=os.environ.get(
|
| 457 |
+
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
|
| 458 |
+
),
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# create a new root metrics aggregator so validation metrics
|
| 462 |
+
# don't pollute other aggregators (e.g., train meters)
|
| 463 |
+
with metrics.aggregate(new_root=True) as agg:
|
| 464 |
+
for i, sample in enumerate(progress):
|
| 465 |
+
if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
|
| 466 |
+
break
|
| 467 |
+
trainer.valid_step(sample)
|
| 468 |
+
|
| 469 |
+
# log validation stats
|
| 470 |
+
if hasattr(task, 'get_valid_stats'):
|
| 471 |
+
stats = task.get_valid_stats(cfg, trainer, agg.get_smoothed_values())
|
| 472 |
+
else:
|
| 473 |
+
stats = agg.get_smoothed_values()
|
| 474 |
+
stats = get_valid_stats(cfg, trainer, stats)
|
| 475 |
+
|
| 476 |
+
if hasattr(task, "post_validate"):
|
| 477 |
+
task.post_validate(trainer.get_model(), stats, agg)
|
| 478 |
+
|
| 479 |
+
progress.print(stats, tag=subset, step=trainer.get_num_updates())
|
| 480 |
+
|
| 481 |
+
valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
|
| 482 |
+
return valid_losses
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def get_valid_stats(
|
| 486 |
+
cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
|
| 487 |
+
) -> Dict[str, Any]:
|
| 488 |
+
stats["num_updates"] = trainer.get_num_updates()
|
| 489 |
+
if hasattr(checkpoint_utils.save_checkpoint, "best"):
|
| 490 |
+
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
|
| 491 |
+
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
|
| 492 |
+
stats[key] = best_function(
|
| 493 |
+
checkpoint_utils.save_checkpoint.best,
|
| 494 |
+
stats[cfg.checkpoint.best_checkpoint_metric],
|
| 495 |
+
)
|
| 496 |
+
return stats
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def cli_main(
|
| 500 |
+
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
|
| 501 |
+
) -> None:
|
| 502 |
+
parser = options.get_training_parser()
|
| 503 |
+
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
|
| 504 |
+
|
| 505 |
+
cfg = convert_namespace_to_omegaconf(args)
|
| 506 |
+
|
| 507 |
+
if cfg.common.use_plasma_view:
|
| 508 |
+
server = PlasmaStore(path=cfg.common.plasma_path)
|
| 509 |
+
logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}")
|
| 510 |
+
|
| 511 |
+
if args.profile:
|
| 512 |
+
with torch.cuda.profiler.profile():
|
| 513 |
+
with torch.autograd.profiler.emit_nvtx():
|
| 514 |
+
distributed_utils.call_main(cfg, main)
|
| 515 |
+
else:
|
| 516 |
+
distributed_utils.call_main(cfg, main)
|
| 517 |
+
|
| 518 |
+
# if cfg.common.use_plasma_view:
|
| 519 |
+
# server.server.kill()
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
cli_main()
|
trainer.py
ADDED
|
@@ -0,0 +1,1531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Train a network across multiple GPUs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import contextlib
|
| 11 |
+
import logging
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
from argparse import Namespace
|
| 15 |
+
from itertools import chain
|
| 16 |
+
from typing import Any, Dict, List
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from fairseq import models, optim, utils
|
| 20 |
+
from fairseq.dataclass.configs import FairseqConfig
|
| 21 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 22 |
+
from fairseq.distributed import utils as distributed_utils
|
| 23 |
+
from fairseq.file_io import PathManager
|
| 24 |
+
from fairseq.logging import meters, metrics
|
| 25 |
+
from fairseq.models.ema import build_ema
|
| 26 |
+
from fairseq.nan_detector import NanDetector
|
| 27 |
+
from fairseq.optim import lr_scheduler
|
| 28 |
+
from omegaconf import OmegaConf
|
| 29 |
+
|
| 30 |
+
from utils import checkpoint_utils
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Trainer(object):
|
| 36 |
+
"""Main class for data parallel training.
|
| 37 |
+
|
| 38 |
+
This class supports synchronous distributed data parallel training,
|
| 39 |
+
where multiple workers each have a full model replica and gradients
|
| 40 |
+
are accumulated across workers before each update. We use
|
| 41 |
+
:class:`~torch.nn.parallel.DistributedDataParallel` to handle
|
| 42 |
+
communication of the gradients across workers.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None):
|
| 46 |
+
|
| 47 |
+
if isinstance(cfg, Namespace):
|
| 48 |
+
logger.warning(
|
| 49 |
+
"argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
|
| 50 |
+
)
|
| 51 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
| 52 |
+
|
| 53 |
+
self.cfg = cfg
|
| 54 |
+
self.task = task
|
| 55 |
+
|
| 56 |
+
# catalog shared parameters
|
| 57 |
+
shared_params = _catalog_shared_params(model)
|
| 58 |
+
self.tpu = cfg.common.tpu
|
| 59 |
+
self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
|
| 60 |
+
if self.cuda:
|
| 61 |
+
self.device = torch.device("cuda")
|
| 62 |
+
elif self.tpu:
|
| 63 |
+
self.device = utils.get_tpu_device()
|
| 64 |
+
else:
|
| 65 |
+
self.device = torch.device("cpu")
|
| 66 |
+
|
| 67 |
+
if self.is_fsdp:
|
| 68 |
+
import fairscale
|
| 69 |
+
if self.cfg.common.bf16:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"FullyShardedDataParallel is not compatible with --bf16 or "
|
| 72 |
+
"--memory-efficient-bf16"
|
| 73 |
+
)
|
| 74 |
+
if self.cfg.distributed_training.zero_sharding != "none":
|
| 75 |
+
raise ValueError(
|
| 76 |
+
"FullyShardedDataParallel is not compatible with --zero-sharding "
|
| 77 |
+
"option (it's already built in)"
|
| 78 |
+
)
|
| 79 |
+
if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0":
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"Please update to fairscale 0.4.0 or newer when combining "
|
| 82 |
+
"--update-freq with FullyShardedDataParallel"
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
if (
|
| 86 |
+
hasattr(self.cfg.distributed_training, "cpu_offload")
|
| 87 |
+
and self.cfg.distributed_training.cpu_offload
|
| 88 |
+
):
|
| 89 |
+
raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
|
| 90 |
+
|
| 91 |
+
# copy model and criterion to current device/dtype
|
| 92 |
+
self._criterion = criterion
|
| 93 |
+
self._model = model
|
| 94 |
+
if not self.is_fsdp:
|
| 95 |
+
if cfg.common.fp16:
|
| 96 |
+
assert not cfg.common.amp, "Cannot use fp16 and AMP together"
|
| 97 |
+
self._criterion = self._criterion.half()
|
| 98 |
+
self._model = self._model.half()
|
| 99 |
+
elif cfg.common.bf16:
|
| 100 |
+
self._criterion = self._criterion.to(dtype=torch.bfloat16)
|
| 101 |
+
self._model = self._model.to(dtype=torch.bfloat16)
|
| 102 |
+
elif cfg.common.amp:
|
| 103 |
+
self._amp_retries = 0
|
| 104 |
+
if (
|
| 105 |
+
not cfg.distributed_training.pipeline_model_parallel
|
| 106 |
+
# the DistributedFairseqModel wrapper will handle moving to device,
|
| 107 |
+
# so only handle cases which don't use the wrapper
|
| 108 |
+
and not self.use_distributed_wrapper
|
| 109 |
+
):
|
| 110 |
+
self._criterion = self._criterion.to(device=self.device)
|
| 111 |
+
self._model = self._model.to(device=self.device)
|
| 112 |
+
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
|
| 113 |
+
self.last_device = None
|
| 114 |
+
if self.cuda and self.pipeline_model_parallel:
|
| 115 |
+
self.last_device = torch.device(
|
| 116 |
+
cfg.distributed_training.pipeline_devices[-1]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# check that shared parameters are preserved after device transfer
|
| 120 |
+
for shared_param in shared_params:
|
| 121 |
+
ref = _get_module_by_path(self._model, shared_param[0])
|
| 122 |
+
for path in shared_param[1:]:
|
| 123 |
+
logger.info(
|
| 124 |
+
"detected shared parameter: {} <- {}".format(shared_param[0], path)
|
| 125 |
+
)
|
| 126 |
+
_set_module_by_path(self._model, path, ref)
|
| 127 |
+
|
| 128 |
+
self._dummy_batch = None # indicates we don't have a dummy batch at first
|
| 129 |
+
self._lr_scheduler = None
|
| 130 |
+
self._num_updates = 0
|
| 131 |
+
self._num_xla_compiles = 0 # for TPUs
|
| 132 |
+
self._optim_history = None
|
| 133 |
+
self._optimizer = None
|
| 134 |
+
self._warn_once = set()
|
| 135 |
+
self._wrapped_criterion = None
|
| 136 |
+
self._wrapped_model = None
|
| 137 |
+
self._ema = None
|
| 138 |
+
|
| 139 |
+
# TODO(myleott): support tpu
|
| 140 |
+
if self.cuda and self.data_parallel_world_size > 1:
|
| 141 |
+
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
|
| 142 |
+
else:
|
| 143 |
+
self._grad_norm_buf = None
|
| 144 |
+
|
| 145 |
+
self.quantizer = quantizer
|
| 146 |
+
if self.quantizer is not None:
|
| 147 |
+
self.quantizer.set_trainer(self)
|
| 148 |
+
|
| 149 |
+
# get detailed cuda environment
|
| 150 |
+
if self.cuda:
|
| 151 |
+
self.cuda_env = utils.CudaEnvironment()
|
| 152 |
+
if self.data_parallel_world_size > 1:
|
| 153 |
+
self.cuda_env_arr = distributed_utils.all_gather_list(
|
| 154 |
+
self.cuda_env, group=distributed_utils.get_global_group()
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
self.cuda_env_arr = [self.cuda_env]
|
| 158 |
+
if self.data_parallel_rank == 0:
|
| 159 |
+
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
|
| 160 |
+
else:
|
| 161 |
+
self.cuda_env = None
|
| 162 |
+
self.cuda_env_arr = None
|
| 163 |
+
|
| 164 |
+
metrics.log_start_time("wall", priority=790, round=0)
|
| 165 |
+
|
| 166 |
+
self._start_time = time.time()
|
| 167 |
+
self._previous_training_time = 0
|
| 168 |
+
self._cumulative_training_time = None
|
| 169 |
+
|
| 170 |
+
def reinitialize(self):
|
| 171 |
+
"""Reinitialize the Trainer, typically after model params change."""
|
| 172 |
+
self._lr_scheduler = None
|
| 173 |
+
self._optimizer = None
|
| 174 |
+
self._wrapped_criterion = None
|
| 175 |
+
self._wrapped_model = None
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def data_parallel_world_size(self):
|
| 179 |
+
if self.cfg.distributed_training.distributed_world_size == 1:
|
| 180 |
+
return 1
|
| 181 |
+
return distributed_utils.get_data_parallel_world_size()
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def data_parallel_process_group(self):
|
| 185 |
+
return distributed_utils.get_data_parallel_group()
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def data_parallel_rank(self):
|
| 189 |
+
if self.cfg.distributed_training.distributed_world_size == 1:
|
| 190 |
+
return 0
|
| 191 |
+
return distributed_utils.get_data_parallel_rank()
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def is_data_parallel_master(self):
|
| 195 |
+
# NOTE: this returns true for all model parallel replicas with data
|
| 196 |
+
# parallel rank 0
|
| 197 |
+
return self.data_parallel_rank == 0
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def use_distributed_wrapper(self) -> bool:
|
| 201 |
+
return (
|
| 202 |
+
self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf
|
| 203 |
+
) or (
|
| 204 |
+
self.is_fsdp and self.cfg.distributed_training.cpu_offload
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def should_save_checkpoint_on_current_rank(self) -> bool:
|
| 209 |
+
"""Indicates whether to save checkpoints on the current DDP rank."""
|
| 210 |
+
if (
|
| 211 |
+
self.is_fsdp and self.cfg.distributed_training.use_sharded_state
|
| 212 |
+
) or getattr(self.cfg.model, "base_layers", 0) > 0:
|
| 213 |
+
return True
|
| 214 |
+
else:
|
| 215 |
+
return self.is_data_parallel_master
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def always_call_state_dict_during_save_checkpoint(self) -> bool:
|
| 219 |
+
if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state:
|
| 220 |
+
# FSDP calls communication collective when consolidating checkpoints
|
| 221 |
+
return True
|
| 222 |
+
else:
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def checkpoint_suffix(self) -> str:
|
| 227 |
+
"""Suffix to add to the checkpoint file name."""
|
| 228 |
+
if self.is_fsdp and self.cfg.distributed_training.use_sharded_state:
|
| 229 |
+
return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(
|
| 230 |
+
self.data_parallel_rank
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
return self.cfg.checkpoint.checkpoint_suffix or ""
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def criterion(self):
|
| 237 |
+
if self._wrapped_criterion is None:
|
| 238 |
+
if utils.has_parameters(self._criterion) and self.use_distributed_wrapper:
|
| 239 |
+
self._wrapped_criterion = models.DistributedFairseqModel(
|
| 240 |
+
self.cfg.distributed_training,
|
| 241 |
+
self._criterion,
|
| 242 |
+
process_group=self.data_parallel_process_group,
|
| 243 |
+
device=self.device,
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
self._wrapped_criterion = self._criterion
|
| 247 |
+
return self._wrapped_criterion
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def model(self):
|
| 251 |
+
if self._wrapped_model is None:
|
| 252 |
+
if self.use_distributed_wrapper:
|
| 253 |
+
self._wrapped_model = models.DistributedFairseqModel(
|
| 254 |
+
self.cfg.distributed_training,
|
| 255 |
+
self._model,
|
| 256 |
+
process_group=self.data_parallel_process_group,
|
| 257 |
+
device=self.device,
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
self._wrapped_model = self._model
|
| 261 |
+
return self._wrapped_model
|
| 262 |
+
|
| 263 |
+
@property
|
| 264 |
+
def ema(self):
|
| 265 |
+
if self._ema is None:
|
| 266 |
+
self._build_ema()
|
| 267 |
+
return self._ema
|
| 268 |
+
|
| 269 |
+
def _build_ema(self):
|
| 270 |
+
if self.cfg.ema.store_ema:
|
| 271 |
+
self._ema = build_ema(self._model, self.cfg.ema, self.device)
|
| 272 |
+
logger.info(
|
| 273 |
+
"Exponential Moving Average Shadow Model is initialized."
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
@property
|
| 277 |
+
def optimizer(self):
|
| 278 |
+
if self._optimizer is None:
|
| 279 |
+
self._build_optimizer()
|
| 280 |
+
return self._optimizer
|
| 281 |
+
|
| 282 |
+
@property
|
| 283 |
+
def lr_scheduler(self):
|
| 284 |
+
if self._lr_scheduler is None:
|
| 285 |
+
self._build_optimizer() # this will initialize self._lr_scheduler
|
| 286 |
+
return self._lr_scheduler
|
| 287 |
+
|
| 288 |
+
def _build_optimizer(self):
|
| 289 |
+
params = list(
|
| 290 |
+
filter(
|
| 291 |
+
lambda p: p.requires_grad,
|
| 292 |
+
chain(self.model.parameters(), self.criterion.parameters()),
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if self.is_fsdp and self.cfg.common.fp16:
|
| 297 |
+
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
|
| 298 |
+
# mostly for the grad scaling. But if we don't have the
|
| 299 |
+
# --memory-efficient-fp16 flag set, then we're effectively doing
|
| 300 |
+
# regular --fp16 and can allow the use of optimizers that would
|
| 301 |
+
# otherwise be unsupported by MemoryEfficientFP16Optimizer.
|
| 302 |
+
allow_unsupported = not self.cfg.common.memory_efficient_fp16
|
| 303 |
+
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
|
| 304 |
+
self.cfg, params, allow_unsupported=allow_unsupported
|
| 305 |
+
)
|
| 306 |
+
elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp:
|
| 307 |
+
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
|
| 308 |
+
logger.info(
|
| 309 |
+
"NOTE: your device does NOT support faster training with --fp16 or --amp, "
|
| 310 |
+
"please switch to FP32 which is likely to be faster"
|
| 311 |
+
)
|
| 312 |
+
if (
|
| 313 |
+
self.cfg.common.memory_efficient_fp16
|
| 314 |
+
or self.cfg.common.memory_efficient_bf16
|
| 315 |
+
):
|
| 316 |
+
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
|
| 317 |
+
self.cfg, params
|
| 318 |
+
)
|
| 319 |
+
elif self.cfg.common.amp:
|
| 320 |
+
self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params)
|
| 321 |
+
else:
|
| 322 |
+
self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
|
| 323 |
+
else:
|
| 324 |
+
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
|
| 325 |
+
logger.info("NOTE: your device may support faster training with --fp16 or --amp")
|
| 326 |
+
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
|
| 327 |
+
|
| 328 |
+
if self.is_fsdp:
|
| 329 |
+
assert (
|
| 330 |
+
not self.cfg.optimization.use_bmuf
|
| 331 |
+
), "--ddp-backend=fully_sharded is not compatible with BMUF"
|
| 332 |
+
assert self._optimizer.supports_flat_params, (
|
| 333 |
+
"--ddp-backend=fully_sharded is only compatible with pointwise "
|
| 334 |
+
"optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
|
| 335 |
+
"However, the sharding will result in slightly different results when "
|
| 336 |
+
"using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if self.cfg.optimization.use_bmuf:
|
| 340 |
+
self._optimizer = optim.FairseqBMUF(
|
| 341 |
+
self.cfg.bmuf,
|
| 342 |
+
self._optimizer,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if self.cfg.distributed_training.zero_sharding == "os":
|
| 346 |
+
if (
|
| 347 |
+
self.cfg.common.fp16
|
| 348 |
+
and not self.cfg.common.memory_efficient_fp16
|
| 349 |
+
and not self.cfg.common.memory_efficient_bf16
|
| 350 |
+
) and not self.cfg.common.fp16_no_flatten_grads:
|
| 351 |
+
raise ValueError(
|
| 352 |
+
"ZeRO is incomptabile with fp16 and flattened grads. "
|
| 353 |
+
"Please use --fp16-no-flatten-grads"
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
optim.shard_(self._optimizer, self.data_parallel_process_group)
|
| 357 |
+
|
| 358 |
+
# We should initialize the learning rate scheduler immediately after
|
| 359 |
+
# building the optimizer, so that the initial learning rate is set.
|
| 360 |
+
self._lr_scheduler = lr_scheduler.build_lr_scheduler(
|
| 361 |
+
self.cfg.lr_scheduler,
|
| 362 |
+
self.optimizer,
|
| 363 |
+
)
|
| 364 |
+
self._lr_scheduler.step_update(0)
|
| 365 |
+
|
| 366 |
+
@property
|
| 367 |
+
def is_fsdp(self):
|
| 368 |
+
return self.cfg.distributed_training.ddp_backend == "fully_sharded"
|
| 369 |
+
|
| 370 |
+
def consolidate_optimizer(self):
|
| 371 |
+
"""For OSS, we need to consolidate the state dict."""
|
| 372 |
+
if self.cfg.checkpoint.no_save_optimizer_state:
|
| 373 |
+
return
|
| 374 |
+
self._gathered_optim_state = None
|
| 375 |
+
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
|
| 376 |
+
self.optimizer.optimizer.consolidate_state_dict()
|
| 377 |
+
elif self.is_fsdp and not self.model.use_sharded_state:
|
| 378 |
+
st = self.model.gather_full_optim_state_dict(
|
| 379 |
+
self.optimizer
|
| 380 |
+
) # only returns on rank 0
|
| 381 |
+
self._gathered_optim_state = st
|
| 382 |
+
|
| 383 |
+
def state_dict(self):
|
| 384 |
+
state_dict = {
|
| 385 |
+
"args": None, # legacy
|
| 386 |
+
"cfg": (
|
| 387 |
+
OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True)
|
| 388 |
+
if OmegaConf.is_config(self.cfg)
|
| 389 |
+
else self.cfg
|
| 390 |
+
),
|
| 391 |
+
"model": self.model.state_dict(),
|
| 392 |
+
"criterion": (
|
| 393 |
+
self.criterion.state_dict()
|
| 394 |
+
if utils.has_parameters(self.criterion)
|
| 395 |
+
else None
|
| 396 |
+
),
|
| 397 |
+
"optimizer_history": (self._optim_history or [])
|
| 398 |
+
+ [
|
| 399 |
+
{
|
| 400 |
+
"criterion_name": self.get_criterion().__class__.__name__,
|
| 401 |
+
"optimizer_name": self.optimizer.__class__.__name__,
|
| 402 |
+
"lr_scheduler_state": self.lr_scheduler.state_dict(),
|
| 403 |
+
"num_updates": self.get_num_updates(),
|
| 404 |
+
}
|
| 405 |
+
],
|
| 406 |
+
"task_state": self.task.state_dict() if self.task is not None else {},
|
| 407 |
+
"extra_state": {
|
| 408 |
+
"metrics": metrics.state_dict(),
|
| 409 |
+
"previous_training_time": self.cumulative_training_time(),
|
| 410 |
+
},
|
| 411 |
+
}
|
| 412 |
+
if self.cfg.ema.store_ema:
|
| 413 |
+
# Save EMA model state as extra state
|
| 414 |
+
state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict()
|
| 415 |
+
if self.cfg.ema.ema_fp32:
|
| 416 |
+
# Save EMA params in fp32
|
| 417 |
+
state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params
|
| 418 |
+
if not self.cfg.checkpoint.no_save_optimizer_state:
|
| 419 |
+
if self._gathered_optim_state is not None:
|
| 420 |
+
state_dict["last_optimizer_state"] = self._gathered_optim_state
|
| 421 |
+
self._gathered_optim_state = None
|
| 422 |
+
else:
|
| 423 |
+
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
|
| 424 |
+
if self.is_fsdp:
|
| 425 |
+
# save meta data for recombining checkpoint upon loading
|
| 426 |
+
state_dict["fsdp_metadata"] = self.model.local_metadata_dict()
|
| 427 |
+
return state_dict
|
| 428 |
+
|
| 429 |
+
def save_checkpoint(self, filename, extra_state):
|
| 430 |
+
"""Save all training state in a checkpoint file."""
|
| 431 |
+
logger.info(f"Saving checkpoint to {filename}")
|
| 432 |
+
# call state_dict on all ranks in case it needs internal communication
|
| 433 |
+
state_dict = utils.move_to_cpu(self.state_dict())
|
| 434 |
+
state_dict["extra_state"].update(extra_state)
|
| 435 |
+
if self.should_save_checkpoint_on_current_rank:
|
| 436 |
+
checkpoint_utils.torch_persistent_save(
|
| 437 |
+
state_dict,
|
| 438 |
+
filename,
|
| 439 |
+
async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
|
| 440 |
+
)
|
| 441 |
+
logger.info(f"Finished saving checkpoint to {filename}")
|
| 442 |
+
|
| 443 |
+
def load_checkpoint(
|
| 444 |
+
self,
|
| 445 |
+
filename,
|
| 446 |
+
reset_optimizer=False,
|
| 447 |
+
reset_lr_scheduler=False,
|
| 448 |
+
optimizer_overrides=None,
|
| 449 |
+
reset_meters=False,
|
| 450 |
+
):
|
| 451 |
+
"""
|
| 452 |
+
Load all training state from a checkpoint file.
|
| 453 |
+
rank = 0 will load the checkpoint, and then broadcast it to all
|
| 454 |
+
other ranks.
|
| 455 |
+
"""
|
| 456 |
+
extra_state, self._optim_history, last_optim_state = None, [], None
|
| 457 |
+
|
| 458 |
+
logger.info(f"Preparing to load checkpoint {filename}")
|
| 459 |
+
is_distributed = self.data_parallel_world_size > 1
|
| 460 |
+
bexists = PathManager.isfile(filename)
|
| 461 |
+
if bexists:
|
| 462 |
+
load_on_all_ranks = (
|
| 463 |
+
self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks
|
| 464 |
+
# TPUs don't support broadcast yet, so load checkpoints
|
| 465 |
+
# on every worker for now
|
| 466 |
+
or self.tpu
|
| 467 |
+
# FSDP requires loading checkpoint shards on all ranks
|
| 468 |
+
or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state)
|
| 469 |
+
or getattr(self.cfg.model, "base_layers", 0) > 0
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if load_on_all_ranks or self.data_parallel_rank == 0:
|
| 473 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
| 474 |
+
filename, load_on_all_ranks=load_on_all_ranks
|
| 475 |
+
)
|
| 476 |
+
last_optim_state = state.get("last_optimizer_state", None)
|
| 477 |
+
|
| 478 |
+
# If doing zero_sharding, do not broadcast global optimizer
|
| 479 |
+
# state. Later we will broadcast sharded states to each rank
|
| 480 |
+
# to avoid memory from exploding.
|
| 481 |
+
if (
|
| 482 |
+
not load_on_all_ranks
|
| 483 |
+
and self.cfg.distributed_training.zero_sharding == "os"
|
| 484 |
+
and "last_optimizer_state" in state
|
| 485 |
+
and is_distributed
|
| 486 |
+
):
|
| 487 |
+
state["last_optimizer_state"] = "SHARDED"
|
| 488 |
+
else:
|
| 489 |
+
last_optim_state = None
|
| 490 |
+
state = None
|
| 491 |
+
|
| 492 |
+
if is_distributed and not load_on_all_ranks:
|
| 493 |
+
state = distributed_utils.broadcast_object(
|
| 494 |
+
state,
|
| 495 |
+
src_rank=0,
|
| 496 |
+
group=self.data_parallel_process_group,
|
| 497 |
+
dist_device=self.device,
|
| 498 |
+
)
|
| 499 |
+
if self.data_parallel_rank > 0:
|
| 500 |
+
last_optim_state = state.get("last_optimizer_state", None)
|
| 501 |
+
|
| 502 |
+
# load model parameters
|
| 503 |
+
try:
|
| 504 |
+
if self.cfg.checkpoint.use_ema_weights_to_init_param and "extra_state" in state and "ema" in state["extra_state"]:
|
| 505 |
+
logger.info("use_ema_weights_to_init_param = True, will use EMA weights in the ckpt to init the model param...")
|
| 506 |
+
ema_state_dict = state["extra_state"]["ema_fp32_params"] if "ema_fp32_params" in state["extra_state"] else state["extra_state"]["ema"]
|
| 507 |
+
self.model.load_state_dict(
|
| 508 |
+
ema_state_dict, strict=True, model_cfg=self.cfg.model
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
self.model.load_state_dict(
|
| 512 |
+
state["model"], strict=True, model_cfg=self.cfg.model
|
| 513 |
+
)
|
| 514 |
+
# save memory for later steps
|
| 515 |
+
if not (self.cfg.ema.store_ema and (self.cfg.checkpoint.use_latest_weights_to_init_ema or not ("extra_state" in state and "ema" in state["extra_state"]))):
|
| 516 |
+
del state["model"]
|
| 517 |
+
if utils.has_parameters(self.get_criterion()):
|
| 518 |
+
self.get_criterion().load_state_dict(
|
| 519 |
+
state["criterion"], strict=True
|
| 520 |
+
)
|
| 521 |
+
del state["criterion"]
|
| 522 |
+
|
| 523 |
+
except Exception:
|
| 524 |
+
raise Exception(
|
| 525 |
+
"Cannot load model parameters from checkpoint {}; "
|
| 526 |
+
"please ensure that the architectures match.".format(filename)
|
| 527 |
+
)
|
| 528 |
+
extra_state = state["extra_state"]
|
| 529 |
+
self._optim_history = state["optimizer_history"]
|
| 530 |
+
|
| 531 |
+
if last_optim_state is not None and not reset_optimizer:
|
| 532 |
+
# rebuild optimizer after loading model, since params may have changed
|
| 533 |
+
self._build_optimizer()
|
| 534 |
+
|
| 535 |
+
# only reload optimizer and lr_scheduler if they match
|
| 536 |
+
last_optim = self._optim_history[-1]
|
| 537 |
+
assert (
|
| 538 |
+
last_optim["criterion_name"] == self.get_criterion().__class__.__name__
|
| 539 |
+
), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}"
|
| 540 |
+
assert (
|
| 541 |
+
last_optim["optimizer_name"] == self.optimizer.__class__.__name__
|
| 542 |
+
), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"
|
| 543 |
+
|
| 544 |
+
if not reset_lr_scheduler:
|
| 545 |
+
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
|
| 546 |
+
|
| 547 |
+
if self.is_fsdp and not self.model.use_sharded_state:
|
| 548 |
+
# if use_sharded_state, the last_optim_state is already sharded, skip this
|
| 549 |
+
last_optim_state = self.model.get_shard_from_optim_state_dict(
|
| 550 |
+
last_optim_state
|
| 551 |
+
)
|
| 552 |
+
elif not load_on_all_ranks and is_distributed:
|
| 553 |
+
last_optim_state = self.optimizer.broadcast_global_state_dict(
|
| 554 |
+
last_optim_state
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
|
| 558 |
+
|
| 559 |
+
self.set_num_updates(last_optim["num_updates"])
|
| 560 |
+
|
| 561 |
+
if extra_state is not None:
|
| 562 |
+
itr_state = extra_state["train_iterator"]
|
| 563 |
+
epoch = itr_state["epoch"]
|
| 564 |
+
|
| 565 |
+
if "previous_training_time" in extra_state:
|
| 566 |
+
self._previous_training_time = extra_state["previous_training_time"]
|
| 567 |
+
self._start_time = time.time()
|
| 568 |
+
|
| 569 |
+
self.lr_step(epoch)
|
| 570 |
+
|
| 571 |
+
if (
|
| 572 |
+
itr_state.get("version", 1) >= 2
|
| 573 |
+
and itr_state["iterations_in_epoch"] == 0
|
| 574 |
+
):
|
| 575 |
+
# reset meters at start of epoch
|
| 576 |
+
reset_meters = True
|
| 577 |
+
|
| 578 |
+
if "metrics" in extra_state and not reset_meters:
|
| 579 |
+
metrics.load_state_dict(extra_state["metrics"])
|
| 580 |
+
|
| 581 |
+
# reset TimeMeters, since their start times don't make sense anymore
|
| 582 |
+
for meter in metrics.get_meters("default"):
|
| 583 |
+
if isinstance(meter, meters.TimeMeter):
|
| 584 |
+
meter.reset()
|
| 585 |
+
|
| 586 |
+
if self.cfg.ema.store_ema:
|
| 587 |
+
if self.cfg.checkpoint.use_latest_weights_to_init_ema or "ema" not in extra_state:
|
| 588 |
+
if "ema" not in extra_state:
|
| 589 |
+
logger.warn(
|
| 590 |
+
"EMA not found in checkpoint. But store_ema is True. "
|
| 591 |
+
"EMA is re-initialized from checkpoint."
|
| 592 |
+
)
|
| 593 |
+
elif self.cfg.checkpoint.use_latest_weights_to_init_ema:
|
| 594 |
+
logger.info(
|
| 595 |
+
"use_latest_weights_to_init_ema = True. EMA is re-initialized from checkpoint."
|
| 596 |
+
)
|
| 597 |
+
self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32)
|
| 598 |
+
del state["model"]
|
| 599 |
+
else:
|
| 600 |
+
logger.info(
|
| 601 |
+
"Loading EMA from checkpoint"
|
| 602 |
+
)
|
| 603 |
+
self.ema.restore(extra_state["ema"], build_fp32_params=False)
|
| 604 |
+
|
| 605 |
+
if self.cfg.ema.ema_fp32:
|
| 606 |
+
if "ema_fp32_params" in extra_state:
|
| 607 |
+
logger.info(
|
| 608 |
+
"Loading EMA fp32 params from checkpoint"
|
| 609 |
+
)
|
| 610 |
+
self.ema.build_fp32_params(extra_state["ema_fp32_params"])
|
| 611 |
+
else:
|
| 612 |
+
logger.info(
|
| 613 |
+
"Building EMA fp32 params from EMA model in checkpoint"
|
| 614 |
+
)
|
| 615 |
+
self.ema.build_fp32_params()
|
| 616 |
+
|
| 617 |
+
logger.info(
|
| 618 |
+
"Loaded checkpoint {} (epoch {} @ {} updates)".format(
|
| 619 |
+
filename, epoch, self.get_num_updates()
|
| 620 |
+
)
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
else:
|
| 624 |
+
logger.info("No existing checkpoint found {}".format(filename))
|
| 625 |
+
|
| 626 |
+
return extra_state
|
| 627 |
+
|
| 628 |
+
def get_train_iterator(
|
| 629 |
+
self,
|
| 630 |
+
epoch,
|
| 631 |
+
combine=True,
|
| 632 |
+
load_dataset=True,
|
| 633 |
+
data_selector=None,
|
| 634 |
+
shard_batch_itr=True,
|
| 635 |
+
disable_iterator_cache=False,
|
| 636 |
+
):
|
| 637 |
+
"""Return an EpochBatchIterator over the training set for a given epoch."""
|
| 638 |
+
if load_dataset:
|
| 639 |
+
logger.info("loading train data for epoch {}".format(epoch))
|
| 640 |
+
self.task.load_dataset(
|
| 641 |
+
self.cfg.dataset.train_subset,
|
| 642 |
+
epoch=epoch,
|
| 643 |
+
combine=combine,
|
| 644 |
+
data_selector=data_selector,
|
| 645 |
+
tpu=self.tpu,
|
| 646 |
+
)
|
| 647 |
+
batch_iterator = self.task.get_batch_iterator(
|
| 648 |
+
dataset=self.task.dataset(self.cfg.dataset.train_subset),
|
| 649 |
+
max_tokens=self.cfg.dataset.max_tokens,
|
| 650 |
+
max_sentences=self.cfg.dataset.batch_size,
|
| 651 |
+
max_positions=utils.resolve_max_positions(
|
| 652 |
+
self.task.max_positions(),
|
| 653 |
+
self.model.max_positions(),
|
| 654 |
+
self.cfg.dataset.max_tokens,
|
| 655 |
+
),
|
| 656 |
+
ignore_invalid_inputs=True,
|
| 657 |
+
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
|
| 658 |
+
seed=self.cfg.common.seed,
|
| 659 |
+
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
|
| 660 |
+
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
|
| 661 |
+
num_workers=self.cfg.dataset.num_workers,
|
| 662 |
+
epoch=epoch,
|
| 663 |
+
data_buffer_size=self.cfg.dataset.data_buffer_size,
|
| 664 |
+
disable_iterator_cache=disable_iterator_cache,
|
| 665 |
+
)
|
| 666 |
+
self.reset_dummy_batch(batch_iterator.first_batch)
|
| 667 |
+
batch_iterator.dataset.dataset._seek()
|
| 668 |
+
return batch_iterator
|
| 669 |
+
|
| 670 |
+
def get_valid_iterator(
|
| 671 |
+
self,
|
| 672 |
+
subset,
|
| 673 |
+
disable_iterator_cache=False,
|
| 674 |
+
):
|
| 675 |
+
"""Return an EpochBatchIterator over given validation subset for a given epoch."""
|
| 676 |
+
self.task.dataset(subset).dataset._seek()
|
| 677 |
+
batch_iterator = self.task.get_batch_iterator(
|
| 678 |
+
dataset=self.task.dataset(subset),
|
| 679 |
+
max_tokens=self.cfg.dataset.max_tokens_valid,
|
| 680 |
+
max_sentences=self.cfg.dataset.batch_size_valid,
|
| 681 |
+
max_positions=utils.resolve_max_positions(
|
| 682 |
+
self.task.max_positions(),
|
| 683 |
+
self.model.max_positions(),
|
| 684 |
+
),
|
| 685 |
+
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
|
| 686 |
+
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
|
| 687 |
+
seed=self.cfg.common.seed,
|
| 688 |
+
num_shards=self.data_parallel_world_size,
|
| 689 |
+
shard_id=self.data_parallel_rank,
|
| 690 |
+
num_workers=self.cfg.dataset.num_workers,
|
| 691 |
+
# always pass a fixed "epoch" to keep validation data consistent
|
| 692 |
+
# across training epochs
|
| 693 |
+
epoch=1,
|
| 694 |
+
data_buffer_size=self.cfg.dataset.data_buffer_size,
|
| 695 |
+
disable_iterator_cache=disable_iterator_cache,
|
| 696 |
+
)
|
| 697 |
+
self.reset_dummy_batch(batch_iterator.first_batch)
|
| 698 |
+
batch_iterator.dataset.dataset._seek()
|
| 699 |
+
return batch_iterator
|
| 700 |
+
|
| 701 |
+
def begin_epoch(self, epoch):
|
| 702 |
+
"""Called at the beginning of each epoch."""
|
| 703 |
+
logger.info("begin training epoch {}".format(epoch))
|
| 704 |
+
|
| 705 |
+
self.lr_step_begin_epoch(epoch)
|
| 706 |
+
|
| 707 |
+
if self.quantizer is not None:
|
| 708 |
+
self.quantizer.begin_epoch(epoch)
|
| 709 |
+
|
| 710 |
+
# task specific setup per epoch
|
| 711 |
+
self.task.begin_epoch(epoch, self.get_model())
|
| 712 |
+
|
| 713 |
+
if self.tpu:
|
| 714 |
+
import torch_xla.core.xla_model as xm
|
| 715 |
+
|
| 716 |
+
xm.rendezvous("begin_epoch") # wait for all workers
|
| 717 |
+
xm.mark_step()
|
| 718 |
+
|
| 719 |
+
def begin_valid_epoch(self, epoch):
|
| 720 |
+
"""Called at the beginning of each validation epoch."""
|
| 721 |
+
|
| 722 |
+
# task specific setup per validation epoch
|
| 723 |
+
self.task.begin_valid_epoch(epoch, self.get_model())
|
| 724 |
+
|
| 725 |
+
def reset_dummy_batch(self, batch):
|
| 726 |
+
self._dummy_batch = batch
|
| 727 |
+
|
| 728 |
+
@metrics.aggregate("train")
|
| 729 |
+
def train_step(self, samples, raise_oom=False):
|
| 730 |
+
"""Do forward, backward and parameter update."""
|
| 731 |
+
self._set_seed()
|
| 732 |
+
self.model.train()
|
| 733 |
+
self.criterion.train()
|
| 734 |
+
self.zero_grad()
|
| 735 |
+
|
| 736 |
+
metrics.log_start_time("train_wall", priority=800, round=0)
|
| 737 |
+
|
| 738 |
+
# If EMA is enabled through store_ema=True
|
| 739 |
+
# and task.uses_ema is True, pass the EMA model as a keyword
|
| 740 |
+
# argument to the task.
|
| 741 |
+
extra_kwargs = {}
|
| 742 |
+
if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
|
| 743 |
+
extra_kwargs["ema_model"] = self.ema.get_model()
|
| 744 |
+
|
| 745 |
+
# forward and backward pass
|
| 746 |
+
logging_outputs, sample_size, ooms = [], 0, 0
|
| 747 |
+
for i, sample in enumerate(samples): # delayed update loop
|
| 748 |
+
sample, is_dummy_batch = self._prepare_sample(sample)
|
| 749 |
+
|
| 750 |
+
def maybe_no_sync():
|
| 751 |
+
"""
|
| 752 |
+
Whenever *samples* contains more than one mini-batch, we
|
| 753 |
+
want to accumulate gradients locally and only call
|
| 754 |
+
all-reduce in the last backwards pass.
|
| 755 |
+
"""
|
| 756 |
+
if (
|
| 757 |
+
self.data_parallel_world_size > 1
|
| 758 |
+
and hasattr(self.model, "no_sync")
|
| 759 |
+
and i < len(samples) - 1
|
| 760 |
+
# The no_sync context manager results in increased memory
|
| 761 |
+
# usage with FSDP, since full-size gradients will be
|
| 762 |
+
# accumulated on each GPU. It's typically a better tradeoff
|
| 763 |
+
# to do the extra communication with FSDP.
|
| 764 |
+
and not self.is_fsdp
|
| 765 |
+
):
|
| 766 |
+
return self.model.no_sync()
|
| 767 |
+
else:
|
| 768 |
+
return contextlib.ExitStack() # dummy contextmanager
|
| 769 |
+
|
| 770 |
+
try:
|
| 771 |
+
with maybe_no_sync():
|
| 772 |
+
# forward and backward
|
| 773 |
+
loss, sample_size_i, logging_output = self.task.train_step(
|
| 774 |
+
sample=sample,
|
| 775 |
+
model=self.model,
|
| 776 |
+
criterion=self.criterion,
|
| 777 |
+
optimizer=self.optimizer,
|
| 778 |
+
update_num=self.get_num_updates(),
|
| 779 |
+
ignore_grad=is_dummy_batch,
|
| 780 |
+
**extra_kwargs,
|
| 781 |
+
)
|
| 782 |
+
del loss
|
| 783 |
+
|
| 784 |
+
logging_outputs.append(logging_output)
|
| 785 |
+
sample_size += sample_size_i
|
| 786 |
+
|
| 787 |
+
# emptying the CUDA cache after the first step can
|
| 788 |
+
# reduce the chance of OOM
|
| 789 |
+
if self.cuda and self.get_num_updates() == 0:
|
| 790 |
+
torch.cuda.empty_cache()
|
| 791 |
+
except RuntimeError as e:
|
| 792 |
+
if "out of memory" in str(e):
|
| 793 |
+
self._log_oom(e)
|
| 794 |
+
if raise_oom:
|
| 795 |
+
raise e
|
| 796 |
+
logger.warning(
|
| 797 |
+
"attempting to recover from OOM in forward/backward pass"
|
| 798 |
+
)
|
| 799 |
+
ooms += 1
|
| 800 |
+
self.zero_grad()
|
| 801 |
+
if self.cuda:
|
| 802 |
+
torch.cuda.empty_cache()
|
| 803 |
+
if self.cfg.distributed_training.distributed_world_size == 1:
|
| 804 |
+
return None
|
| 805 |
+
else:
|
| 806 |
+
raise e
|
| 807 |
+
|
| 808 |
+
if self.tpu and i < len(samples) - 1:
|
| 809 |
+
# tpu-comment: every XLA operation before marking step is
|
| 810 |
+
# appended to the IR graph, and processing too many batches
|
| 811 |
+
# before marking step can lead to OOM errors.
|
| 812 |
+
# To handle gradient accumulation use case, we explicitly
|
| 813 |
+
# mark step here for every forward pass without a backward pass
|
| 814 |
+
self._xla_markstep_and_send_to_cpu()
|
| 815 |
+
|
| 816 |
+
if is_dummy_batch:
|
| 817 |
+
if torch.is_tensor(sample_size):
|
| 818 |
+
sample_size.zero_()
|
| 819 |
+
else:
|
| 820 |
+
sample_size *= 0.0
|
| 821 |
+
|
| 822 |
+
if torch.is_tensor(sample_size):
|
| 823 |
+
sample_size = sample_size.float()
|
| 824 |
+
else:
|
| 825 |
+
sample_size = float(sample_size)
|
| 826 |
+
|
| 827 |
+
# gather logging outputs from all replicas
|
| 828 |
+
if self._sync_stats():
|
| 829 |
+
train_time = self._local_cumulative_training_time()
|
| 830 |
+
logging_outputs, (
|
| 831 |
+
sample_size,
|
| 832 |
+
ooms,
|
| 833 |
+
total_train_time,
|
| 834 |
+
) = self._aggregate_logging_outputs(
|
| 835 |
+
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch
|
| 836 |
+
)
|
| 837 |
+
self._cumulative_training_time = (
|
| 838 |
+
total_train_time / self.data_parallel_world_size
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
overflow = False
|
| 842 |
+
try:
|
| 843 |
+
with torch.autograd.profiler.record_function("reduce-grads"):
|
| 844 |
+
# reduce gradients across workers
|
| 845 |
+
self.optimizer.all_reduce_grads(self.model)
|
| 846 |
+
if utils.has_parameters(self.criterion):
|
| 847 |
+
self.optimizer.all_reduce_grads(self.criterion)
|
| 848 |
+
|
| 849 |
+
with torch.autograd.profiler.record_function("multiply-grads"):
|
| 850 |
+
# multiply gradients by (data_parallel_size / sample_size) since
|
| 851 |
+
# DDP normalizes by the number of data parallel workers for
|
| 852 |
+
# improved fp16 precision.
|
| 853 |
+
# Thus we get (sum_of_gradients / sample_size) at the end.
|
| 854 |
+
# In case of fp16, this step also undoes loss scaling.
|
| 855 |
+
# (Debugging note: Some optimizers perform this scaling on the
|
| 856 |
+
# fly, so inspecting model.parameters() or optimizer.params may
|
| 857 |
+
# still show the original, unscaled gradients.)
|
| 858 |
+
numer = (
|
| 859 |
+
self.data_parallel_world_size
|
| 860 |
+
if not self.cfg.optimization.use_bmuf or self._sync_stats()
|
| 861 |
+
else 1
|
| 862 |
+
)
|
| 863 |
+
self.optimizer.multiply_grads(numer / (sample_size or 1.0))
|
| 864 |
+
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a
|
| 865 |
+
# way that avoids CPU/device transfers in case sample_size is a GPU or
|
| 866 |
+
# TPU object. The assumption is that the gradient itself is also 0.
|
| 867 |
+
|
| 868 |
+
with torch.autograd.profiler.record_function("clip-grads"):
|
| 869 |
+
# clip grads
|
| 870 |
+
grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
|
| 871 |
+
|
| 872 |
+
# check that grad norms are consistent across workers
|
| 873 |
+
# on tpu check tensor is slow
|
| 874 |
+
if not self.tpu:
|
| 875 |
+
if (
|
| 876 |
+
not self.cfg.optimization.use_bmuf
|
| 877 |
+
and self.cfg.distributed_training.ddp_backend != "slow_mo"
|
| 878 |
+
):
|
| 879 |
+
self._check_grad_norms(grad_norm)
|
| 880 |
+
if not torch.isfinite(grad_norm).all():
|
| 881 |
+
# in case of AMP, if gradients are Nan/Inf then
|
| 882 |
+
# optimizer step is still required
|
| 883 |
+
if self.cfg.common.amp:
|
| 884 |
+
overflow = True
|
| 885 |
+
else:
|
| 886 |
+
# check local gradnorm single GPU case, trigger NanDetector
|
| 887 |
+
raise FloatingPointError("gradients are Nan/Inf")
|
| 888 |
+
|
| 889 |
+
with torch.autograd.profiler.record_function("optimizer"):
|
| 890 |
+
# take an optimization step
|
| 891 |
+
self.task.optimizer_step(
|
| 892 |
+
self.optimizer, model=self.model, update_num=self.get_num_updates()
|
| 893 |
+
)
|
| 894 |
+
if self.cfg.common.amp and overflow:
|
| 895 |
+
if self._amp_retries == self.cfg.common.amp_batch_retries:
|
| 896 |
+
logger.info("AMP: skipping this batch.")
|
| 897 |
+
self._amp_retries = 0
|
| 898 |
+
else:
|
| 899 |
+
self._amp_retries += 1
|
| 900 |
+
return self.train_step(samples, raise_oom) # recursion to feed in same batch
|
| 901 |
+
|
| 902 |
+
except FloatingPointError:
|
| 903 |
+
# re-run the forward and backward pass with hooks attached to print
|
| 904 |
+
# out where it fails
|
| 905 |
+
self.zero_grad()
|
| 906 |
+
with NanDetector(self.get_model()):
|
| 907 |
+
for _, sample in enumerate(samples):
|
| 908 |
+
sample, _ = self._prepare_sample(sample)
|
| 909 |
+
self.task.train_step(
|
| 910 |
+
sample,
|
| 911 |
+
self.model,
|
| 912 |
+
self.criterion,
|
| 913 |
+
self.optimizer,
|
| 914 |
+
self.get_num_updates(),
|
| 915 |
+
ignore_grad=False,
|
| 916 |
+
**extra_kwargs,
|
| 917 |
+
)
|
| 918 |
+
raise
|
| 919 |
+
except OverflowError as e:
|
| 920 |
+
overflow = True
|
| 921 |
+
logger.info(
|
| 922 |
+
f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
|
| 923 |
+
)
|
| 924 |
+
grad_norm = torch.tensor(0.0).cuda()
|
| 925 |
+
self.zero_grad()
|
| 926 |
+
except RuntimeError as e:
|
| 927 |
+
if "out of memory" in str(e):
|
| 928 |
+
self._log_oom(e)
|
| 929 |
+
logger.error("OOM during optimization, irrecoverable")
|
| 930 |
+
raise e
|
| 931 |
+
|
| 932 |
+
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer
|
| 933 |
+
# after the step
|
| 934 |
+
if hasattr(self.model, "perform_additional_optimizer_actions"):
|
| 935 |
+
if hasattr(self.optimizer, "fp32_params"):
|
| 936 |
+
self.model.perform_additional_optimizer_actions(
|
| 937 |
+
self.optimizer.optimizer, self.optimizer.fp32_params
|
| 938 |
+
)
|
| 939 |
+
else:
|
| 940 |
+
self.model.perform_additional_optimizer_actions(
|
| 941 |
+
self.optimizer.optimizer
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
logging_output = None
|
| 945 |
+
if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo":
|
| 946 |
+
self.set_num_updates(self.get_num_updates() + 1)
|
| 947 |
+
|
| 948 |
+
if self.cfg.ema.store_ema:
|
| 949 |
+
# Step EMA forward with new model.
|
| 950 |
+
self.ema.step(
|
| 951 |
+
self.get_model(),
|
| 952 |
+
self.get_num_updates(),
|
| 953 |
+
)
|
| 954 |
+
metrics.log_scalar(
|
| 955 |
+
"ema_decay",
|
| 956 |
+
self.ema.get_decay(),
|
| 957 |
+
priority=10000,
|
| 958 |
+
round=5,
|
| 959 |
+
weight=0,
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
if self.tpu:
|
| 963 |
+
import torch_xla.core.xla_model as xm
|
| 964 |
+
|
| 965 |
+
# mark step on TPUs
|
| 966 |
+
self._xla_markstep_and_send_to_cpu()
|
| 967 |
+
|
| 968 |
+
# only log stats every log_interval steps
|
| 969 |
+
# this causes wps to be misreported when log_interval > 1
|
| 970 |
+
logging_output = {}
|
| 971 |
+
if self.get_num_updates() % self.cfg.common.log_interval == 0:
|
| 972 |
+
# log memory usage
|
| 973 |
+
mem_info = xm.get_memory_info(self.device)
|
| 974 |
+
gb_free = mem_info["kb_free"] / 1024 / 1024
|
| 975 |
+
gb_total = mem_info["kb_total"] / 1024 / 1024
|
| 976 |
+
metrics.log_scalar(
|
| 977 |
+
"gb_free", gb_free, priority=1500, round=1, weight=0
|
| 978 |
+
)
|
| 979 |
+
metrics.log_scalar(
|
| 980 |
+
"gb_total", gb_total, priority=1600, round=1, weight=0
|
| 981 |
+
)
|
| 982 |
+
logging_outputs = self._xla_markstep_and_send_to_cpu(
|
| 983 |
+
logging_outputs
|
| 984 |
+
)
|
| 985 |
+
logging_output = self._reduce_and_log_stats(
|
| 986 |
+
logging_outputs, sample_size, grad_norm
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
# log whenever there's an XLA compilation, since these
|
| 990 |
+
# slow down training and may indicate opportunities for
|
| 991 |
+
# optimization
|
| 992 |
+
self._check_xla_compilation()
|
| 993 |
+
else:
|
| 994 |
+
if self.cuda and self.cuda_env is not None:
|
| 995 |
+
# log minimum free memory over the iteration
|
| 996 |
+
gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
|
| 997 |
+
torch.cuda.reset_peak_memory_stats()
|
| 998 |
+
gb_free = self.cuda_env.total_memory_in_GB - gb_used
|
| 999 |
+
metrics.log_scalar(
|
| 1000 |
+
"gb_free", gb_free, priority=1500, round=1, weight=0
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
# log stats
|
| 1004 |
+
logging_output = self._reduce_and_log_stats(
|
| 1005 |
+
logging_outputs, sample_size, grad_norm
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
# clear CUDA cache to reduce memory fragmentation
|
| 1009 |
+
if (
|
| 1010 |
+
self.cuda
|
| 1011 |
+
and self.cfg.common.empty_cache_freq > 0
|
| 1012 |
+
and (
|
| 1013 |
+
(self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
|
| 1014 |
+
% self.cfg.common.empty_cache_freq
|
| 1015 |
+
)
|
| 1016 |
+
== 0
|
| 1017 |
+
):
|
| 1018 |
+
torch.cuda.empty_cache()
|
| 1019 |
+
|
| 1020 |
+
if self.cfg.common.fp16 or self.cfg.common.amp:
|
| 1021 |
+
metrics.log_scalar(
|
| 1022 |
+
"loss_scale",
|
| 1023 |
+
(
|
| 1024 |
+
self.optimizer.scaler.loss_scale
|
| 1025 |
+
if self.cfg.common.fp16
|
| 1026 |
+
else self.optimizer.scaler.get_scale()
|
| 1027 |
+
),
|
| 1028 |
+
priority=700,
|
| 1029 |
+
round=4,
|
| 1030 |
+
weight=0,
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
metrics.log_stop_time("train_wall")
|
| 1034 |
+
return logging_output
|
| 1035 |
+
|
| 1036 |
+
@metrics.aggregate("valid")
|
| 1037 |
+
def valid_step(self, sample, raise_oom=False):
|
| 1038 |
+
"""Do forward pass in evaluation mode."""
|
| 1039 |
+
if self.tpu:
|
| 1040 |
+
import torch_xla.core.xla_model as xm
|
| 1041 |
+
|
| 1042 |
+
xm.rendezvous("valid_step") # wait for all workers
|
| 1043 |
+
|
| 1044 |
+
# If EMA is enabled through store_ema=True
|
| 1045 |
+
# and task.uses_ema is True, pass the EMA model as a keyword
|
| 1046 |
+
# argument to the task.
|
| 1047 |
+
extra_kwargs = {}
|
| 1048 |
+
if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
|
| 1049 |
+
extra_kwargs["ema_model"] = self.ema.get_model()
|
| 1050 |
+
|
| 1051 |
+
with torch.no_grad():
|
| 1052 |
+
self.model.eval()
|
| 1053 |
+
self.criterion.eval()
|
| 1054 |
+
|
| 1055 |
+
sample, is_dummy_batch = self._prepare_sample(sample)
|
| 1056 |
+
|
| 1057 |
+
try:
|
| 1058 |
+
_loss, sample_size, logging_output = self.task.valid_step(
|
| 1059 |
+
sample, self.model, self.criterion, **extra_kwargs
|
| 1060 |
+
)
|
| 1061 |
+
except RuntimeError as e:
|
| 1062 |
+
if "out of memory" in str(e):
|
| 1063 |
+
self._log_oom(e)
|
| 1064 |
+
if not raise_oom:
|
| 1065 |
+
logger.warning(
|
| 1066 |
+
"ran out of memory in validation step, retrying batch"
|
| 1067 |
+
)
|
| 1068 |
+
for p in self.model.parameters():
|
| 1069 |
+
if p.grad is not None:
|
| 1070 |
+
p.grad = None # free some memory
|
| 1071 |
+
if self.cuda:
|
| 1072 |
+
torch.cuda.empty_cache()
|
| 1073 |
+
return self.valid_step(sample, raise_oom=True)
|
| 1074 |
+
raise e
|
| 1075 |
+
|
| 1076 |
+
logging_outputs = [logging_output]
|
| 1077 |
+
if is_dummy_batch:
|
| 1078 |
+
if torch.is_tensor(sample_size):
|
| 1079 |
+
sample_size.zero_()
|
| 1080 |
+
else:
|
| 1081 |
+
sample_size *= 0.0
|
| 1082 |
+
|
| 1083 |
+
# gather logging outputs from all replicas
|
| 1084 |
+
if self.data_parallel_world_size > 1:
|
| 1085 |
+
logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
|
| 1086 |
+
logging_outputs,
|
| 1087 |
+
sample_size,
|
| 1088 |
+
ignore=is_dummy_batch,
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
# log validation stats
|
| 1092 |
+
if self.tpu:
|
| 1093 |
+
logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs)
|
| 1094 |
+
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
|
| 1095 |
+
|
| 1096 |
+
return logging_output
|
| 1097 |
+
|
| 1098 |
+
def zero_grad(self):
|
| 1099 |
+
self.optimizer.zero_grad()
|
| 1100 |
+
|
| 1101 |
+
def lr_step_begin_epoch(self, epoch):
|
| 1102 |
+
"""Adjust the learning rate at the beginning of the epoch."""
|
| 1103 |
+
self.lr_scheduler.step_begin_epoch(epoch)
|
| 1104 |
+
# prefer updating the LR based on the number of steps
|
| 1105 |
+
return self.lr_step_update()
|
| 1106 |
+
|
| 1107 |
+
def lr_reinit(self, total_updates, num_updates):
|
| 1108 |
+
self.lr_scheduler.reinit(total_updates, num_updates)
|
| 1109 |
+
|
| 1110 |
+
def lr_step(self, epoch, val_loss=None):
|
| 1111 |
+
"""Adjust the learning rate at the end of the epoch."""
|
| 1112 |
+
self.lr_scheduler.step(epoch, val_loss)
|
| 1113 |
+
# prefer updating the LR based on the number of steps
|
| 1114 |
+
return self.lr_step_update()
|
| 1115 |
+
|
| 1116 |
+
def lr_step_update(self):
|
| 1117 |
+
"""Update the learning rate after each update."""
|
| 1118 |
+
new_lr = self.lr_scheduler.step_update(self.get_num_updates())
|
| 1119 |
+
if isinstance(new_lr, dict):
|
| 1120 |
+
for k, v in new_lr.items():
|
| 1121 |
+
metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300)
|
| 1122 |
+
new_lr = new_lr.get("default", next(iter(new_lr.values())))
|
| 1123 |
+
else:
|
| 1124 |
+
metrics.log_scalar("lr", new_lr, weight=0, priority=300)
|
| 1125 |
+
return new_lr
|
| 1126 |
+
|
| 1127 |
+
def get_lr(self):
|
| 1128 |
+
"""Get the current learning rate."""
|
| 1129 |
+
return self.optimizer.get_lr()
|
| 1130 |
+
|
| 1131 |
+
def get_model(self):
|
| 1132 |
+
"""Get the (non-wrapped) model instance."""
|
| 1133 |
+
return self._model
|
| 1134 |
+
|
| 1135 |
+
def get_criterion(self):
|
| 1136 |
+
"""Get the (non-wrapped) criterion instance."""
|
| 1137 |
+
return self._criterion
|
| 1138 |
+
|
| 1139 |
+
def get_meter(self, name):
|
| 1140 |
+
"""[deprecated] Get a specific meter by name."""
|
| 1141 |
+
from fairseq import meters
|
| 1142 |
+
|
| 1143 |
+
if "get_meter" not in self._warn_once:
|
| 1144 |
+
self._warn_once.add("get_meter")
|
| 1145 |
+
utils.deprecation_warning(
|
| 1146 |
+
"Trainer.get_meter is deprecated. Please use fairseq.metrics instead."
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
train_meters = metrics.get_meters("train")
|
| 1150 |
+
if train_meters is None:
|
| 1151 |
+
train_meters = {}
|
| 1152 |
+
|
| 1153 |
+
if name == "train_loss" and "loss" in train_meters:
|
| 1154 |
+
return train_meters["loss"]
|
| 1155 |
+
elif name == "train_nll_loss":
|
| 1156 |
+
# support for legacy train.py, which assumed this meter is
|
| 1157 |
+
# always initialized
|
| 1158 |
+
m = train_meters.get("nll_loss", None)
|
| 1159 |
+
return m or meters.AverageMeter()
|
| 1160 |
+
elif name == "wall":
|
| 1161 |
+
# support for legacy train.py, which assumed this meter is
|
| 1162 |
+
# always initialized
|
| 1163 |
+
m = metrics.get_meter("default", "wall")
|
| 1164 |
+
return m or meters.TimeMeter()
|
| 1165 |
+
elif name == "wps":
|
| 1166 |
+
m = metrics.get_meter("train", "wps")
|
| 1167 |
+
return m or meters.TimeMeter()
|
| 1168 |
+
elif name in {"valid_loss", "valid_nll_loss"}:
|
| 1169 |
+
# support for legacy train.py, which assumed these meters
|
| 1170 |
+
# are always initialized
|
| 1171 |
+
k = name[len("valid_") :]
|
| 1172 |
+
m = metrics.get_meter("valid", k)
|
| 1173 |
+
return m or meters.AverageMeter()
|
| 1174 |
+
elif name == "oom":
|
| 1175 |
+
return meters.AverageMeter()
|
| 1176 |
+
elif name in train_meters:
|
| 1177 |
+
return train_meters[name]
|
| 1178 |
+
return None
|
| 1179 |
+
|
| 1180 |
+
def get_num_updates(self):
|
| 1181 |
+
"""Get the number of parameters updates."""
|
| 1182 |
+
return self._num_updates
|
| 1183 |
+
|
| 1184 |
+
def set_num_updates(self, num_updates):
|
| 1185 |
+
"""Set the number of parameters updates."""
|
| 1186 |
+
self._num_updates = num_updates
|
| 1187 |
+
self.lr_step_update()
|
| 1188 |
+
if self.quantizer:
|
| 1189 |
+
self.quantizer.step_update(self._num_updates)
|
| 1190 |
+
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
|
| 1191 |
+
|
| 1192 |
+
def clip_grad_norm(self, clip_norm):
|
| 1193 |
+
def agg_norm_fn(total_norm):
|
| 1194 |
+
total_norm = total_norm.cuda().float() ** 2
|
| 1195 |
+
total_norm = distributed_utils.all_reduce(
|
| 1196 |
+
total_norm, group=self.data_parallel_process_group
|
| 1197 |
+
)
|
| 1198 |
+
return total_norm ** 0.5
|
| 1199 |
+
|
| 1200 |
+
should_agg_norm = (
|
| 1201 |
+
self.is_fsdp
|
| 1202 |
+
and (
|
| 1203 |
+
self.data_parallel_process_group is not None
|
| 1204 |
+
or torch.distributed.is_initialized()
|
| 1205 |
+
)
|
| 1206 |
+
)
|
| 1207 |
+
return self.optimizer.clip_grad_norm(
|
| 1208 |
+
clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None
|
| 1209 |
+
)
|
| 1210 |
+
|
| 1211 |
+
def cumulative_training_time(self):
|
| 1212 |
+
if self._cumulative_training_time is None:
|
| 1213 |
+
# single GPU
|
| 1214 |
+
return self._local_cumulative_training_time()
|
| 1215 |
+
else:
|
| 1216 |
+
return self._cumulative_training_time
|
| 1217 |
+
|
| 1218 |
+
def _local_cumulative_training_time(self):
|
| 1219 |
+
"""Aggregate training time in seconds."""
|
| 1220 |
+
return time.time() - self._start_time + self._previous_training_time
|
| 1221 |
+
|
| 1222 |
+
def _fp_convert_sample(self, sample):
|
| 1223 |
+
def apply_half(t):
|
| 1224 |
+
if t.dtype is torch.float32:
|
| 1225 |
+
return t.to(dtype=torch.half)
|
| 1226 |
+
return t
|
| 1227 |
+
|
| 1228 |
+
def apply_bfloat16(t):
|
| 1229 |
+
if t.dtype is torch.float32:
|
| 1230 |
+
return t.to(dtype=torch.bfloat16)
|
| 1231 |
+
return t
|
| 1232 |
+
|
| 1233 |
+
if self.cfg.common.fp16:
|
| 1234 |
+
sample = utils.apply_to_sample(apply_half, sample)
|
| 1235 |
+
|
| 1236 |
+
if self.cfg.common.bf16:
|
| 1237 |
+
sample = utils.apply_to_sample(apply_bfloat16, sample)
|
| 1238 |
+
|
| 1239 |
+
return sample
|
| 1240 |
+
|
| 1241 |
+
def _prepare_sample(self, sample, is_dummy=False):
|
| 1242 |
+
if sample == "DUMMY":
|
| 1243 |
+
raise Exception(
|
| 1244 |
+
"Trying to use an uninitialized 'dummy' batch. This usually indicates "
|
| 1245 |
+
"that the total number of batches is smaller than the number of "
|
| 1246 |
+
"participating GPUs. Try reducing the batch size or using fewer GPUs."
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
if sample is None or len(sample) == 0:
|
| 1250 |
+
assert (
|
| 1251 |
+
self._dummy_batch is not None and len(self._dummy_batch) > 0
|
| 1252 |
+
), "Invalid dummy batch: {}".format(self._dummy_batch)
|
| 1253 |
+
sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True)
|
| 1254 |
+
return sample, True
|
| 1255 |
+
|
| 1256 |
+
# Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth
|
| 1257 |
+
# it makes sense to do the format conversion on the CPU and then transfer
|
| 1258 |
+
# a smaller buffer to the device. This also saves GPU memory capacity.
|
| 1259 |
+
|
| 1260 |
+
if self.cfg.common.on_cpu_convert_precision:
|
| 1261 |
+
sample = self._fp_convert_sample(sample)
|
| 1262 |
+
|
| 1263 |
+
if self.cuda:
|
| 1264 |
+
if self.pipeline_model_parallel:
|
| 1265 |
+
if 'target' in sample:
|
| 1266 |
+
sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device)
|
| 1267 |
+
else:
|
| 1268 |
+
sample = utils.move_to_cuda(sample)
|
| 1269 |
+
elif self.tpu and is_dummy:
|
| 1270 |
+
# the dummy batch may not be on the appropriate device
|
| 1271 |
+
sample = utils.move_to_cuda(sample, device=self.device)
|
| 1272 |
+
|
| 1273 |
+
if not self.cfg.common.on_cpu_convert_precision:
|
| 1274 |
+
sample = self._fp_convert_sample(sample)
|
| 1275 |
+
|
| 1276 |
+
if self._dummy_batch == "DUMMY":
|
| 1277 |
+
self._dummy_batch = sample
|
| 1278 |
+
|
| 1279 |
+
return sample, False
|
| 1280 |
+
|
| 1281 |
+
def _set_seed(self):
|
| 1282 |
+
# Set seed based on args.seed and the update number so that we get
|
| 1283 |
+
# reproducible results when resuming from checkpoints
|
| 1284 |
+
seed = self.cfg.common.seed + self.get_num_updates()
|
| 1285 |
+
utils.set_torch_seed(seed)
|
| 1286 |
+
|
| 1287 |
+
def _sync_stats(self):
|
| 1288 |
+
# Return True if it's using multiple GPUs and DDP or multiple GPUs with
|
| 1289 |
+
# BMUF and it's a bmuf sync with warmup iterations completed before.
|
| 1290 |
+
if self.data_parallel_world_size == 1:
|
| 1291 |
+
return False
|
| 1292 |
+
elif self.cfg.optimization.use_bmuf:
|
| 1293 |
+
return (
|
| 1294 |
+
self.get_num_updates() + 1
|
| 1295 |
+
) % self.cfg.bmuf.global_sync_iter == 0 and (
|
| 1296 |
+
self.get_num_updates() + 1
|
| 1297 |
+
) > self.cfg.bmuf.warmup_iterations
|
| 1298 |
+
else:
|
| 1299 |
+
return True
|
| 1300 |
+
|
| 1301 |
+
def _log_oom(self, exc):
|
| 1302 |
+
msg = "OOM: Ran out of memory with exception: {}".format(exc)
|
| 1303 |
+
logger.warning(msg)
|
| 1304 |
+
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
|
| 1305 |
+
for device_idx in range(torch.cuda.device_count()):
|
| 1306 |
+
logger.warning(torch.cuda.memory_summary(device=device_idx))
|
| 1307 |
+
sys.stderr.flush()
|
| 1308 |
+
|
| 1309 |
+
def _aggregate_logging_outputs(
|
| 1310 |
+
self,
|
| 1311 |
+
logging_outputs: List[Dict[str, Any]],
|
| 1312 |
+
*extra_stats_to_sum,
|
| 1313 |
+
ignore=False,
|
| 1314 |
+
):
|
| 1315 |
+
if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()):
|
| 1316 |
+
return self._fast_stat_sync_sum(
|
| 1317 |
+
logging_outputs, *extra_stats_to_sum, ignore=ignore
|
| 1318 |
+
)
|
| 1319 |
+
else:
|
| 1320 |
+
return self._all_gather_list_sync(
|
| 1321 |
+
logging_outputs, *extra_stats_to_sum, ignore=ignore
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
def _all_gather_list_sync(
|
| 1325 |
+
self,
|
| 1326 |
+
logging_outputs: List[Dict[str, Any]],
|
| 1327 |
+
*extra_stats_to_sum,
|
| 1328 |
+
ignore=False,
|
| 1329 |
+
):
|
| 1330 |
+
"""
|
| 1331 |
+
Sync logging outputs across workers. all_gather_list_sync is
|
| 1332 |
+
suitable when logging outputs are complex types.
|
| 1333 |
+
"""
|
| 1334 |
+
if self.tpu:
|
| 1335 |
+
raise NotImplementedError
|
| 1336 |
+
if ignore:
|
| 1337 |
+
logging_outputs = []
|
| 1338 |
+
results = list(
|
| 1339 |
+
zip(
|
| 1340 |
+
*distributed_utils.all_gather_list(
|
| 1341 |
+
[logging_outputs] + list(extra_stats_to_sum),
|
| 1342 |
+
max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
|
| 1343 |
+
group=self.data_parallel_process_group,
|
| 1344 |
+
)
|
| 1345 |
+
)
|
| 1346 |
+
)
|
| 1347 |
+
logging_outputs, extra_stats_to_sum = results[0], results[1:]
|
| 1348 |
+
logging_outputs = list(chain.from_iterable(logging_outputs))
|
| 1349 |
+
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
|
| 1350 |
+
return logging_outputs, extra_stats_to_sum
|
| 1351 |
+
|
| 1352 |
+
def _fast_stat_sync_sum(
|
| 1353 |
+
self,
|
| 1354 |
+
logging_outputs: List[Dict[str, Any]],
|
| 1355 |
+
*extra_stats_to_sum,
|
| 1356 |
+
ignore=False,
|
| 1357 |
+
):
|
| 1358 |
+
"""
|
| 1359 |
+
Sync logging outputs across workers. fast_stat_sync_sum is
|
| 1360 |
+
faster than all_gather_list_sync, but is only suitable when
|
| 1361 |
+
logging outputs are scalars and can be summed. Note that
|
| 1362 |
+
*logging_outputs* cannot contain any nested dicts/lists.
|
| 1363 |
+
"""
|
| 1364 |
+
data = {}
|
| 1365 |
+
for i, stat in enumerate(extra_stats_to_sum):
|
| 1366 |
+
data["extra_stats_" + str(i)] = stat
|
| 1367 |
+
if len(logging_outputs) > 0:
|
| 1368 |
+
log_keys = list(logging_outputs[0].keys())
|
| 1369 |
+
for k in log_keys:
|
| 1370 |
+
if not ignore:
|
| 1371 |
+
v = sum(log[k] for log in logging_outputs if k in log)
|
| 1372 |
+
else:
|
| 1373 |
+
v = logging_outputs[0][k]
|
| 1374 |
+
v = torch.zeros_like(v) if torch.is_tensor(v) else 0
|
| 1375 |
+
data["logging_outputs_" + k] = v
|
| 1376 |
+
else:
|
| 1377 |
+
log_keys = None
|
| 1378 |
+
|
| 1379 |
+
data = distributed_utils.all_reduce_dict(
|
| 1380 |
+
data, device=self.device, group=self.data_parallel_process_group
|
| 1381 |
+
)
|
| 1382 |
+
|
| 1383 |
+
extra_stats_to_sum = [
|
| 1384 |
+
data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum))
|
| 1385 |
+
]
|
| 1386 |
+
if log_keys is not None:
|
| 1387 |
+
logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}]
|
| 1388 |
+
else:
|
| 1389 |
+
logging_outputs = []
|
| 1390 |
+
return logging_outputs, extra_stats_to_sum
|
| 1391 |
+
|
| 1392 |
+
def _check_grad_norms(self, grad_norm):
|
| 1393 |
+
"""Check that grad norms are consistent across workers."""
|
| 1394 |
+
if self._grad_norm_buf is not None:
|
| 1395 |
+
self._grad_norm_buf.zero_()
|
| 1396 |
+
self._grad_norm_buf[self.data_parallel_rank] = grad_norm
|
| 1397 |
+
distributed_utils.all_reduce(
|
| 1398 |
+
self._grad_norm_buf, group=self.data_parallel_process_group
|
| 1399 |
+
)
|
| 1400 |
+
|
| 1401 |
+
def is_consistent(tensor):
|
| 1402 |
+
max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
|
| 1403 |
+
return (
|
| 1404 |
+
(torch.isfinite(tensor).all()
|
| 1405 |
+
and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all())
|
| 1406 |
+
or
|
| 1407 |
+
(self.cfg.common.amp and not torch.isfinite(tensor).all())
|
| 1408 |
+
# in case of amp non-finite grads are fine
|
| 1409 |
+
)
|
| 1410 |
+
|
| 1411 |
+
if not is_consistent(self._grad_norm_buf):
|
| 1412 |
+
pretty_detail = "\n".join(
|
| 1413 |
+
"rank {:3d} = {:.8f}".format(r, n)
|
| 1414 |
+
for r, n in enumerate(self._grad_norm_buf.tolist())
|
| 1415 |
+
)
|
| 1416 |
+
error_detail = "grad_norm across the workers:\n{}\n".format(
|
| 1417 |
+
pretty_detail
|
| 1418 |
+
)
|
| 1419 |
+
# use FloatingPointError to trigger NanDetector
|
| 1420 |
+
raise FloatingPointError(
|
| 1421 |
+
"Fatal error: gradients are inconsistent between workers. "
|
| 1422 |
+
"Try --ddp-backend=legacy_ddp. "
|
| 1423 |
+
"Or are you mixing up different generation of GPUs in training?"
|
| 1424 |
+
+ "\n"
|
| 1425 |
+
+ "-" * 80
|
| 1426 |
+
+ "\n{}\n".format(error_detail)
|
| 1427 |
+
+ "-" * 80
|
| 1428 |
+
)
|
| 1429 |
+
|
| 1430 |
+
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
|
| 1431 |
+
if grad_norm is not None and (
|
| 1432 |
+
not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)
|
| 1433 |
+
):
|
| 1434 |
+
metrics.log_speed("ups", 1.0, priority=100, round=2)
|
| 1435 |
+
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
|
| 1436 |
+
if self.cfg.optimization.clip_norm > 0:
|
| 1437 |
+
metrics.log_scalar(
|
| 1438 |
+
"clip",
|
| 1439 |
+
torch.where(
|
| 1440 |
+
grad_norm > self.cfg.optimization.clip_norm,
|
| 1441 |
+
grad_norm.new_tensor(100),
|
| 1442 |
+
grad_norm.new_tensor(0),
|
| 1443 |
+
),
|
| 1444 |
+
priority=500,
|
| 1445 |
+
round=1,
|
| 1446 |
+
)
|
| 1447 |
+
|
| 1448 |
+
with metrics.aggregate() as agg:
|
| 1449 |
+
if logging_outputs is not None:
|
| 1450 |
+
self.task.reduce_metrics(logging_outputs, self.get_criterion())
|
| 1451 |
+
del logging_outputs
|
| 1452 |
+
|
| 1453 |
+
# extra warning for criterions that don't properly log a loss value
|
| 1454 |
+
if "loss" not in agg:
|
| 1455 |
+
if "loss" not in self._warn_once:
|
| 1456 |
+
self._warn_once.add("loss")
|
| 1457 |
+
logger.warning(
|
| 1458 |
+
"Criterion.reduce_metrics did not log a 'loss' value, "
|
| 1459 |
+
"which may break some functionality"
|
| 1460 |
+
)
|
| 1461 |
+
metrics.log_scalar("loss", -1)
|
| 1462 |
+
|
| 1463 |
+
# support legacy interface
|
| 1464 |
+
if self.tpu:
|
| 1465 |
+
logging_output = {}
|
| 1466 |
+
else:
|
| 1467 |
+
logging_output = agg.get_smoothed_values()
|
| 1468 |
+
logging_output["sample_size"] = sample_size
|
| 1469 |
+
for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
|
| 1470 |
+
if key_to_delete in logging_output:
|
| 1471 |
+
del logging_output[key_to_delete]
|
| 1472 |
+
return logging_output
|
| 1473 |
+
|
| 1474 |
+
def _check_xla_compilation(self):
|
| 1475 |
+
import torch_xla.debug.metrics as met
|
| 1476 |
+
|
| 1477 |
+
compile_stats = met.metric_data("CompileTime")
|
| 1478 |
+
if compile_stats is None:
|
| 1479 |
+
return
|
| 1480 |
+
num_xla_compiles = compile_stats[0]
|
| 1481 |
+
if num_xla_compiles > self._num_xla_compiles:
|
| 1482 |
+
logger.warning(
|
| 1483 |
+
"XLA compilation detected on device #{}; too many of these can lead "
|
| 1484 |
+
"to slow training, but we expect a few in the beginning".format(
|
| 1485 |
+
self.cfg.distributed_training.distributed_rank
|
| 1486 |
+
)
|
| 1487 |
+
)
|
| 1488 |
+
self._num_xla_compiles = num_xla_compiles
|
| 1489 |
+
|
| 1490 |
+
def _xla_markstep_and_send_to_cpu(self, data=None):
|
| 1491 |
+
import torch_xla.core.xla_model as xm
|
| 1492 |
+
|
| 1493 |
+
xm.mark_step()
|
| 1494 |
+
if data is not None:
|
| 1495 |
+
from fairseq.utils import xla_device_to_cpu
|
| 1496 |
+
|
| 1497 |
+
return xla_device_to_cpu(data)
|
| 1498 |
+
|
| 1499 |
+
|
| 1500 |
+
def _catalog_shared_params(module, memo=None, prefix=""):
|
| 1501 |
+
if memo is None:
|
| 1502 |
+
first_call = True
|
| 1503 |
+
memo = {}
|
| 1504 |
+
else:
|
| 1505 |
+
first_call = False
|
| 1506 |
+
for name, param in module._parameters.items():
|
| 1507 |
+
param_prefix = prefix + ("." if prefix else "") + name
|
| 1508 |
+
if param not in memo:
|
| 1509 |
+
memo[param] = []
|
| 1510 |
+
memo[param].append(param_prefix)
|
| 1511 |
+
for name, m in module._modules.items():
|
| 1512 |
+
if m is None:
|
| 1513 |
+
continue
|
| 1514 |
+
submodule_prefix = prefix + ("." if prefix else "") + name
|
| 1515 |
+
_catalog_shared_params(m, memo, submodule_prefix)
|
| 1516 |
+
if first_call:
|
| 1517 |
+
return [x for x in memo.values() if len(x) > 1]
|
| 1518 |
+
|
| 1519 |
+
|
| 1520 |
+
def _get_module_by_path(module, path):
|
| 1521 |
+
path = path.split(".")
|
| 1522 |
+
for name in path:
|
| 1523 |
+
module = getattr(module, name)
|
| 1524 |
+
return module
|
| 1525 |
+
|
| 1526 |
+
|
| 1527 |
+
def _set_module_by_path(module, path, value):
|
| 1528 |
+
path = path.split(".")
|
| 1529 |
+
for name in path[:-1]:
|
| 1530 |
+
module = getattr(module, name)
|
| 1531 |
+
setattr(module, path[-1], value)
|
utils/BPE/__init__.py
ADDED
|
File without changes
|
utils/BPE/dict.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils/BPE/encoder.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils/BPE/vocab.bpe
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils/__init__.py
ADDED
|
File without changes
|
utils/checkpoint_utils.py
ADDED
|
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import ast
|
| 7 |
+
import collections
|
| 8 |
+
import contextlib
|
| 9 |
+
import logging
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import time
|
| 14 |
+
import traceback
|
| 15 |
+
import math
|
| 16 |
+
from collections import OrderedDict
|
| 17 |
+
from typing import Any, Dict, Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from fairseq.dataclass.configs import CheckpointConfig
|
| 21 |
+
from fairseq.dataclass.utils import (
|
| 22 |
+
convert_namespace_to_omegaconf,
|
| 23 |
+
overwrite_args_by_name,
|
| 24 |
+
)
|
| 25 |
+
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
|
| 26 |
+
from fairseq.file_io import PathManager
|
| 27 |
+
from fairseq.models import FairseqDecoder, FairseqEncoder
|
| 28 |
+
from omegaconf import DictConfig, open_dict, OmegaConf
|
| 29 |
+
|
| 30 |
+
from data import data_utils
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
| 36 |
+
from fairseq import meters
|
| 37 |
+
|
| 38 |
+
# only one worker should attempt to create the required dir
|
| 39 |
+
if trainer.data_parallel_rank == 0:
|
| 40 |
+
os.makedirs(cfg.save_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
prev_best = getattr(save_checkpoint, "best", val_loss)
|
| 43 |
+
if val_loss is not None:
|
| 44 |
+
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
| 45 |
+
save_checkpoint.best = best_function(val_loss, prev_best)
|
| 46 |
+
|
| 47 |
+
if cfg.no_save:
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
|
| 51 |
+
|
| 52 |
+
if not trainer.should_save_checkpoint_on_current_rank:
|
| 53 |
+
if trainer.always_call_state_dict_during_save_checkpoint:
|
| 54 |
+
trainer.state_dict()
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
write_timer = meters.StopwatchMeter()
|
| 58 |
+
write_timer.start()
|
| 59 |
+
|
| 60 |
+
epoch = epoch_itr.epoch
|
| 61 |
+
end_of_epoch = epoch_itr.end_of_epoch()
|
| 62 |
+
updates = trainer.get_num_updates()
|
| 63 |
+
|
| 64 |
+
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
| 65 |
+
|
| 66 |
+
def is_better(a, b):
|
| 67 |
+
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
| 68 |
+
|
| 69 |
+
suffix = trainer.checkpoint_suffix
|
| 70 |
+
checkpoint_conds = collections.OrderedDict()
|
| 71 |
+
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
| 72 |
+
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
| 73 |
+
)
|
| 74 |
+
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
| 75 |
+
not end_of_epoch
|
| 76 |
+
and cfg.save_interval_updates > 0
|
| 77 |
+
and updates % cfg.save_interval_updates == 0
|
| 78 |
+
)
|
| 79 |
+
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
| 80 |
+
not hasattr(save_checkpoint, "best")
|
| 81 |
+
or is_better(val_loss, save_checkpoint.best)
|
| 82 |
+
)
|
| 83 |
+
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
| 84 |
+
worst_best = getattr(save_checkpoint, "best", None)
|
| 85 |
+
chkpts = checkpoint_paths(
|
| 86 |
+
cfg.save_dir,
|
| 87 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
| 88 |
+
cfg.best_checkpoint_metric, suffix
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
if len(chkpts) > 0:
|
| 92 |
+
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
|
| 93 |
+
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
|
| 94 |
+
# add random digits to resolve ties
|
| 95 |
+
with data_utils.numpy_seed(epoch, updates, val_loss):
|
| 96 |
+
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
|
| 97 |
+
|
| 98 |
+
checkpoint_conds[
|
| 99 |
+
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
| 100 |
+
cfg.best_checkpoint_metric,
|
| 101 |
+
val_loss,
|
| 102 |
+
rand_sfx,
|
| 103 |
+
suffix
|
| 104 |
+
)
|
| 105 |
+
] = worst_best is None or is_better(val_loss, worst_best)
|
| 106 |
+
checkpoint_conds[
|
| 107 |
+
"checkpoint_last{}.pt".format(suffix)
|
| 108 |
+
] = not cfg.no_last_checkpoints
|
| 109 |
+
|
| 110 |
+
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
| 111 |
+
if hasattr(save_checkpoint, "best"):
|
| 112 |
+
extra_state.update({"best": save_checkpoint.best})
|
| 113 |
+
|
| 114 |
+
checkpoints = [
|
| 115 |
+
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
| 116 |
+
]
|
| 117 |
+
if len(checkpoints) > 0:
|
| 118 |
+
trainer.save_checkpoint(checkpoints[0], extra_state)
|
| 119 |
+
for cp in checkpoints[1:]:
|
| 120 |
+
if cfg.write_checkpoints_asynchronously:
|
| 121 |
+
# TODO[ioPath]: Need to implement a delayed asynchronous
|
| 122 |
+
# file copying/moving feature.
|
| 123 |
+
logger.warning(
|
| 124 |
+
f"ioPath is not copying {checkpoints[0]} to {cp} "
|
| 125 |
+
"since async write mode is on."
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
assert PathManager.copy(
|
| 129 |
+
checkpoints[0], cp, overwrite=True
|
| 130 |
+
), f"Failed to copy {checkpoints[0]} to {cp}"
|
| 131 |
+
|
| 132 |
+
write_timer.stop()
|
| 133 |
+
logger.info(
|
| 134 |
+
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
| 135 |
+
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if not end_of_epoch and cfg.keep_interval_updates > 0:
|
| 140 |
+
# remove old checkpoints; checkpoints are sorted in descending order
|
| 141 |
+
if cfg.keep_interval_updates_pattern == -1:
|
| 142 |
+
checkpoints = checkpoint_paths(
|
| 143 |
+
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
checkpoints = checkpoint_paths(
|
| 147 |
+
cfg.save_dir,
|
| 148 |
+
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
|
| 149 |
+
keep_match=True,
|
| 150 |
+
)
|
| 151 |
+
checkpoints = [
|
| 152 |
+
x[0]
|
| 153 |
+
for x in checkpoints
|
| 154 |
+
if x[1] % cfg.keep_interval_updates_pattern != 0
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
| 158 |
+
if os.path.lexists(old_chk):
|
| 159 |
+
os.remove(old_chk)
|
| 160 |
+
elif PathManager.exists(old_chk):
|
| 161 |
+
PathManager.rm(old_chk)
|
| 162 |
+
|
| 163 |
+
if cfg.keep_last_epochs > 0:
|
| 164 |
+
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
| 165 |
+
checkpoints = checkpoint_paths(
|
| 166 |
+
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
|
| 167 |
+
)
|
| 168 |
+
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
| 169 |
+
if os.path.lexists(old_chk):
|
| 170 |
+
os.remove(old_chk)
|
| 171 |
+
elif PathManager.exists(old_chk):
|
| 172 |
+
PathManager.rm(old_chk)
|
| 173 |
+
|
| 174 |
+
if cfg.keep_best_checkpoints > 0:
|
| 175 |
+
# only keep the best N checkpoints according to validation metric
|
| 176 |
+
checkpoints = checkpoint_paths(
|
| 177 |
+
cfg.save_dir,
|
| 178 |
+
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
| 179 |
+
cfg.best_checkpoint_metric, suffix
|
| 180 |
+
),
|
| 181 |
+
)
|
| 182 |
+
if not cfg.maximize_best_checkpoint_metric:
|
| 183 |
+
checkpoints = checkpoints[::-1]
|
| 184 |
+
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
| 185 |
+
if os.path.lexists(old_chk):
|
| 186 |
+
os.remove(old_chk)
|
| 187 |
+
elif PathManager.exists(old_chk):
|
| 188 |
+
PathManager.rm(old_chk)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
| 192 |
+
"""
|
| 193 |
+
Load a checkpoint and restore the training iterator.
|
| 194 |
+
|
| 195 |
+
*passthrough_args* will be passed through to
|
| 196 |
+
``trainer.get_train_iterator``.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
reset_optimizer = cfg.reset_optimizer
|
| 200 |
+
reset_lr_scheduler = cfg.reset_lr_scheduler
|
| 201 |
+
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
| 202 |
+
reset_meters = cfg.reset_meters
|
| 203 |
+
reset_dataloader = cfg.reset_dataloader
|
| 204 |
+
|
| 205 |
+
if cfg.finetune_from_model is not None and (
|
| 206 |
+
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
| 207 |
+
):
|
| 208 |
+
raise ValueError(
|
| 209 |
+
"--finetune-from-model can not be set together with either --reset-optimizer"
|
| 210 |
+
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
suffix = trainer.checkpoint_suffix
|
| 214 |
+
if (
|
| 215 |
+
cfg.restore_file == "checkpoint_last.pt"
|
| 216 |
+
): # default value of restore_file is 'checkpoint_last.pt'
|
| 217 |
+
checkpoint_path = os.path.join(
|
| 218 |
+
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
| 219 |
+
)
|
| 220 |
+
first_launch = not PathManager.exists(checkpoint_path)
|
| 221 |
+
if cfg.finetune_from_model is not None and first_launch:
|
| 222 |
+
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
| 223 |
+
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
| 224 |
+
if PathManager.exists(cfg.finetune_from_model):
|
| 225 |
+
checkpoint_path = cfg.finetune_from_model
|
| 226 |
+
reset_optimizer = True
|
| 227 |
+
reset_lr_scheduler = True
|
| 228 |
+
reset_meters = True
|
| 229 |
+
reset_dataloader = True
|
| 230 |
+
logger.info(
|
| 231 |
+
f"loading pretrained model from {checkpoint_path}: "
|
| 232 |
+
"optimizer, lr scheduler, meters, dataloader will be reset"
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"--funetune-from-model {cfg.finetune_from_model} does not exist"
|
| 237 |
+
)
|
| 238 |
+
elif suffix is not None:
|
| 239 |
+
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
| 240 |
+
else:
|
| 241 |
+
checkpoint_path = cfg.restore_file
|
| 242 |
+
|
| 243 |
+
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
"--finetune-from-model and --restore-file (non-default value) "
|
| 246 |
+
"can not be specified together: " + str(cfg)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
extra_state = trainer.load_checkpoint(
|
| 250 |
+
checkpoint_path,
|
| 251 |
+
reset_optimizer,
|
| 252 |
+
reset_lr_scheduler,
|
| 253 |
+
optimizer_overrides,
|
| 254 |
+
reset_meters=reset_meters,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if (
|
| 258 |
+
extra_state is not None
|
| 259 |
+
and "best" in extra_state
|
| 260 |
+
and not reset_optimizer
|
| 261 |
+
and not reset_meters
|
| 262 |
+
):
|
| 263 |
+
save_checkpoint.best = extra_state["best"]
|
| 264 |
+
|
| 265 |
+
if extra_state is not None and not reset_dataloader:
|
| 266 |
+
# restore iterator from checkpoint
|
| 267 |
+
itr_state = extra_state["train_iterator"]
|
| 268 |
+
epoch_itr = trainer.get_train_iterator(
|
| 269 |
+
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
| 270 |
+
)
|
| 271 |
+
epoch_itr.load_state_dict(itr_state)
|
| 272 |
+
_n = itr_state['iterations_in_epoch']
|
| 273 |
+
offset = sum(len(_) for _ in epoch_itr.batch_sampler[:_n])
|
| 274 |
+
epoch_itr.dataset.dataset._seek(offset=offset)
|
| 275 |
+
true_num = int(math.ceil(len(epoch_itr.dataset) / 8)) * 8
|
| 276 |
+
another_offset = ((epoch_itr.epoch - 1) * true_num + offset) // 8
|
| 277 |
+
if hasattr(epoch_itr.dataset, 'pure_text_dataset'):
|
| 278 |
+
text_offset = (2 * another_offset) % len(epoch_itr.dataset.pure_text_dataset)
|
| 279 |
+
epoch_itr.dataset.pure_text_dataset._seek(offset=text_offset)
|
| 280 |
+
if hasattr(epoch_itr.dataset, 'pure_image_dataset'):
|
| 281 |
+
image_offset = another_offset % len(epoch_itr.dataset.pure_image_dataset)
|
| 282 |
+
epoch_itr.dataset.pure_image_dataset._seek(offset=image_offset)
|
| 283 |
+
if hasattr(epoch_itr.dataset, 'detection_dataset'):
|
| 284 |
+
detection_offset = another_offset % len(epoch_itr.dataset.detection_dataset)
|
| 285 |
+
epoch_itr.dataset.detection_dataset._seek(offset=detection_offset)
|
| 286 |
+
else:
|
| 287 |
+
epoch_itr = trainer.get_train_iterator(
|
| 288 |
+
epoch=1, load_dataset=True, **passthrough_args
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
trainer.lr_step(epoch_itr.epoch)
|
| 292 |
+
|
| 293 |
+
return extra_state, epoch_itr
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
| 297 |
+
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
| 298 |
+
|
| 299 |
+
If doing single-GPU training or if the checkpoint is only being loaded by at
|
| 300 |
+
most one process on each node (current default behavior is for only rank 0
|
| 301 |
+
to read the checkpoint from disk), load_on_all_ranks should be False to
|
| 302 |
+
avoid errors from torch.distributed not having been initialized or
|
| 303 |
+
torch.distributed.barrier() hanging.
|
| 304 |
+
|
| 305 |
+
If all processes on each node may be loading the checkpoint
|
| 306 |
+
simultaneously, load_on_all_ranks should be set to True to avoid I/O
|
| 307 |
+
conflicts.
|
| 308 |
+
|
| 309 |
+
There's currently no support for > 1 but < all processes loading the
|
| 310 |
+
checkpoint on each node.
|
| 311 |
+
"""
|
| 312 |
+
local_path = PathManager.get_local_path(path)
|
| 313 |
+
# The locally cached file returned by get_local_path() may be stale for
|
| 314 |
+
# remote files that are periodically updated/overwritten (ex:
|
| 315 |
+
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
| 316 |
+
# (if needed), and then download a fresh copy.
|
| 317 |
+
if local_path != path and PathManager.path_requires_pathmanager(path):
|
| 318 |
+
try:
|
| 319 |
+
os.remove(local_path)
|
| 320 |
+
except FileNotFoundError:
|
| 321 |
+
# With potentially multiple processes removing the same file, the
|
| 322 |
+
# file being missing is benign (missing_ok isn't available until
|
| 323 |
+
# Python 3.8).
|
| 324 |
+
pass
|
| 325 |
+
if load_on_all_ranks:
|
| 326 |
+
torch.distributed.barrier()
|
| 327 |
+
local_path = PathManager.get_local_path(path)
|
| 328 |
+
|
| 329 |
+
with open(local_path, "rb") as f:
|
| 330 |
+
state = torch.load(f, map_location=torch.device("cpu"))
|
| 331 |
+
|
| 332 |
+
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
| 333 |
+
args = state["args"]
|
| 334 |
+
for arg_name, arg_val in arg_overrides.items():
|
| 335 |
+
setattr(args, arg_name, arg_val)
|
| 336 |
+
|
| 337 |
+
if "cfg" in state and state["cfg"] is not None:
|
| 338 |
+
|
| 339 |
+
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
| 340 |
+
# omegaconf version that supports object flags, or when we migrate all existing models
|
| 341 |
+
from omegaconf import _utils
|
| 342 |
+
|
| 343 |
+
old_primitive = _utils.is_primitive_type
|
| 344 |
+
_utils.is_primitive_type = lambda _: True
|
| 345 |
+
|
| 346 |
+
state["cfg"] = OmegaConf.create(state["cfg"])
|
| 347 |
+
|
| 348 |
+
_utils.is_primitive_type = old_primitive
|
| 349 |
+
OmegaConf.set_struct(state["cfg"], True)
|
| 350 |
+
|
| 351 |
+
if arg_overrides is not None:
|
| 352 |
+
overwrite_args_by_name(state["cfg"], arg_overrides)
|
| 353 |
+
|
| 354 |
+
state = _upgrade_state_dict(state)
|
| 355 |
+
return state
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def load_model_ensemble(
|
| 359 |
+
filenames,
|
| 360 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
| 361 |
+
task=None,
|
| 362 |
+
strict=True,
|
| 363 |
+
suffix="",
|
| 364 |
+
num_shards=1,
|
| 365 |
+
state=None,
|
| 366 |
+
):
|
| 367 |
+
"""Loads an ensemble of models.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
filenames (List[str]): checkpoint files to load
|
| 371 |
+
arg_overrides (Dict[str,Any], optional): override model args that
|
| 372 |
+
were used during model training
|
| 373 |
+
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
| 374 |
+
"""
|
| 375 |
+
assert not (
|
| 376 |
+
strict and num_shards > 1
|
| 377 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
| 378 |
+
ensemble, args, _task = load_model_ensemble_and_task(
|
| 379 |
+
filenames,
|
| 380 |
+
arg_overrides,
|
| 381 |
+
task,
|
| 382 |
+
strict,
|
| 383 |
+
suffix,
|
| 384 |
+
num_shards,
|
| 385 |
+
state,
|
| 386 |
+
)
|
| 387 |
+
return ensemble, args
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def get_maybe_sharded_checkpoint_filename(
|
| 391 |
+
filename: str, suffix: str, shard_idx: int, num_shards: int
|
| 392 |
+
) -> str:
|
| 393 |
+
orig_filename = filename
|
| 394 |
+
filename = filename.replace(".pt", suffix + ".pt")
|
| 395 |
+
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
|
| 396 |
+
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
| 397 |
+
if PathManager.exists(fsdp_filename):
|
| 398 |
+
return fsdp_filename
|
| 399 |
+
elif num_shards > 1:
|
| 400 |
+
return model_parallel_filename
|
| 401 |
+
else:
|
| 402 |
+
return filename
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def load_model_ensemble_and_task(
|
| 406 |
+
filenames,
|
| 407 |
+
arg_overrides: Optional[Dict[str, Any]] = None,
|
| 408 |
+
task=None,
|
| 409 |
+
strict=True,
|
| 410 |
+
suffix="",
|
| 411 |
+
num_shards=1,
|
| 412 |
+
state=None,
|
| 413 |
+
):
|
| 414 |
+
assert state is None or len(filenames) == 1
|
| 415 |
+
|
| 416 |
+
from fairseq import tasks
|
| 417 |
+
|
| 418 |
+
assert not (
|
| 419 |
+
strict and num_shards > 1
|
| 420 |
+
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
| 421 |
+
ensemble = []
|
| 422 |
+
cfg = None
|
| 423 |
+
for filename in filenames:
|
| 424 |
+
orig_filename = filename
|
| 425 |
+
model_shard_state = {"shard_weights": [], "shard_metadata": []}
|
| 426 |
+
assert num_shards > 0
|
| 427 |
+
st = time.time()
|
| 428 |
+
for shard_idx in range(num_shards):
|
| 429 |
+
filename = get_maybe_sharded_checkpoint_filename(
|
| 430 |
+
orig_filename, suffix, shard_idx, num_shards
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if not PathManager.exists(filename):
|
| 434 |
+
raise IOError("Model file not found: {}".format(filename))
|
| 435 |
+
if state is None:
|
| 436 |
+
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
| 437 |
+
if "args" in state and state["args"] is not None:
|
| 438 |
+
cfg = convert_namespace_to_omegaconf(state["args"])
|
| 439 |
+
elif "cfg" in state and state["cfg"] is not None:
|
| 440 |
+
cfg = state["cfg"]
|
| 441 |
+
else:
|
| 442 |
+
raise RuntimeError(
|
| 443 |
+
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
if task is None:
|
| 447 |
+
task = tasks.setup_task(cfg.task)
|
| 448 |
+
|
| 449 |
+
if "task_state" in state:
|
| 450 |
+
task.load_state_dict(state["task_state"])
|
| 451 |
+
|
| 452 |
+
if "fsdp_metadata" in state and num_shards > 1:
|
| 453 |
+
model_shard_state["shard_weights"].append(state["model"])
|
| 454 |
+
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
| 455 |
+
# check FSDP import before the code goes too far
|
| 456 |
+
if not has_FSDP:
|
| 457 |
+
raise ImportError(
|
| 458 |
+
"Cannot find FullyShardedDataParallel. "
|
| 459 |
+
"Please install fairscale with: pip install fairscale"
|
| 460 |
+
)
|
| 461 |
+
if shard_idx == num_shards - 1:
|
| 462 |
+
consolidated_model_state = FSDP.consolidate_shard_weights(
|
| 463 |
+
shard_weights=model_shard_state["shard_weights"],
|
| 464 |
+
shard_metadata=model_shard_state["shard_metadata"],
|
| 465 |
+
)
|
| 466 |
+
model = task.build_model(cfg.model)
|
| 467 |
+
model.load_state_dict(
|
| 468 |
+
consolidated_model_state, strict=strict, model_cfg=cfg.model
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
# model parallel checkpoint or unsharded checkpoint
|
| 472 |
+
model = task.build_model(cfg.model)
|
| 473 |
+
model.load_state_dict(
|
| 474 |
+
state["model"], strict=strict, model_cfg=cfg.model
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# reset state so it gets loaded for the next model in ensemble
|
| 478 |
+
state = None
|
| 479 |
+
if shard_idx % 10 == 0 and shard_idx > 0:
|
| 480 |
+
elapsed = time.time() - st
|
| 481 |
+
logger.info(
|
| 482 |
+
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# build model for ensemble
|
| 486 |
+
ensemble.append(model)
|
| 487 |
+
return ensemble, cfg, task
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
| 491 |
+
"""Retrieves all checkpoints found in `path` directory.
|
| 492 |
+
|
| 493 |
+
Checkpoints are identified by matching filename to the specified pattern. If
|
| 494 |
+
the pattern contains groups, the result will be sorted by the first group in
|
| 495 |
+
descending order.
|
| 496 |
+
"""
|
| 497 |
+
pt_regexp = re.compile(pattern)
|
| 498 |
+
files = PathManager.ls(path)
|
| 499 |
+
|
| 500 |
+
entries = []
|
| 501 |
+
for i, f in enumerate(files):
|
| 502 |
+
m = pt_regexp.fullmatch(f)
|
| 503 |
+
if m is not None:
|
| 504 |
+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
| 505 |
+
entries.append((idx, m.group(0)))
|
| 506 |
+
if keep_match:
|
| 507 |
+
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
| 508 |
+
else:
|
| 509 |
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def torch_persistent_save(obj, filename, async_write: bool = False):
|
| 513 |
+
if async_write:
|
| 514 |
+
with PathManager.opena(filename, "wb") as f:
|
| 515 |
+
_torch_persistent_save(obj, f)
|
| 516 |
+
else:
|
| 517 |
+
with PathManager.open(filename, "wb") as f:
|
| 518 |
+
_torch_persistent_save(obj, f)
|
| 519 |
+
# if PathManager.supports_rename(filename):
|
| 520 |
+
# # do atomic save
|
| 521 |
+
# with PathManager.open(filename + ".tmp", "wb") as f:
|
| 522 |
+
# _torch_persistent_save(obj, f)
|
| 523 |
+
# PathManager.rename(filename + ".tmp", filename)
|
| 524 |
+
# else:
|
| 525 |
+
# # fallback to non-atomic save
|
| 526 |
+
# with PathManager.open(filename, "wb") as f:
|
| 527 |
+
# _torch_persistent_save(obj, f)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def _torch_persistent_save(obj, f):
|
| 531 |
+
if isinstance(f, str):
|
| 532 |
+
with PathManager.open(f, "wb") as h:
|
| 533 |
+
torch_persistent_save(obj, h)
|
| 534 |
+
return
|
| 535 |
+
for i in range(3):
|
| 536 |
+
try:
|
| 537 |
+
return torch.save(obj, f)
|
| 538 |
+
except Exception:
|
| 539 |
+
if i == 2:
|
| 540 |
+
logger.error(traceback.format_exc())
|
| 541 |
+
raise
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def _upgrade_state_dict(state):
|
| 545 |
+
"""Helper for upgrading old model checkpoints."""
|
| 546 |
+
|
| 547 |
+
# add optimizer_history
|
| 548 |
+
if "optimizer_history" not in state:
|
| 549 |
+
state["optimizer_history"] = [
|
| 550 |
+
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
| 551 |
+
]
|
| 552 |
+
state["last_optimizer_state"] = state["optimizer"]
|
| 553 |
+
del state["optimizer"]
|
| 554 |
+
del state["best_loss"]
|
| 555 |
+
# move extra_state into sub-dictionary
|
| 556 |
+
if "epoch" in state and "extra_state" not in state:
|
| 557 |
+
state["extra_state"] = {
|
| 558 |
+
"epoch": state["epoch"],
|
| 559 |
+
"batch_offset": state["batch_offset"],
|
| 560 |
+
"val_loss": state["val_loss"],
|
| 561 |
+
}
|
| 562 |
+
del state["epoch"]
|
| 563 |
+
del state["batch_offset"]
|
| 564 |
+
del state["val_loss"]
|
| 565 |
+
# reduce optimizer history's memory usage (only keep the last state)
|
| 566 |
+
if "optimizer" in state["optimizer_history"][-1]:
|
| 567 |
+
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
| 568 |
+
for optim_hist in state["optimizer_history"]:
|
| 569 |
+
del optim_hist["optimizer"]
|
| 570 |
+
# record the optimizer class name
|
| 571 |
+
if "optimizer_name" not in state["optimizer_history"][-1]:
|
| 572 |
+
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
| 573 |
+
# move best_loss into lr_scheduler_state
|
| 574 |
+
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
| 575 |
+
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
| 576 |
+
"best": state["optimizer_history"][-1]["best_loss"]
|
| 577 |
+
}
|
| 578 |
+
del state["optimizer_history"][-1]["best_loss"]
|
| 579 |
+
# keep track of number of updates
|
| 580 |
+
if "num_updates" not in state["optimizer_history"][-1]:
|
| 581 |
+
state["optimizer_history"][-1]["num_updates"] = 0
|
| 582 |
+
# old model checkpoints may not have separate source/target positions
|
| 583 |
+
if (
|
| 584 |
+
"args" in state
|
| 585 |
+
and hasattr(state["args"], "max_positions")
|
| 586 |
+
and not hasattr(state["args"], "max_source_positions")
|
| 587 |
+
):
|
| 588 |
+
state["args"].max_source_positions = state["args"].max_positions
|
| 589 |
+
state["args"].max_target_positions = state["args"].max_positions
|
| 590 |
+
# use stateful training data iterator
|
| 591 |
+
if "train_iterator" not in state["extra_state"]:
|
| 592 |
+
state["extra_state"]["train_iterator"] = {
|
| 593 |
+
"epoch": state["extra_state"]["epoch"],
|
| 594 |
+
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
# backward compatibility, cfg updates
|
| 598 |
+
if "args" in state and state["args"] is not None:
|
| 599 |
+
# default to translation task
|
| 600 |
+
if not hasattr(state["args"], "task"):
|
| 601 |
+
state["args"].task = "translation"
|
| 602 |
+
# --raw-text and --lazy-load are deprecated
|
| 603 |
+
if getattr(state["args"], "raw_text", False):
|
| 604 |
+
state["args"].dataset_impl = "raw"
|
| 605 |
+
elif getattr(state["args"], "lazy_load", False):
|
| 606 |
+
state["args"].dataset_impl = "lazy"
|
| 607 |
+
# epochs start at 1
|
| 608 |
+
if state["extra_state"]["train_iterator"] is not None:
|
| 609 |
+
state["extra_state"]["train_iterator"]["epoch"] = max(
|
| 610 |
+
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
| 611 |
+
)
|
| 612 |
+
# --remove-bpe ==> --postprocess
|
| 613 |
+
if hasattr(state["args"], "remove_bpe"):
|
| 614 |
+
state["args"].post_process = state["args"].remove_bpe
|
| 615 |
+
# --min-lr ==> --stop-min-lr
|
| 616 |
+
if hasattr(state["args"], "min_lr"):
|
| 617 |
+
state["args"].stop_min_lr = state["args"].min_lr
|
| 618 |
+
del state["args"].min_lr
|
| 619 |
+
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
| 620 |
+
if (
|
| 621 |
+
hasattr(state["args"], "criterion")
|
| 622 |
+
and state["args"].criterion in [
|
| 623 |
+
"binary_cross_entropy",
|
| 624 |
+
"kd_binary_cross_entropy",
|
| 625 |
+
]
|
| 626 |
+
):
|
| 627 |
+
state["args"].criterion = "wav2vec"
|
| 628 |
+
# remove log_keys if it's None (criteria will supply a default value of [])
|
| 629 |
+
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
| 630 |
+
delattr(state["args"], "log_keys")
|
| 631 |
+
# speech_pretraining => audio pretraining
|
| 632 |
+
if (
|
| 633 |
+
hasattr(state["args"], "task")
|
| 634 |
+
and state["args"].task == "speech_pretraining"
|
| 635 |
+
):
|
| 636 |
+
state["args"].task = "audio_pretraining"
|
| 637 |
+
# audio_cpc => wav2vec
|
| 638 |
+
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
|
| 639 |
+
state["args"].arch = "wav2vec"
|
| 640 |
+
# convert legacy float learning rate to List[float]
|
| 641 |
+
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
|
| 642 |
+
state["args"].lr = [state["args"].lr]
|
| 643 |
+
# convert task data arg to a string instead of List[string]
|
| 644 |
+
if (
|
| 645 |
+
hasattr(state["args"], "data")
|
| 646 |
+
and isinstance(state["args"].data, list)
|
| 647 |
+
and len(state["args"].data) > 0
|
| 648 |
+
):
|
| 649 |
+
state["args"].data = state["args"].data[0]
|
| 650 |
+
# remove keys in state["args"] related to teacher-student learning
|
| 651 |
+
for key in [
|
| 652 |
+
"static_teachers",
|
| 653 |
+
"static_teacher_weights",
|
| 654 |
+
"dynamic_teachers",
|
| 655 |
+
"dynamic_teacher_weights",
|
| 656 |
+
]:
|
| 657 |
+
if key in state["args"]:
|
| 658 |
+
delattr(state["args"], key)
|
| 659 |
+
|
| 660 |
+
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
| 661 |
+
|
| 662 |
+
if "cfg" in state and state["cfg"] is not None:
|
| 663 |
+
cfg = state["cfg"]
|
| 664 |
+
with open_dict(cfg):
|
| 665 |
+
# any upgrades for Hydra-based configs
|
| 666 |
+
if (
|
| 667 |
+
"task" in cfg
|
| 668 |
+
and "eval_wer_config" in cfg.task
|
| 669 |
+
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
|
| 670 |
+
):
|
| 671 |
+
cfg.task.eval_wer_config.print_alignment = "hard"
|
| 672 |
+
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
| 673 |
+
cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None
|
| 674 |
+
if (
|
| 675 |
+
"model" in cfg
|
| 676 |
+
and "w2v_args" in cfg.model
|
| 677 |
+
and cfg.model.w2v_args is not None
|
| 678 |
+
and (
|
| 679 |
+
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
|
| 680 |
+
)
|
| 681 |
+
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
|
| 682 |
+
and cfg.model.w2v_args.task.eval_wer_config is not None
|
| 683 |
+
and isinstance(
|
| 684 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
|
| 685 |
+
)
|
| 686 |
+
):
|
| 687 |
+
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
|
| 688 |
+
|
| 689 |
+
return state
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
| 693 |
+
"""Prune the given state_dict if desired for LayerDrop
|
| 694 |
+
(https://arxiv.org/abs/1909.11556).
|
| 695 |
+
|
| 696 |
+
Training with LayerDrop allows models to be robust to pruning at inference
|
| 697 |
+
time. This function prunes state_dict to allow smaller models to be loaded
|
| 698 |
+
from a larger model and re-maps the existing state_dict for this to occur.
|
| 699 |
+
|
| 700 |
+
It's called by functions that load models from checkpoints and does not
|
| 701 |
+
need to be called directly.
|
| 702 |
+
"""
|
| 703 |
+
arch = None
|
| 704 |
+
if model_cfg is not None:
|
| 705 |
+
arch = (
|
| 706 |
+
model_cfg._name
|
| 707 |
+
if isinstance(model_cfg, DictConfig)
|
| 708 |
+
else getattr(model_cfg, "arch", None)
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
if not model_cfg or arch is None or arch == "ptt_transformer":
|
| 712 |
+
# args should not be none, but don't crash if it is.
|
| 713 |
+
return state_dict
|
| 714 |
+
|
| 715 |
+
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
| 716 |
+
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
| 717 |
+
|
| 718 |
+
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
| 719 |
+
return state_dict
|
| 720 |
+
|
| 721 |
+
# apply pruning
|
| 722 |
+
logger.info(
|
| 723 |
+
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
def create_pruning_pass(layers_to_keep, layer_name):
|
| 727 |
+
keep_layers = sorted(
|
| 728 |
+
int(layer_string) for layer_string in layers_to_keep.split(",")
|
| 729 |
+
)
|
| 730 |
+
mapping_dict = {}
|
| 731 |
+
for i in range(len(keep_layers)):
|
| 732 |
+
mapping_dict[str(keep_layers[i])] = str(i)
|
| 733 |
+
|
| 734 |
+
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
| 735 |
+
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
| 736 |
+
|
| 737 |
+
pruning_passes = []
|
| 738 |
+
if encoder_layers_to_keep:
|
| 739 |
+
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
| 740 |
+
if decoder_layers_to_keep:
|
| 741 |
+
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
| 742 |
+
|
| 743 |
+
new_state_dict = {}
|
| 744 |
+
for layer_name in state_dict.keys():
|
| 745 |
+
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
| 746 |
+
# if layer has no number in it, it is a supporting layer, such as an
|
| 747 |
+
# embedding
|
| 748 |
+
if not match:
|
| 749 |
+
new_state_dict[layer_name] = state_dict[layer_name]
|
| 750 |
+
continue
|
| 751 |
+
|
| 752 |
+
# otherwise, layer should be pruned.
|
| 753 |
+
original_layer_number = match.group(1)
|
| 754 |
+
# figure out which mapping dict to replace from
|
| 755 |
+
for pruning_pass in pruning_passes:
|
| 756 |
+
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
| 757 |
+
"substitution_regex"
|
| 758 |
+
].search(layer_name):
|
| 759 |
+
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
| 760 |
+
substitution_match = pruning_pass["substitution_regex"].search(
|
| 761 |
+
layer_name
|
| 762 |
+
)
|
| 763 |
+
new_state_key = (
|
| 764 |
+
layer_name[: substitution_match.start(1)]
|
| 765 |
+
+ new_layer_number
|
| 766 |
+
+ layer_name[substitution_match.end(1) :]
|
| 767 |
+
)
|
| 768 |
+
new_state_dict[new_state_key] = state_dict[layer_name]
|
| 769 |
+
|
| 770 |
+
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
| 771 |
+
# This is more of "It would make it work fix" rather than a proper fix.
|
| 772 |
+
if isinstance(model_cfg, DictConfig):
|
| 773 |
+
context = open_dict(model_cfg)
|
| 774 |
+
else:
|
| 775 |
+
context = contextlib.ExitStack()
|
| 776 |
+
with context:
|
| 777 |
+
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
| 778 |
+
model_cfg.encoder_layers_to_keep = None
|
| 779 |
+
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
| 780 |
+
model_cfg.decoder_layers_to_keep = None
|
| 781 |
+
|
| 782 |
+
return new_state_dict
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def load_pretrained_component_from_model(
|
| 786 |
+
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
|
| 787 |
+
):
|
| 788 |
+
"""
|
| 789 |
+
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
| 790 |
+
provided `component` object. If state_dict fails to load, there may be a
|
| 791 |
+
mismatch in the architecture of the corresponding `component` found in the
|
| 792 |
+
`checkpoint` file.
|
| 793 |
+
"""
|
| 794 |
+
if not PathManager.exists(checkpoint):
|
| 795 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
| 796 |
+
state = load_checkpoint_to_cpu(checkpoint)
|
| 797 |
+
if isinstance(component, FairseqEncoder):
|
| 798 |
+
component_type = "encoder"
|
| 799 |
+
elif isinstance(component, FairseqDecoder):
|
| 800 |
+
component_type = "decoder"
|
| 801 |
+
else:
|
| 802 |
+
raise ValueError(
|
| 803 |
+
"component to load must be either a FairseqEncoder or "
|
| 804 |
+
"FairseqDecoder. Loading other component types are not supported."
|
| 805 |
+
)
|
| 806 |
+
component_state_dict = OrderedDict()
|
| 807 |
+
for key in state["model"].keys():
|
| 808 |
+
if key.startswith(component_type):
|
| 809 |
+
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
| 810 |
+
component_subkey = key[len(component_type) + 1 :]
|
| 811 |
+
component_state_dict[component_subkey] = state["model"][key]
|
| 812 |
+
component.load_state_dict(component_state_dict, strict=True)
|
| 813 |
+
return component
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def verify_checkpoint_directory(save_dir: str) -> None:
|
| 817 |
+
if not os.path.exists(save_dir):
|
| 818 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 819 |
+
temp_file_path = os.path.join(save_dir, "dummy")
|
| 820 |
+
try:
|
| 821 |
+
with open(temp_file_path, "w"):
|
| 822 |
+
pass
|
| 823 |
+
except OSError as e:
|
| 824 |
+
logger.warning(
|
| 825 |
+
"Unable to access checkpoint save directory: {}".format(save_dir)
|
| 826 |
+
)
|
| 827 |
+
raise e
|
| 828 |
+
else:
|
| 829 |
+
os.remove(temp_file_path)
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
def load_ema_from_checkpoint(fpath):
|
| 833 |
+
"""Loads exponential moving averaged (EMA) checkpoint from input and
|
| 834 |
+
returns a model with ema weights.
|
| 835 |
+
|
| 836 |
+
Args:
|
| 837 |
+
fpath: A string path of checkpoint to load from.
|
| 838 |
+
|
| 839 |
+
Returns:
|
| 840 |
+
A dict of string keys mapping to various values. The 'model' key
|
| 841 |
+
from the returned dict should correspond to an OrderedDict mapping
|
| 842 |
+
string parameter names to torch Tensors.
|
| 843 |
+
"""
|
| 844 |
+
params_dict = collections.OrderedDict()
|
| 845 |
+
new_state = None
|
| 846 |
+
|
| 847 |
+
with PathManager.open(fpath, 'rb') as f:
|
| 848 |
+
new_state = torch.load(
|
| 849 |
+
f,
|
| 850 |
+
map_location=(
|
| 851 |
+
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
|
| 852 |
+
),
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# EMA model is stored in a separate "extra state"
|
| 856 |
+
model_params = new_state['extra_state']['ema']
|
| 857 |
+
|
| 858 |
+
for key in list(model_params.keys()):
|
| 859 |
+
p = model_params[key]
|
| 860 |
+
if isinstance(p, torch.HalfTensor):
|
| 861 |
+
p = p.float()
|
| 862 |
+
if key not in params_dict:
|
| 863 |
+
params_dict[key] = p.clone()
|
| 864 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
| 865 |
+
else:
|
| 866 |
+
raise ValueError("Key {} is repeated in EMA model params.".format(key))
|
| 867 |
+
|
| 868 |
+
if len(params_dict) == 0:
|
| 869 |
+
raise ValueError(
|
| 870 |
+
f"Input checkpoint path '{fpath}' does not contain "
|
| 871 |
+
"ema model weights, is this model trained with EMA?"
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
new_state['model'] = params_dict
|
| 875 |
+
return new_state
|
utils/cider/pyciderevalcap/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
utils/cider/pyciderevalcap/cider/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
utils/cider/pyciderevalcap/cider/cider.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Filename: cider.py
|
| 2 |
+
#
|
| 3 |
+
#
|
| 4 |
+
# Description: Describes the class to compute the CIDEr
|
| 5 |
+
# (Consensus-Based Image Description Evaluation) Metric
|
| 6 |
+
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
|
| 7 |
+
#
|
| 8 |
+
# Creation Date: Sun Feb 8 14:16:54 2015
|
| 9 |
+
#
|
| 10 |
+
# Authors: Ramakrishna Vedantam <[email protected]> and
|
| 11 |
+
# Tsung-Yi Lin <[email protected]>
|
| 12 |
+
from __future__ import absolute_import
|
| 13 |
+
from __future__ import division
|
| 14 |
+
from __future__ import print_function
|
| 15 |
+
|
| 16 |
+
from .cider_scorer import CiderScorer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Cider:
|
| 20 |
+
"""
|
| 21 |
+
Main Class to compute the CIDEr metric
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, n=4, df="corpus"):
|
| 25 |
+
"""
|
| 26 |
+
Initialize the CIDEr scoring function
|
| 27 |
+
: param n (int): n-gram size
|
| 28 |
+
: param df (string): specifies where to get the IDF values from
|
| 29 |
+
takes values 'corpus', 'coco-train'
|
| 30 |
+
: return: None
|
| 31 |
+
"""
|
| 32 |
+
# set cider to sum over 1 to 4-grams
|
| 33 |
+
self._n = n
|
| 34 |
+
self._df = df
|
| 35 |
+
self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
|
| 36 |
+
|
| 37 |
+
def compute_score(self, gts, res):
|
| 38 |
+
"""
|
| 39 |
+
Main function to compute CIDEr score
|
| 40 |
+
: param gts (dict) : {image:tokenized reference sentence}
|
| 41 |
+
: param res (dict) : {image:tokenized candidate sentence}
|
| 42 |
+
: return: cider (float) : computed CIDEr score for the corpus
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
# clear all the previous hypos and refs
|
| 46 |
+
self.cider_scorer.clear()
|
| 47 |
+
|
| 48 |
+
for res_id in res:
|
| 49 |
+
|
| 50 |
+
hypo = res_id['caption']
|
| 51 |
+
ref = gts[res_id['image_id']]
|
| 52 |
+
|
| 53 |
+
# Sanity check.
|
| 54 |
+
assert(type(hypo) is list)
|
| 55 |
+
assert(len(hypo) == 1)
|
| 56 |
+
assert(type(ref) is list)
|
| 57 |
+
assert(len(ref) > 0)
|
| 58 |
+
self.cider_scorer += (hypo[0], ref)
|
| 59 |
+
|
| 60 |
+
(score, scores) = self.cider_scorer.compute_score()
|
| 61 |
+
|
| 62 |
+
return score, scores
|
| 63 |
+
|
| 64 |
+
def method(self):
|
| 65 |
+
return "CIDEr"
|
utils/cider/pyciderevalcap/cider/cider_scorer.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Tsung-Yi Lin <[email protected]>
|
| 3 |
+
# Ramakrishna Vedantam <[email protected]>
|
| 4 |
+
from __future__ import absolute_import
|
| 5 |
+
from __future__ import division
|
| 6 |
+
from __future__ import print_function
|
| 7 |
+
|
| 8 |
+
import copy
|
| 9 |
+
import six
|
| 10 |
+
from six.moves import cPickle
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
import numpy as np
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
def precook(s, n=4, out=False):
|
| 17 |
+
"""
|
| 18 |
+
Takes a string as input and returns an object that can be given to
|
| 19 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
| 20 |
+
can take string arguments as well.
|
| 21 |
+
:param s: string : sentence to be converted into ngrams
|
| 22 |
+
:param n: int : number of ngrams for which representation is calculated
|
| 23 |
+
:return: term frequency vector for occuring ngrams
|
| 24 |
+
"""
|
| 25 |
+
words = s.split()
|
| 26 |
+
counts = defaultdict(int)
|
| 27 |
+
for k in range(1,n+1):
|
| 28 |
+
for i in range(len(words)-k+1):
|
| 29 |
+
ngram = tuple(words[i:i+k])
|
| 30 |
+
counts[ngram] += 1
|
| 31 |
+
return counts
|
| 32 |
+
|
| 33 |
+
def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
|
| 34 |
+
'''Takes a list of reference sentences for a single segment
|
| 35 |
+
and returns an object that encapsulates everything that BLEU
|
| 36 |
+
needs to know about them.
|
| 37 |
+
:param refs: list of string : reference sentences for some image
|
| 38 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 39 |
+
:return: result (list of dict)
|
| 40 |
+
'''
|
| 41 |
+
return [precook(ref, n) for ref in refs]
|
| 42 |
+
|
| 43 |
+
def cook_test(test, n=4):
|
| 44 |
+
'''Takes a test sentence and returns an object that
|
| 45 |
+
encapsulates everything that BLEU needs to know about it.
|
| 46 |
+
:param test: list of string : hypothesis sentence for some image
|
| 47 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 48 |
+
:return: result (dict)
|
| 49 |
+
'''
|
| 50 |
+
return precook(test, n, True)
|
| 51 |
+
|
| 52 |
+
class CiderScorer(object):
|
| 53 |
+
"""CIDEr scorer.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def copy(self):
|
| 57 |
+
''' copy the refs.'''
|
| 58 |
+
new = CiderScorer(n=self.n)
|
| 59 |
+
new.ctest = copy.copy(self.ctest)
|
| 60 |
+
new.crefs = copy.copy(self.crefs)
|
| 61 |
+
return new
|
| 62 |
+
|
| 63 |
+
def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
|
| 64 |
+
''' singular instance '''
|
| 65 |
+
self.n = n
|
| 66 |
+
self.sigma = sigma
|
| 67 |
+
self.crefs = []
|
| 68 |
+
self.ctest = []
|
| 69 |
+
self.df_mode = df_mode
|
| 70 |
+
self.ref_len = None
|
| 71 |
+
if self.df_mode != "corpus":
|
| 72 |
+
pkl_file = cPickle.load(open(os.path.join('data', df_mode + '.p'),'rb'), **(dict(encoding='latin1') if six.PY3 else {}))
|
| 73 |
+
self.ref_len = np.log(float(pkl_file['ref_len']))
|
| 74 |
+
self.document_frequency = pkl_file['document_frequency']
|
| 75 |
+
self.cook_append(test, refs)
|
| 76 |
+
|
| 77 |
+
def clear(self):
|
| 78 |
+
self.crefs = []
|
| 79 |
+
self.ctest = []
|
| 80 |
+
|
| 81 |
+
def cook_append(self, test, refs):
|
| 82 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
| 83 |
+
|
| 84 |
+
if refs is not None:
|
| 85 |
+
self.crefs.append(cook_refs(refs))
|
| 86 |
+
if test is not None:
|
| 87 |
+
self.ctest.append(cook_test(test)) ## N.B.: -1
|
| 88 |
+
else:
|
| 89 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
| 90 |
+
|
| 91 |
+
def size(self):
|
| 92 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 93 |
+
return len(self.crefs)
|
| 94 |
+
|
| 95 |
+
def __iadd__(self, other):
|
| 96 |
+
'''add an instance (e.g., from another sentence).'''
|
| 97 |
+
|
| 98 |
+
if type(other) is tuple:
|
| 99 |
+
## avoid creating new CiderScorer instances
|
| 100 |
+
self.cook_append(other[0], other[1])
|
| 101 |
+
else:
|
| 102 |
+
self.ctest.extend(other.ctest)
|
| 103 |
+
self.crefs.extend(other.crefs)
|
| 104 |
+
|
| 105 |
+
return self
|
| 106 |
+
def compute_doc_freq(self):
|
| 107 |
+
'''
|
| 108 |
+
Compute term frequency for reference data.
|
| 109 |
+
This will be used to compute idf (inverse document frequency later)
|
| 110 |
+
The term frequency is stored in the object
|
| 111 |
+
:return: None
|
| 112 |
+
'''
|
| 113 |
+
for refs in self.crefs:
|
| 114 |
+
# refs, k ref captions of one image
|
| 115 |
+
for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
|
| 116 |
+
self.document_frequency[ngram] += 1
|
| 117 |
+
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
| 118 |
+
|
| 119 |
+
def compute_cider(self):
|
| 120 |
+
def counts2vec(cnts):
|
| 121 |
+
"""
|
| 122 |
+
Function maps counts of ngram to vector of tfidf weights.
|
| 123 |
+
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
|
| 124 |
+
The n-th entry of array denotes length of n-grams.
|
| 125 |
+
:param cnts:
|
| 126 |
+
:return: vec (array of dict), norm (array of float), length (int)
|
| 127 |
+
"""
|
| 128 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
| 129 |
+
length = 0
|
| 130 |
+
norm = [0.0 for _ in range(self.n)]
|
| 131 |
+
for (ngram,term_freq) in cnts.items():
|
| 132 |
+
# give word count 1 if it doesn't appear in reference corpus
|
| 133 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
| 134 |
+
# ngram index
|
| 135 |
+
n = len(ngram)-1
|
| 136 |
+
# tf (term_freq) * idf (precomputed idf) for n-grams
|
| 137 |
+
vec[n][ngram] = float(term_freq)*(self.ref_len - df)
|
| 138 |
+
# compute norm for the vector. the norm will be used for
|
| 139 |
+
# computing similarity
|
| 140 |
+
norm[n] += pow(vec[n][ngram], 2)
|
| 141 |
+
|
| 142 |
+
if n == 1:
|
| 143 |
+
length += term_freq
|
| 144 |
+
norm = [np.sqrt(n) for n in norm]
|
| 145 |
+
return vec, norm, length
|
| 146 |
+
|
| 147 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
| 148 |
+
'''
|
| 149 |
+
Compute the cosine similarity of two vectors.
|
| 150 |
+
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
|
| 151 |
+
:param vec_ref: array of dictionary for vector corresponding to reference
|
| 152 |
+
:param norm_hyp: array of float for vector corresponding to hypothesis
|
| 153 |
+
:param norm_ref: array of float for vector corresponding to reference
|
| 154 |
+
:param length_hyp: int containing length of hypothesis
|
| 155 |
+
:param length_ref: int containing length of reference
|
| 156 |
+
:return: array of score for each n-grams cosine similarity
|
| 157 |
+
'''
|
| 158 |
+
delta = float(length_hyp - length_ref)
|
| 159 |
+
# measure consine similarity
|
| 160 |
+
val = np.array([0.0 for _ in range(self.n)])
|
| 161 |
+
for n in range(self.n):
|
| 162 |
+
# ngram
|
| 163 |
+
for (ngram,count) in vec_hyp[n].items():
|
| 164 |
+
val[n] += vec_hyp[n][ngram] * vec_ref[n][ngram]
|
| 165 |
+
|
| 166 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
| 167 |
+
val[n] /= (norm_hyp[n]*norm_ref[n])
|
| 168 |
+
|
| 169 |
+
assert(not math.isnan(val[n]))
|
| 170 |
+
return val
|
| 171 |
+
|
| 172 |
+
# compute log reference length
|
| 173 |
+
if self.df_mode == "corpus":
|
| 174 |
+
self.ref_len = np.log(float(len(self.crefs)))
|
| 175 |
+
|
| 176 |
+
scores = []
|
| 177 |
+
for test, refs in zip(self.ctest, self.crefs):
|
| 178 |
+
# compute vector for test captions
|
| 179 |
+
vec, norm, length = counts2vec(test)
|
| 180 |
+
# compute vector for ref captions
|
| 181 |
+
score = np.array([0.0 for _ in range(self.n)])
|
| 182 |
+
for ref in refs:
|
| 183 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
| 184 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
| 185 |
+
# change by vrama91 - mean of ngram scores, instead of sum
|
| 186 |
+
score_avg = np.mean(score)
|
| 187 |
+
# divide by number of references
|
| 188 |
+
score_avg /= len(refs)
|
| 189 |
+
# multiply score by 10
|
| 190 |
+
score_avg *= 10.0
|
| 191 |
+
# append score of an image to the score list
|
| 192 |
+
scores.append(score_avg)
|
| 193 |
+
return scores
|
| 194 |
+
|
| 195 |
+
def compute_score(self, option=None, verbose=0):
|
| 196 |
+
# compute idf
|
| 197 |
+
if self.df_mode == "corpus":
|
| 198 |
+
self.document_frequency = defaultdict(float)
|
| 199 |
+
self.compute_doc_freq()
|
| 200 |
+
# assert to check document frequency
|
| 201 |
+
assert(len(self.ctest) >= max(self.document_frequency.values()))
|
| 202 |
+
# import json for now and write the corresponding files
|
| 203 |
+
# compute cider score
|
| 204 |
+
score = self.compute_cider()
|
| 205 |
+
# debug
|
| 206 |
+
# print score
|
| 207 |
+
return np.mean(np.array(score)), np.array(score)
|
utils/cider/pyciderevalcap/ciderD/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
utils/cider/pyciderevalcap/ciderD/ciderD.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Filename: ciderD.py
|
| 2 |
+
#
|
| 3 |
+
# Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric
|
| 4 |
+
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
|
| 5 |
+
#
|
| 6 |
+
# Creation Date: Sun Feb 8 14:16:54 2015
|
| 7 |
+
#
|
| 8 |
+
# Authors: Ramakrishna Vedantam <[email protected]> and Tsung-Yi Lin <[email protected]>
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import division
|
| 11 |
+
from __future__ import print_function
|
| 12 |
+
|
| 13 |
+
from .ciderD_scorer import CiderScorer
|
| 14 |
+
import pdb
|
| 15 |
+
|
| 16 |
+
class CiderD:
|
| 17 |
+
"""
|
| 18 |
+
Main Class to compute the CIDEr metric
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, n=4, sigma=6.0, df="corpus"):
|
| 22 |
+
# set cider to sum over 1 to 4-grams
|
| 23 |
+
self._n = n
|
| 24 |
+
# set the standard deviation parameter for gaussian penalty
|
| 25 |
+
self._sigma = sigma
|
| 26 |
+
# set which where to compute document frequencies from
|
| 27 |
+
self._df = df
|
| 28 |
+
self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
|
| 29 |
+
|
| 30 |
+
def compute_score(self, gts, res):
|
| 31 |
+
"""
|
| 32 |
+
Main function to compute CIDEr score
|
| 33 |
+
:param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
|
| 34 |
+
ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
|
| 35 |
+
:return: cider (float) : computed CIDEr score for the corpus
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# clear all the previous hypos and refs
|
| 39 |
+
tmp_cider_scorer = self.cider_scorer.copy_empty()
|
| 40 |
+
tmp_cider_scorer.clear()
|
| 41 |
+
for res_id in res:
|
| 42 |
+
|
| 43 |
+
hypo = res_id['caption']
|
| 44 |
+
ref = gts[res_id['image_id']]
|
| 45 |
+
|
| 46 |
+
# Sanity check.
|
| 47 |
+
assert(type(hypo) is list)
|
| 48 |
+
assert(len(hypo) == 1)
|
| 49 |
+
assert(type(ref) is list)
|
| 50 |
+
assert(len(ref) > 0)
|
| 51 |
+
tmp_cider_scorer += (hypo[0], ref)
|
| 52 |
+
|
| 53 |
+
(score, scores) = tmp_cider_scorer.compute_score()
|
| 54 |
+
|
| 55 |
+
return score, scores
|
| 56 |
+
|
| 57 |
+
def method(self):
|
| 58 |
+
return "CIDEr-D"
|
utils/cider/pyciderevalcap/ciderD/ciderD_scorer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Tsung-Yi Lin <[email protected]>
|
| 3 |
+
# Ramakrishna Vedantam <[email protected]>
|
| 4 |
+
from __future__ import absolute_import
|
| 5 |
+
from __future__ import division
|
| 6 |
+
from __future__ import print_function
|
| 7 |
+
|
| 8 |
+
import copy
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pdb
|
| 12 |
+
import math
|
| 13 |
+
import six
|
| 14 |
+
from six.moves import cPickle
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def precook(s, n=4, out=False):
|
| 18 |
+
"""
|
| 19 |
+
Takes a string as input and returns an object that can be given to
|
| 20 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
| 21 |
+
can take string arguments as well.
|
| 22 |
+
:param s: string : sentence to be converted into ngrams
|
| 23 |
+
:param n: int : number of ngrams for which representation is calculated
|
| 24 |
+
:return: term frequency vector for occuring ngrams
|
| 25 |
+
"""
|
| 26 |
+
words = s.split()
|
| 27 |
+
counts = defaultdict(int)
|
| 28 |
+
for k in range(1,n+1):
|
| 29 |
+
for i in range(len(words)-k+1):
|
| 30 |
+
ngram = tuple(words[i:i+k])
|
| 31 |
+
counts[ngram] += 1
|
| 32 |
+
return counts
|
| 33 |
+
|
| 34 |
+
def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
|
| 35 |
+
'''Takes a list of reference sentences for a single segment
|
| 36 |
+
and returns an object that encapsulates everything that BLEU
|
| 37 |
+
needs to know about them.
|
| 38 |
+
:param refs: list of string : reference sentences for some image
|
| 39 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 40 |
+
:return: result (list of dict)
|
| 41 |
+
'''
|
| 42 |
+
return [precook(ref, n) for ref in refs]
|
| 43 |
+
|
| 44 |
+
def cook_test(test, n=4):
|
| 45 |
+
'''Takes a test sentence and returns an object that
|
| 46 |
+
encapsulates everything that BLEU needs to know about it.
|
| 47 |
+
:param test: list of string : hypothesis sentence for some image
|
| 48 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 49 |
+
:return: result (dict)
|
| 50 |
+
'''
|
| 51 |
+
return precook(test, n, True)
|
| 52 |
+
|
| 53 |
+
class CiderScorer(object):
|
| 54 |
+
"""CIDEr scorer.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def copy(self):
|
| 58 |
+
''' copy the refs.'''
|
| 59 |
+
new = CiderScorer(n=self.n)
|
| 60 |
+
new.ctest = copy.copy(self.ctest)
|
| 61 |
+
new.crefs = copy.copy(self.crefs)
|
| 62 |
+
return new
|
| 63 |
+
|
| 64 |
+
def copy_empty(self):
|
| 65 |
+
new = CiderScorer(df_mode="corpus", n=self.n, sigma=self.sigma)
|
| 66 |
+
new.df_mode = self.df_mode
|
| 67 |
+
new.ref_len = self.ref_len
|
| 68 |
+
new.document_frequency = self.document_frequency
|
| 69 |
+
return new
|
| 70 |
+
|
| 71 |
+
def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
|
| 72 |
+
''' singular instance '''
|
| 73 |
+
self.n = n
|
| 74 |
+
self.sigma = sigma
|
| 75 |
+
self.crefs = []
|
| 76 |
+
self.ctest = []
|
| 77 |
+
self.df_mode = df_mode
|
| 78 |
+
self.ref_len = None
|
| 79 |
+
if self.df_mode != "corpus":
|
| 80 |
+
pkl_file = cPickle.load(open(df_mode,'rb'), **(dict(encoding='latin1') if six.PY3 else {}))
|
| 81 |
+
self.ref_len = np.log(float(pkl_file['ref_len']))
|
| 82 |
+
self.document_frequency = pkl_file['document_frequency']
|
| 83 |
+
else:
|
| 84 |
+
self.document_frequency = None
|
| 85 |
+
self.cook_append(test, refs)
|
| 86 |
+
|
| 87 |
+
def clear(self):
|
| 88 |
+
self.crefs = []
|
| 89 |
+
self.ctest = []
|
| 90 |
+
|
| 91 |
+
def cook_append(self, test, refs):
|
| 92 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
| 93 |
+
|
| 94 |
+
if refs is not None:
|
| 95 |
+
self.crefs.append(cook_refs(refs))
|
| 96 |
+
if test is not None:
|
| 97 |
+
self.ctest.append(cook_test(test)) ## N.B.: -1
|
| 98 |
+
else:
|
| 99 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
| 100 |
+
|
| 101 |
+
def size(self):
|
| 102 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 103 |
+
return len(self.crefs)
|
| 104 |
+
|
| 105 |
+
def __iadd__(self, other):
|
| 106 |
+
'''add an instance (e.g., from another sentence).'''
|
| 107 |
+
|
| 108 |
+
if type(other) is tuple:
|
| 109 |
+
## avoid creating new CiderScorer instances
|
| 110 |
+
self.cook_append(other[0], other[1])
|
| 111 |
+
else:
|
| 112 |
+
self.ctest.extend(other.ctest)
|
| 113 |
+
self.crefs.extend(other.crefs)
|
| 114 |
+
|
| 115 |
+
return self
|
| 116 |
+
def compute_doc_freq(self):
|
| 117 |
+
'''
|
| 118 |
+
Compute term frequency for reference data.
|
| 119 |
+
This will be used to compute idf (inverse document frequency later)
|
| 120 |
+
The term frequency is stored in the object
|
| 121 |
+
:return: None
|
| 122 |
+
'''
|
| 123 |
+
for refs in self.crefs:
|
| 124 |
+
# refs, k ref captions of one image
|
| 125 |
+
for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
|
| 126 |
+
self.document_frequency[ngram] += 1
|
| 127 |
+
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
| 128 |
+
|
| 129 |
+
def compute_cider(self):
|
| 130 |
+
def counts2vec(cnts):
|
| 131 |
+
"""
|
| 132 |
+
Function maps counts of ngram to vector of tfidf weights.
|
| 133 |
+
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
|
| 134 |
+
The n-th entry of array denotes length of n-grams.
|
| 135 |
+
:param cnts:
|
| 136 |
+
:return: vec (array of dict), norm (array of float), length (int)
|
| 137 |
+
"""
|
| 138 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
| 139 |
+
length = 0
|
| 140 |
+
norm = [0.0 for _ in range(self.n)]
|
| 141 |
+
for (ngram,term_freq) in cnts.items():
|
| 142 |
+
# give word count 1 if it doesn't appear in reference corpus
|
| 143 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
| 144 |
+
# ngram index
|
| 145 |
+
n = len(ngram)-1
|
| 146 |
+
# tf (term_freq) * idf (precomputed idf) for n-grams
|
| 147 |
+
vec[n][ngram] = float(term_freq)*(self.ref_len - df)
|
| 148 |
+
# compute norm for the vector. the norm will be used for computing similarity
|
| 149 |
+
norm[n] += pow(vec[n][ngram], 2)
|
| 150 |
+
|
| 151 |
+
if n == 1:
|
| 152 |
+
length += term_freq
|
| 153 |
+
norm = [np.sqrt(n) for n in norm]
|
| 154 |
+
return vec, norm, length
|
| 155 |
+
|
| 156 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
| 157 |
+
'''
|
| 158 |
+
Compute the cosine similarity of two vectors.
|
| 159 |
+
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
|
| 160 |
+
:param vec_ref: array of dictionary for vector corresponding to reference
|
| 161 |
+
:param norm_hyp: array of float for vector corresponding to hypothesis
|
| 162 |
+
:param norm_ref: array of float for vector corresponding to reference
|
| 163 |
+
:param length_hyp: int containing length of hypothesis
|
| 164 |
+
:param length_ref: int containing length of reference
|
| 165 |
+
:return: array of score for each n-grams cosine similarity
|
| 166 |
+
'''
|
| 167 |
+
delta = float(length_hyp - length_ref)
|
| 168 |
+
# measure consine similarity
|
| 169 |
+
val = np.array([0.0 for _ in range(self.n)])
|
| 170 |
+
for n in range(self.n):
|
| 171 |
+
# ngram
|
| 172 |
+
for (ngram,count) in vec_hyp[n].items():
|
| 173 |
+
# vrama91 : added clipping
|
| 174 |
+
val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
|
| 175 |
+
|
| 176 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
| 177 |
+
val[n] /= (norm_hyp[n]*norm_ref[n])
|
| 178 |
+
|
| 179 |
+
assert(not math.isnan(val[n]))
|
| 180 |
+
# vrama91: added a length based gaussian penalty
|
| 181 |
+
val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
|
| 182 |
+
return val
|
| 183 |
+
|
| 184 |
+
# compute log reference length
|
| 185 |
+
if self.df_mode == "corpus":
|
| 186 |
+
self.ref_len = np.log(float(len(self.crefs)))
|
| 187 |
+
#elif self.df_mode == "coco-val-df":
|
| 188 |
+
# if coco option selected, use length of coco-val set
|
| 189 |
+
# self.ref_len = np.log(float(40504))
|
| 190 |
+
|
| 191 |
+
scores = []
|
| 192 |
+
for test, refs in zip(self.ctest, self.crefs):
|
| 193 |
+
# compute vector for test captions
|
| 194 |
+
vec, norm, length = counts2vec(test)
|
| 195 |
+
# compute vector for ref captions
|
| 196 |
+
score = np.array([0.0 for _ in range(self.n)])
|
| 197 |
+
for ref in refs:
|
| 198 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
| 199 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
| 200 |
+
# change by vrama91 - mean of ngram scores, instead of sum
|
| 201 |
+
score_avg = np.mean(score)
|
| 202 |
+
# divide by number of references
|
| 203 |
+
score_avg /= len(refs)
|
| 204 |
+
# multiply score by 10
|
| 205 |
+
score_avg *= 10.0
|
| 206 |
+
# append score of an image to the score list
|
| 207 |
+
scores.append(score_avg)
|
| 208 |
+
return scores
|
| 209 |
+
|
| 210 |
+
def compute_score(self, option=None, verbose=0):
|
| 211 |
+
# compute idf
|
| 212 |
+
if self.df_mode == "corpus":
|
| 213 |
+
self.document_frequency = defaultdict(float)
|
| 214 |
+
self.compute_doc_freq()
|
| 215 |
+
# assert to check document frequency
|
| 216 |
+
assert(len(self.ctest) >= max(self.document_frequency.values()))
|
| 217 |
+
# import json for now and write the corresponding files
|
| 218 |
+
# compute cider score
|
| 219 |
+
score = self.compute_cider()
|
| 220 |
+
# debug
|
| 221 |
+
# print score
|
| 222 |
+
return np.mean(np.array(score)), np.array(score)
|
utils/eval_utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import string
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from data import data_utils
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_symbols_to_strip_from_output(generator):
|
| 10 |
+
if hasattr(generator, "symbols_to_strip_from_output"):
|
| 11 |
+
return generator.symbols_to_strip_from_output
|
| 12 |
+
else:
|
| 13 |
+
return {generator.bos, generator.eos}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
|
| 17 |
+
x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
|
| 18 |
+
if bpe is not None:
|
| 19 |
+
x = bpe.decode(x)
|
| 20 |
+
if tokenizer is not None:
|
| 21 |
+
x = tokenizer.decode(x)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def eval_caption(task, generator, models, sample):
|
| 26 |
+
transtab = str.maketrans({key: None for key in string.punctuation})
|
| 27 |
+
hypos = task.inference_step(generator, models, sample)
|
| 28 |
+
results = []
|
| 29 |
+
for i, sample_id in enumerate(sample["id"].tolist()):
|
| 30 |
+
detok_hypo_str = decode_fn(hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator)
|
| 31 |
+
results.append({"image_id": str(sample_id), "caption": detok_hypo_str.translate(transtab).strip()})
|
| 32 |
+
return results, None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def eval_step(task, generator, models, sample):
|
| 36 |
+
if task.cfg._name == 'caption':
|
| 37 |
+
return eval_caption(task, generator, models, sample)
|
| 38 |
+
else:
|
| 39 |
+
raise NotImplementedError
|
utils/transforms.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
import torchvision.transforms.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def crop(image, target, region, delete=True):
|
| 11 |
+
cropped_image = F.crop(image, *region)
|
| 12 |
+
|
| 13 |
+
target = target.copy()
|
| 14 |
+
i, j, h, w = region
|
| 15 |
+
|
| 16 |
+
# should we do something wrt the original size?
|
| 17 |
+
target["size"] = torch.tensor([h, w])
|
| 18 |
+
|
| 19 |
+
fields = ["labels", "area"]
|
| 20 |
+
|
| 21 |
+
if "boxes" in target:
|
| 22 |
+
boxes = target["boxes"]
|
| 23 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
| 24 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
| 25 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
| 26 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
| 27 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
| 28 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
| 29 |
+
target["area"] = area
|
| 30 |
+
fields.append("boxes")
|
| 31 |
+
|
| 32 |
+
if "polygons" in target:
|
| 33 |
+
polygons = target["polygons"]
|
| 34 |
+
num_polygons = polygons.shape[0]
|
| 35 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
| 36 |
+
start_coord = torch.cat([torch.tensor([j, i], dtype=torch.float32)
|
| 37 |
+
for _ in range(polygons.shape[1] // 2)], dim=0)
|
| 38 |
+
cropped_boxes = polygons - start_coord
|
| 39 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(num_polygons, -1, 2), max_size)
|
| 40 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
| 41 |
+
target["polygons"] = cropped_boxes.reshape(num_polygons, -1)
|
| 42 |
+
fields.append("polygons")
|
| 43 |
+
|
| 44 |
+
if "masks" in target:
|
| 45 |
+
# FIXME should we update the area here if there are no boxes?
|
| 46 |
+
target['masks'] = target['masks'][:, i:i + h, j:j + w]
|
| 47 |
+
fields.append("masks")
|
| 48 |
+
|
| 49 |
+
# remove elements for which the boxes or masks that have zero area
|
| 50 |
+
if delete and ("boxes" in target or "masks" in target):
|
| 51 |
+
# favor boxes selection when defining which elements to keep
|
| 52 |
+
# this is compatible with previous implementation
|
| 53 |
+
if "boxes" in target:
|
| 54 |
+
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
|
| 55 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
| 56 |
+
else:
|
| 57 |
+
keep = target['masks'].flatten(1).any(1)
|
| 58 |
+
|
| 59 |
+
for field in fields:
|
| 60 |
+
target[field] = target[field][keep.tolist()]
|
| 61 |
+
|
| 62 |
+
return cropped_image, target
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def hflip(image, target):
|
| 66 |
+
flipped_image = F.hflip(image)
|
| 67 |
+
|
| 68 |
+
w, h = image.size
|
| 69 |
+
|
| 70 |
+
target = target.copy()
|
| 71 |
+
if "boxes" in target:
|
| 72 |
+
boxes = target["boxes"]
|
| 73 |
+
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
| 74 |
+
target["boxes"] = boxes
|
| 75 |
+
|
| 76 |
+
if "polygons" in target:
|
| 77 |
+
polygons = target["polygons"]
|
| 78 |
+
num_polygons = polygons.shape[0]
|
| 79 |
+
polygons = polygons.reshape(num_polygons, -1, 2) * torch.as_tensor([-1, 1]) + torch.as_tensor([w, 0])
|
| 80 |
+
target["polygons"] = polygons
|
| 81 |
+
|
| 82 |
+
if "masks" in target:
|
| 83 |
+
target['masks'] = target['masks'].flip(-1)
|
| 84 |
+
|
| 85 |
+
return flipped_image, target
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def resize(image, target, size, max_size=None):
|
| 89 |
+
# size can be min_size (scalar) or (w, h) tuple
|
| 90 |
+
|
| 91 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| 92 |
+
w, h = image_size
|
| 93 |
+
|
| 94 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
| 95 |
+
if max_size is not None:
|
| 96 |
+
max_size = int(max_size)
|
| 97 |
+
h = min(h, max_size)
|
| 98 |
+
w = min(w, max_size)
|
| 99 |
+
return (h, w)
|
| 100 |
+
|
| 101 |
+
if w < h:
|
| 102 |
+
ow = size
|
| 103 |
+
oh = int(size * h / w)
|
| 104 |
+
else:
|
| 105 |
+
oh = size
|
| 106 |
+
ow = int(size * w / h)
|
| 107 |
+
|
| 108 |
+
if max_size is not None:
|
| 109 |
+
max_size = int(max_size)
|
| 110 |
+
oh = min(oh, max_size)
|
| 111 |
+
ow = min(ow, max_size)
|
| 112 |
+
|
| 113 |
+
return (oh, ow)
|
| 114 |
+
|
| 115 |
+
def get_size(image_size, size, max_size=None):
|
| 116 |
+
if isinstance(size, (list, tuple)):
|
| 117 |
+
return size[::-1]
|
| 118 |
+
else:
|
| 119 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 120 |
+
|
| 121 |
+
size = get_size(image.size, size, max_size)
|
| 122 |
+
rescaled_image = F.resize(image, size, interpolation=Image.BICUBIC)
|
| 123 |
+
|
| 124 |
+
if target is None:
|
| 125 |
+
return rescaled_image
|
| 126 |
+
|
| 127 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
| 128 |
+
ratio_width, ratio_height = ratios
|
| 129 |
+
|
| 130 |
+
target = target.copy()
|
| 131 |
+
if "boxes" in target:
|
| 132 |
+
boxes = target["boxes"]
|
| 133 |
+
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
| 134 |
+
target["boxes"] = scaled_boxes
|
| 135 |
+
|
| 136 |
+
if "polygons" in target:
|
| 137 |
+
polygons = target["polygons"]
|
| 138 |
+
scaled_ratio = torch.cat([torch.tensor([ratio_width, ratio_height])
|
| 139 |
+
for _ in range(polygons.shape[1] // 2)], dim=0)
|
| 140 |
+
scaled_polygons = polygons * scaled_ratio
|
| 141 |
+
target["polygons"] = scaled_polygons
|
| 142 |
+
|
| 143 |
+
if "area" in target:
|
| 144 |
+
area = target["area"]
|
| 145 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 146 |
+
target["area"] = scaled_area
|
| 147 |
+
|
| 148 |
+
h, w = size
|
| 149 |
+
target["size"] = torch.tensor([h, w])
|
| 150 |
+
|
| 151 |
+
if "masks" in target:
|
| 152 |
+
assert False
|
| 153 |
+
# target['masks'] = interpolate(
|
| 154 |
+
# target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
| 155 |
+
|
| 156 |
+
return rescaled_image, target
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class CenterCrop(object):
|
| 160 |
+
def __init__(self, size):
|
| 161 |
+
self.size = size
|
| 162 |
+
|
| 163 |
+
def __call__(self, img, target):
|
| 164 |
+
image_width, image_height = img.size
|
| 165 |
+
crop_height, crop_width = self.size
|
| 166 |
+
crop_top = int(round((image_height - crop_height) / 2.))
|
| 167 |
+
crop_left = int(round((image_width - crop_width) / 2.))
|
| 168 |
+
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ObjectCenterCrop(object):
|
| 172 |
+
def __init__(self, size):
|
| 173 |
+
self.size = size
|
| 174 |
+
|
| 175 |
+
def __call__(self, img, target):
|
| 176 |
+
image_width, image_height = img.size
|
| 177 |
+
crop_height, crop_width = self.size
|
| 178 |
+
|
| 179 |
+
x0 = float(target['boxes'][0][0])
|
| 180 |
+
y0 = float(target['boxes'][0][1])
|
| 181 |
+
x1 = float(target['boxes'][0][2])
|
| 182 |
+
y1 = float(target['boxes'][0][3])
|
| 183 |
+
|
| 184 |
+
center_x = (x0 + x1) / 2
|
| 185 |
+
center_y = (y0 + y1) / 2
|
| 186 |
+
crop_left = max(center_x-crop_width/2 + min(image_width-center_x-crop_width/2, 0), 0)
|
| 187 |
+
crop_top = max(center_y-crop_height/2 + min(image_height-center_y-crop_height/2, 0), 0)
|
| 188 |
+
|
| 189 |
+
return crop(img, target, (crop_top, crop_left, crop_height, crop_width), delete=False)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class RandomHorizontalFlip(object):
|
| 193 |
+
def __init__(self, p=0.5):
|
| 194 |
+
self.p = p
|
| 195 |
+
|
| 196 |
+
def __call__(self, img, target):
|
| 197 |
+
if random.random() < self.p:
|
| 198 |
+
return hflip(img, target)
|
| 199 |
+
return img, target
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class RandomResize(object):
|
| 203 |
+
def __init__(self, sizes, max_size=None, equal=False):
|
| 204 |
+
assert isinstance(sizes, (list, tuple))
|
| 205 |
+
self.sizes = sizes
|
| 206 |
+
self.max_size = max_size
|
| 207 |
+
self.equal = equal
|
| 208 |
+
|
| 209 |
+
def __call__(self, img, target=None):
|
| 210 |
+
size = random.choice(self.sizes)
|
| 211 |
+
if self.equal:
|
| 212 |
+
return resize(img, target, size, size)
|
| 213 |
+
else:
|
| 214 |
+
return resize(img, target, size, self.max_size)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class ToTensor(object):
|
| 218 |
+
def __call__(self, img, target):
|
| 219 |
+
return F.to_tensor(img), target
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class Normalize(object):
|
| 223 |
+
def __init__(self, mean, std, max_image_size=512):
|
| 224 |
+
self.mean = mean
|
| 225 |
+
self.std = std
|
| 226 |
+
self.max_image_size = max_image_size
|
| 227 |
+
|
| 228 |
+
def __call__(self, image, target=None):
|
| 229 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 230 |
+
if target is None:
|
| 231 |
+
return image, None
|
| 232 |
+
target = target.copy()
|
| 233 |
+
# h, w = image.shape[-2:]
|
| 234 |
+
h, w = target["size"][0], target["size"][1]
|
| 235 |
+
if "boxes" in target:
|
| 236 |
+
boxes = target["boxes"]
|
| 237 |
+
boxes = boxes / self.max_image_size
|
| 238 |
+
target["boxes"] = boxes
|
| 239 |
+
if "polygons" in target:
|
| 240 |
+
polygons = target["polygons"]
|
| 241 |
+
scale = torch.cat([torch.tensor([w, h], dtype=torch.float32)
|
| 242 |
+
for _ in range(polygons.shape[1] // 2)], dim=0)
|
| 243 |
+
polygons = polygons / scale
|
| 244 |
+
target["polygons"] = polygons
|
| 245 |
+
return image, target
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class Compose(object):
|
| 249 |
+
def __init__(self, transforms):
|
| 250 |
+
self.transforms = transforms
|
| 251 |
+
|
| 252 |
+
def __call__(self, image, target):
|
| 253 |
+
for t in self.transforms:
|
| 254 |
+
image, target = t(image, target)
|
| 255 |
+
return image, target
|
| 256 |
+
|
| 257 |
+
def __repr__(self):
|
| 258 |
+
format_string = self.__class__.__name__ + "("
|
| 259 |
+
for t in self.transforms:
|
| 260 |
+
format_string += "\n"
|
| 261 |
+
format_string += " {0}".format(t)
|
| 262 |
+
format_string += "\n)"
|
| 263 |
+
return format_string
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class LargeScaleJitter(object):
|
| 267 |
+
"""
|
| 268 |
+
implementation of large scale jitter from copy_paste
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0):
|
| 272 |
+
self.desired_size = torch.tensor([output_size])
|
| 273 |
+
self.aug_scale_min = aug_scale_min
|
| 274 |
+
self.aug_scale_max = aug_scale_max
|
| 275 |
+
|
| 276 |
+
def rescale_target(self, scaled_size, image_size, target):
|
| 277 |
+
# compute rescaled targets
|
| 278 |
+
image_scale = scaled_size / image_size
|
| 279 |
+
ratio_height, ratio_width = image_scale
|
| 280 |
+
|
| 281 |
+
target = target.copy()
|
| 282 |
+
target["size"] = scaled_size
|
| 283 |
+
|
| 284 |
+
if "boxes" in target:
|
| 285 |
+
boxes = target["boxes"]
|
| 286 |
+
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
| 287 |
+
target["boxes"] = scaled_boxes
|
| 288 |
+
|
| 289 |
+
if "area" in target:
|
| 290 |
+
area = target["area"]
|
| 291 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 292 |
+
target["area"] = scaled_area
|
| 293 |
+
|
| 294 |
+
if "masks" in target:
|
| 295 |
+
assert False
|
| 296 |
+
masks = target['masks']
|
| 297 |
+
# masks = interpolate(
|
| 298 |
+
# masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5
|
| 299 |
+
target['masks'] = masks
|
| 300 |
+
return target
|
| 301 |
+
|
| 302 |
+
def crop_target(self, region, target):
|
| 303 |
+
i, j, h, w = region
|
| 304 |
+
fields = ["labels", "area"]
|
| 305 |
+
|
| 306 |
+
target = target.copy()
|
| 307 |
+
target["size"] = torch.tensor([h, w])
|
| 308 |
+
|
| 309 |
+
if "boxes" in target:
|
| 310 |
+
boxes = target["boxes"]
|
| 311 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
| 312 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
| 313 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
| 314 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
| 315 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
| 316 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
| 317 |
+
target["area"] = area
|
| 318 |
+
fields.append("boxes")
|
| 319 |
+
|
| 320 |
+
if "masks" in target:
|
| 321 |
+
# FIXME should we update the area here if there are no boxes?
|
| 322 |
+
target['masks'] = target['masks'][:, i:i + h, j:j + w]
|
| 323 |
+
fields.append("masks")
|
| 324 |
+
|
| 325 |
+
# remove elements for which the boxes or masks that have zero area
|
| 326 |
+
if "boxes" in target or "masks" in target:
|
| 327 |
+
# favor boxes selection when defining which elements to keep
|
| 328 |
+
# this is compatible with previous implementation
|
| 329 |
+
if "boxes" in target:
|
| 330 |
+
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
|
| 331 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
| 332 |
+
else:
|
| 333 |
+
keep = target['masks'].flatten(1).any(1)
|
| 334 |
+
|
| 335 |
+
for field in fields:
|
| 336 |
+
target[field] = target[field][keep.tolist()]
|
| 337 |
+
return target
|
| 338 |
+
|
| 339 |
+
def pad_target(self, padding, target):
|
| 340 |
+
target = target.copy()
|
| 341 |
+
if "masks" in target:
|
| 342 |
+
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[1], 0, padding[0]))
|
| 343 |
+
return target
|
| 344 |
+
|
| 345 |
+
def __call__(self, image, target=None):
|
| 346 |
+
image_size = image.size
|
| 347 |
+
image_size = torch.tensor(image_size[::-1])
|
| 348 |
+
|
| 349 |
+
random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min
|
| 350 |
+
scaled_size = (random_scale * self.desired_size).round()
|
| 351 |
+
|
| 352 |
+
scale = torch.maximum(scaled_size / image_size[0], scaled_size / image_size[1])
|
| 353 |
+
scaled_size = (image_size * scale).round().int()
|
| 354 |
+
|
| 355 |
+
scaled_image = F.resize(image, scaled_size.tolist(), interpolation=Image.BICUBIC)
|
| 356 |
+
|
| 357 |
+
if target is not None:
|
| 358 |
+
target = self.rescale_target(scaled_size, image_size, target)
|
| 359 |
+
|
| 360 |
+
# randomly crop or pad images
|
| 361 |
+
if random_scale >= 1:
|
| 362 |
+
# Selects non-zero random offset (x, y) if scaled image is larger than desired_size.
|
| 363 |
+
max_offset = scaled_size - self.desired_size
|
| 364 |
+
offset = (max_offset * torch.rand(2)).floor().int()
|
| 365 |
+
region = (offset[0].item(), offset[1].item(),
|
| 366 |
+
self.desired_size[0].item(), self.desired_size[0].item())
|
| 367 |
+
output_image = F.crop(scaled_image, *region)
|
| 368 |
+
if target is not None:
|
| 369 |
+
target = self.crop_target(region, target)
|
| 370 |
+
else:
|
| 371 |
+
assert False
|
| 372 |
+
padding = self.desired_size - scaled_size
|
| 373 |
+
output_image = F.pad(scaled_image, [0, 0, padding[1].item(), padding[0].item()])
|
| 374 |
+
if target is not None:
|
| 375 |
+
target = self.pad_target(padding, target)
|
| 376 |
+
|
| 377 |
+
return output_image, target
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class OriginLargeScaleJitter(object):
|
| 381 |
+
"""
|
| 382 |
+
implementation of large scale jitter from copy_paste
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0):
|
| 386 |
+
self.desired_size = torch.tensor(output_size)
|
| 387 |
+
self.aug_scale_min = aug_scale_min
|
| 388 |
+
self.aug_scale_max = aug_scale_max
|
| 389 |
+
|
| 390 |
+
def rescale_target(self, scaled_size, image_size, target):
|
| 391 |
+
# compute rescaled targets
|
| 392 |
+
image_scale = scaled_size / image_size
|
| 393 |
+
ratio_height, ratio_width = image_scale
|
| 394 |
+
|
| 395 |
+
target = target.copy()
|
| 396 |
+
target["size"] = scaled_size
|
| 397 |
+
|
| 398 |
+
if "boxes" in target:
|
| 399 |
+
boxes = target["boxes"]
|
| 400 |
+
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
| 401 |
+
target["boxes"] = scaled_boxes
|
| 402 |
+
|
| 403 |
+
if "area" in target:
|
| 404 |
+
area = target["area"]
|
| 405 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 406 |
+
target["area"] = scaled_area
|
| 407 |
+
|
| 408 |
+
if "masks" in target:
|
| 409 |
+
assert False
|
| 410 |
+
masks = target['masks']
|
| 411 |
+
# masks = interpolate(
|
| 412 |
+
# masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5
|
| 413 |
+
target['masks'] = masks
|
| 414 |
+
return target
|
| 415 |
+
|
| 416 |
+
def crop_target(self, region, target):
|
| 417 |
+
i, j, h, w = region
|
| 418 |
+
fields = ["labels", "area"]
|
| 419 |
+
|
| 420 |
+
target = target.copy()
|
| 421 |
+
target["size"] = torch.tensor([h, w])
|
| 422 |
+
|
| 423 |
+
if "boxes" in target:
|
| 424 |
+
boxes = target["boxes"]
|
| 425 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
| 426 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
| 427 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
| 428 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
| 429 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
| 430 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
| 431 |
+
target["area"] = area
|
| 432 |
+
fields.append("boxes")
|
| 433 |
+
|
| 434 |
+
if "masks" in target:
|
| 435 |
+
# FIXME should we update the area here if there are no boxes?
|
| 436 |
+
target['masks'] = target['masks'][:, i:i + h, j:j + w]
|
| 437 |
+
fields.append("masks")
|
| 438 |
+
|
| 439 |
+
# remove elements for which the boxes or masks that have zero area
|
| 440 |
+
if "boxes" in target or "masks" in target:
|
| 441 |
+
# favor boxes selection when defining which elements to keep
|
| 442 |
+
# this is compatible with previous implementation
|
| 443 |
+
if "boxes" in target:
|
| 444 |
+
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
|
| 445 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
| 446 |
+
else:
|
| 447 |
+
keep = target['masks'].flatten(1).any(1)
|
| 448 |
+
|
| 449 |
+
for field in fields:
|
| 450 |
+
target[field] = target[field][keep.tolist()]
|
| 451 |
+
return target
|
| 452 |
+
|
| 453 |
+
def pad_target(self, padding, target):
|
| 454 |
+
target = target.copy()
|
| 455 |
+
if "masks" in target:
|
| 456 |
+
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[1], 0, padding[0]))
|
| 457 |
+
return target
|
| 458 |
+
|
| 459 |
+
def __call__(self, image, target=None):
|
| 460 |
+
image_size = image.size
|
| 461 |
+
image_size = torch.tensor(image_size[::-1])
|
| 462 |
+
|
| 463 |
+
out_desired_size = (self.desired_size * image_size / max(image_size)).round().int()
|
| 464 |
+
|
| 465 |
+
random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min
|
| 466 |
+
scaled_size = (random_scale * self.desired_size).round()
|
| 467 |
+
|
| 468 |
+
scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1])
|
| 469 |
+
scaled_size = (image_size * scale).round().int()
|
| 470 |
+
|
| 471 |
+
scaled_image = F.resize(image, scaled_size.tolist())
|
| 472 |
+
|
| 473 |
+
if target is not None:
|
| 474 |
+
target = self.rescale_target(scaled_size, image_size, target)
|
| 475 |
+
|
| 476 |
+
# randomly crop or pad images
|
| 477 |
+
if random_scale > 1:
|
| 478 |
+
# Selects non-zero random offset (x, y) if scaled image is larger than desired_size.
|
| 479 |
+
max_offset = scaled_size - out_desired_size
|
| 480 |
+
offset = (max_offset * torch.rand(2)).floor().int()
|
| 481 |
+
region = (offset[0].item(), offset[1].item(),
|
| 482 |
+
out_desired_size[0].item(), out_desired_size[1].item())
|
| 483 |
+
output_image = F.crop(scaled_image, *region)
|
| 484 |
+
if target is not None:
|
| 485 |
+
target = self.crop_target(region, target)
|
| 486 |
+
else:
|
| 487 |
+
padding = out_desired_size - scaled_size
|
| 488 |
+
output_image = F.pad(scaled_image, [0, 0, padding[1].item(), padding[0].item()])
|
| 489 |
+
if target is not None:
|
| 490 |
+
target = self.pad_target(padding, target)
|
| 491 |
+
|
| 492 |
+
return output_image, target
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class RandomDistortion(object):
|
| 496 |
+
"""
|
| 497 |
+
Distort image w.r.t hue, saturation and exposure.
|
| 498 |
+
"""
|
| 499 |
+
|
| 500 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, prob=0.5):
|
| 501 |
+
self.prob = prob
|
| 502 |
+
self.tfm = T.ColorJitter(brightness, contrast, saturation, hue)
|
| 503 |
+
|
| 504 |
+
def __call__(self, img, target=None):
|
| 505 |
+
if np.random.random() < self.prob:
|
| 506 |
+
return self.tfm(img), target
|
| 507 |
+
else:
|
| 508 |
+
return img, target
|