Upload folder using huggingface_hub
Browse files- .gitattributes +1 -35
- .gitignore +7 -0
- LICENSE +201 -0
- README.md +147 -3
- config/model_config.json +34 -0
- eval/README.md +123 -0
- eval/eval_compare_matrix.py +406 -0
- eval/install_requirements.sh +1 -0
- inference.py +85 -0
- install_requirements.sh +3 -0
- model/autoencoders.py +374 -0
- model/ear_vae.py +112 -0
- model/transformer.py +846 -0
- pretrained_weight/ear_vae_44k.pyt +3 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.pyt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pretrained_weight/*
|
| 2 |
+
results
|
| 3 |
+
data/*
|
| 4 |
+
__pycache__
|
| 5 |
+
*.pyc
|
| 6 |
+
docs
|
| 7 |
+
images
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,147 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# εar-VAE: High Fidelity Music Reconstruction Model
|
| 2 |
+
|
| 3 |
+
This repository contains the official inference code for εar-VAE, aa 44.1 kHz music signal reconstruction model that rethinks and optimizes VAE training for audio. It targets two common weaknesses in existing open-source VAEs—phase accuracy and stereophonic spatial representation—by aligning objectives with auditory perception and introducing phase-aware training. Experiments show substantial improvements across diverse metrics, with particular strength in high-frequency harmonics and spatial characteristics.
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<img src="./images/all_compares.jpg" width=90%>
|
| 7 |
+
<img src="./images/table.png" width=90%>
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+
<p align="center">
|
| 11 |
+
<em>Upper: Ablation study across our training components.</em> <em>Down: Cross-model metric comparison on the evaluation dataset.</em>
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
Why εar-VAE:
|
| 15 |
+
- 🎧 Perceptual alignment: A K-weighting perceptual filter is applied before loss computation to better match human hearing.
|
| 16 |
+
- 🔁 Phase-aware objectives: Two novel phase losses
|
| 17 |
+
- Stereo Correlation Loss for robust inter-channel coherence.
|
| 18 |
+
- Phase-Derivative Loss using Instantaneous Frequency and Group Delay for phase precision.
|
| 19 |
+
- 🌈 Spectral supervision paradigm: Magnitude supervised across MSLR (Mid/Side/Left/Right) components, while phase is supervised only by LR (Left/Right), improving stability and fidelity.
|
| 20 |
+
- 📈 44.1 kHz performance: Outperforms leading open-source models across diverse metrics.
|
| 21 |
+
|
| 22 |
+
## 1. Installation
|
| 23 |
+
|
| 24 |
+
Follow these steps to set up the environment and install the necessary dependencies.
|
| 25 |
+
|
| 26 |
+
### Installation Steps
|
| 27 |
+
|
| 28 |
+
1. **Clone the repository:**
|
| 29 |
+
```bash
|
| 30 |
+
git clone <your-repo-url>
|
| 31 |
+
cd ear_vae
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
2. **Create and activate a conda environment:**
|
| 35 |
+
```bash
|
| 36 |
+
conda create -n ear_vae python=3.8
|
| 37 |
+
conda activate ear_vae
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
3. **Run the installation script:**
|
| 41 |
+
|
| 42 |
+
This script will install the remaining dependencies.
|
| 43 |
+
```bash
|
| 44 |
+
bash install_requirements.sh
|
| 45 |
+
```
|
| 46 |
+
This will install:
|
| 47 |
+
- `descript-audio-codec`
|
| 48 |
+
- `alias-free-torch`
|
| 49 |
+
- `ffmpeg < 7` (via conda)
|
| 50 |
+
|
| 51 |
+
4. **Download the model weight:**
|
| 52 |
+
|
| 53 |
+
You could download the model checkpoint from **[Hugging Face](https://huggingface.co/earlab/EAR_VAE)**
|
| 54 |
+
## 2. Usage
|
| 55 |
+
|
| 56 |
+
The `inference.py` script is used to process audio files from an input directory and save the reconstructed audio to an output directory.
|
| 57 |
+
|
| 58 |
+
### Running Inference
|
| 59 |
+
|
| 60 |
+
You can run the inference with the following command:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
python inference.py --indir <input_directory> --outdir <output_directory> --model_path <path_to_model> --device <device>
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Command-Line Arguments
|
| 67 |
+
|
| 68 |
+
- `--indir`: (Optional) Path to the input directory containing audio files. Default: `./data`.
|
| 69 |
+
- `--outdir`: (Optional) Path to the output directory where reconstructed audio will be saved. Default: `./results`.
|
| 70 |
+
- `--model_path`: (Optional) Path to the pretrained model weights (`.pyt` file). Default: `./pretrained_weight/ear_vae_44k.pyt`.
|
| 71 |
+
- `--device`: (Optional) The device to run the model on (e.g., `cuda:0` or `cpu`). Defaults to `cuda:0` if available, otherwise `cpu`.
|
| 72 |
+
|
| 73 |
+
### Example
|
| 74 |
+
|
| 75 |
+
1. Place your input audio files (e.g., `.wav`, `.mp3`) into the `data/` directory.
|
| 76 |
+
2. Run the inference script:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python inference.py
|
| 80 |
+
```
|
| 81 |
+
This will use the default paths. The reconstructed audio files will be saved in the `results/` directory.
|
| 82 |
+
|
| 83 |
+
## 3. Project Structure
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
.
|
| 87 |
+
├── README.md # This file
|
| 88 |
+
├── config/ # For model configurations
|
| 89 |
+
│ └── model_config.json
|
| 90 |
+
├── data/ # Default directory for input audio files
|
| 91 |
+
├── eval/ # Scripts for model evaluation
|
| 92 |
+
│ ├── eval_compare_matrix.py
|
| 93 |
+
│ ├── install_requirements.sh
|
| 94 |
+
│ └── README.md
|
| 95 |
+
├── inference.py # Main script for running audio reconstruction
|
| 96 |
+
├── install_requirements.sh # Installation script for dependencies
|
| 97 |
+
├── model/ # Contains the model architecture code
|
| 98 |
+
│ ├── sa2vae.py
|
| 99 |
+
│ ├── transformer.py
|
| 100 |
+
│ └── vaegan.py
|
| 101 |
+
├── pretrained_weight/ # Directory for pretrained model weights
|
| 102 |
+
│ └── your_weight_here
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## 4. Model Details
|
| 106 |
+
|
| 107 |
+
The model is a Variational Autoencoder with a Generative Adversarial Network (VAE-GAN) structure.
|
| 108 |
+
- **Encoder**: An Oobleck-style encoder that downsamples the input audio into a latent representation.
|
| 109 |
+
- **Bottleneck**: A VAE bottleneck that introduces a probabilistic latent space, sampling from a learned mean and variance.
|
| 110 |
+
- **Decoder**: An Oobleck-style decoder that upsamples the latent representation back into an audio waveform.
|
| 111 |
+
- **Transformer**: A Continuous Transformer can optionally be placed in the bottleneck to further process the latent sequence.
|
| 112 |
+
|
| 113 |
+
This architecture allows for efficient and high-quality audio reconstruction.
|
| 114 |
+
|
| 115 |
+
## 5. Evaluation
|
| 116 |
+
|
| 117 |
+
The `eval/` directory contains scripts to evaluate the model's reconstruction performance using objective metrics.
|
| 118 |
+
|
| 119 |
+
### Evaluation Prerequisites
|
| 120 |
+
|
| 121 |
+
1. **Install Dependencies**: The evaluation script has its own set of dependencies. Install them by running the script in the `eval` directory:
|
| 122 |
+
```bash
|
| 123 |
+
bash eval/install_requirements.sh
|
| 124 |
+
```
|
| 125 |
+
This will install libraries such as `auraloss`.
|
| 126 |
+
|
| 127 |
+
2. **FFmpeg**: The script uses `ffmpeg` for loudness analysis. Make sure `ffmpeg` is installed and available in your system's PATH. You can install it via conda:
|
| 128 |
+
```bash
|
| 129 |
+
conda install -c conda-forge 'ffmpeg<7'
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Running Evaluation
|
| 133 |
+
|
| 134 |
+
The `eval_compare_matrix.py` script compares the reconstructed audio with the original ground truth files and computes various metrics.
|
| 135 |
+
|
| 136 |
+
For more details on the evaluation metrics and options, refer to the `eval/README.md` file.
|
| 137 |
+
|
| 138 |
+
## 6. Acknowledgements
|
| 139 |
+
|
| 140 |
+
This project builds upon the work of several open-source projects. We would like to extend our special thanks to:
|
| 141 |
+
|
| 142 |
+
- **[Stability AI's Stable Audio Tools](https://github.com/Stability-AI/stable-audio-tools)**: For providing a foundational framework and tools for audio generation.
|
| 143 |
+
- **[Descript's Audio Codec](https://github.com/descriptinc/descript-audio-codec)**: For the weight-normed convolusional layers
|
| 144 |
+
|
| 145 |
+
Their contributions have been invaluable to the development of εar-VAE.
|
| 146 |
+
|
| 147 |
+
|
config/model_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"transformer": {
|
| 3 |
+
"depth": 2,
|
| 4 |
+
"config": {
|
| 5 |
+
"rotary_pos_emb": true,
|
| 6 |
+
"dim_heads": 32
|
| 7 |
+
}
|
| 8 |
+
},
|
| 9 |
+
"encoder": {
|
| 10 |
+
"config": {
|
| 11 |
+
"in_channels": 2,
|
| 12 |
+
"channels": 128,
|
| 13 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 14 |
+
"strides": [2, 4, 4, 4, 8],
|
| 15 |
+
"latent_dim": 128,
|
| 16 |
+
"use_snake": true
|
| 17 |
+
}
|
| 18 |
+
},
|
| 19 |
+
"decoder": {
|
| 20 |
+
"config": {
|
| 21 |
+
"out_channels": 2,
|
| 22 |
+
"channels": 128,
|
| 23 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 24 |
+
"strides": [2, 4, 4, 4, 8],
|
| 25 |
+
"latent_dim": 64,
|
| 26 |
+
"use_nearest_upsample": false,
|
| 27 |
+
"use_snake": true,
|
| 28 |
+
"final_tanh": false
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"latent_dim": 64,
|
| 32 |
+
"downsampling_ratio": 1024,
|
| 33 |
+
"io_channels": 2
|
| 34 |
+
}
|
eval/README.md
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VAE Audio Evaluation
|
| 2 |
+
|
| 3 |
+
This directory contains the script and resources for evaluating the performance of models in audio reconstruction tasks. The primary script, `eval_compare_matrix.py`, computes a suite of objective metrics to compare the quality of audio generated by the model against the original ground truth audio.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Comprehensive Metrics**: Calculates a wide range of industry-standard and research-grade metrics:
|
| 8 |
+
- **Time-Domain**: Scale-Invariant Signal-to-Distortion Ratio (SI-SDR).
|
| 9 |
+
- **Frequency-Domain**: Multi-Resolution STFT Loss and Multi-Resolution Mel-Spectrogram Loss.
|
| 10 |
+
- **Phase**: Multi-Resolution Phase Coherence (both per-channel and inter-channel for stereo).
|
| 11 |
+
- **Loudness**: Integrated Loudness (LUFS-I), Loudness Range (LRA), and True Peak, analyzed using `ffmpeg`.
|
| 12 |
+
- **Batch Processing**: Automatically discovers and processes multiple model output directories.
|
| 13 |
+
- **File Matching**: Intelligently pairs reconstructed audio files (e.g., `*_vae_rec.wav`) with their corresponding ground truth files (e.g., `*.wav`).
|
| 14 |
+
- **Robust & Resilient**: Handles missing files, audio processing errors, and varying sample rates gracefully.
|
| 15 |
+
- **Organized Output**: Saves aggregated results in both machine-readable (`.json`) and human-readable (`.txt`) formats for each model.
|
| 16 |
+
- **Command-Line Interface**: Easy-to-use CLI for specifying the input directory and other options.
|
| 17 |
+
|
| 18 |
+
## Prerequisites
|
| 19 |
+
|
| 20 |
+
### 1. Python Environment
|
| 21 |
+
Ensure you have a Python environment (3.8 or newer recommended) with the required packages installed. You can install them using pip:
|
| 22 |
+
```bash
|
| 23 |
+
pip install torch torchaudio auraloss numpy
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 2. FFmpeg
|
| 27 |
+
The script relies on `ffmpeg` for loudness analysis. You must have `ffmpeg` installed and accessible in your system's PATH.
|
| 28 |
+
|
| 29 |
+
**On Ubuntu/Debian:**
|
| 30 |
+
```bash
|
| 31 |
+
sudo apt update && sudo apt install ffmpeg
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
**On macOS (using Homebrew):**
|
| 35 |
+
```bash
|
| 36 |
+
brew install ffmpeg
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
**On Windows:**
|
| 40 |
+
Download the executable from the [official FFmpeg website](https://ffmpeg.org/download.html) and add its `bin` directory to your system's PATH environment variable.
|
| 41 |
+
|
| 42 |
+
You can verify the installation by running:
|
| 43 |
+
```bash
|
| 44 |
+
ffmpeg -version
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
**Also On Conda ENv:**
|
| 48 |
+
```bash
|
| 49 |
+
conda install -c conda-forge 'ffmpeg<7'
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Directory Structure
|
| 53 |
+
|
| 54 |
+
The script expects a specific directory structure for the evaluation data. The root input directory should contain subdirectories, where each subdirectory represents a different model or experiment to be evaluated.
|
| 55 |
+
|
| 56 |
+
Inside each model's subdirectory, you should place the pairs of ground truth and reconstructed audio files. The script identifies pairs based on a naming convention:
|
| 57 |
+
- **Ground Truth**: `your_audio_file.wav`
|
| 58 |
+
- **Reconstructed**: `your_audio_file_vae_rec.wav`
|
| 59 |
+
|
| 60 |
+
Here is an example structure:
|
| 61 |
+
```
|
| 62 |
+
/path/to/your/evaluation_data/
|
| 63 |
+
├── model_A/
|
| 64 |
+
│ ├── song1.wav # Ground Truth 1
|
| 65 |
+
│ ├── song1_vae_rec.wav # Reconstructed 1
|
| 66 |
+
│ ├── song2.wav # Ground Truth 2
|
| 67 |
+
│ ├── song2_vae_rec.wav # Reconstructed 2
|
| 68 |
+
│ └── ...
|
| 69 |
+
├── model_B/
|
| 70 |
+
│ ├── trackA.wav
|
| 71 |
+
│ ├── trackA_vae_rec.wav
|
| 72 |
+
│ └── ...
|
| 73 |
+
└── ...
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Usage
|
| 77 |
+
|
| 78 |
+
Run the evaluation script from the command line, pointing it to the root directory containing your model outputs.
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
python eval_compare_matrix.py --input_dir /path/to/your/evaluation_data/
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Command-Line Arguments
|
| 85 |
+
|
| 86 |
+
- `--input_dir` (required): The path to the root directory containing the model folders (e.g., `/path/to/your/evaluation_data/`).
|
| 87 |
+
- `--force` (optional): If specified, the script will re-run the evaluation for all models, even if results files (`evaluation_results.json`) already exist. By default, it skips models that have already been evaluated.
|
| 88 |
+
- `--echo` (optional): If specified, the script will print the detailed evaluation metrics for each individual audio pair during processing. By default, only the progress bar and final summary are shown.
|
| 89 |
+
|
| 90 |
+
### Example
|
| 91 |
+
```bash
|
| 92 |
+
python eval/eval_compare_matrix.py --input_dir ./results/
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Output
|
| 96 |
+
|
| 97 |
+
After running, the script will generate two files inside each model's directory:
|
| 98 |
+
|
| 99 |
+
1. **`evaluation_results.json`**: A JSON file containing the aggregated average of all computed metrics. This is ideal for programmatic analysis.
|
| 100 |
+
```json
|
| 101 |
+
{
|
| 102 |
+
"model_name": "model_A",
|
| 103 |
+
"file_count": 50,
|
| 104 |
+
"avg_sisdr": 15.78,
|
| 105 |
+
"avg_mel_distance": 0.45,
|
| 106 |
+
"avg_stft_distance": 0.89,
|
| 107 |
+
"avg_per_channel_coherence": 0.95,
|
| 108 |
+
"avg_interchannel_coherence": 0.92,
|
| 109 |
+
"avg_gen_lufs-i": -14.2,
|
| 110 |
+
"avg_gt_lufs-i": -14.0,
|
| 111 |
+
...
|
| 112 |
+
}
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
2. **`evaluation_summary.txt`**: A human-readable text file summarizing the results.
|
| 116 |
+
```
|
| 117 |
+
model_name: model_A
|
| 118 |
+
file_count: 50
|
| 119 |
+
avg_sisdr: 15.78...
|
| 120 |
+
avg_mel_distance: 0.45...
|
| 121 |
+
...
|
| 122 |
+
```
|
| 123 |
+
This allows for quick inspection of a model's performance without needing to parse the JSON.
|
eval/eval_compare_matrix.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Evaluation Script
|
| 3 |
+
|
| 4 |
+
This script evaluates the quality of generated audio against ground truth audio
|
| 5 |
+
using a variety of metrics, including:
|
| 6 |
+
- SI-SDR (Scale-Invariant Signal-to-Distortion Ratio)
|
| 7 |
+
- Multi-Resolution STFT Loss
|
| 8 |
+
- Multi-Resolution Mel-Spectrogram Loss
|
| 9 |
+
- Phase Coherence (Per-channel and Inter-channel)
|
| 10 |
+
- Loudness metrics (LUFS-I, LRA, True Peak) via ffmpeg.
|
| 11 |
+
|
| 12 |
+
The script processes a directory of models, where each model directory contains
|
| 13 |
+
pairs of reconstructed (_rec.wav) and ground truth (.wav) audio files.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
import sys
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import argparse
|
| 22 |
+
import subprocess
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Dict, List, Tuple, Optional
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torchaudio
|
| 30 |
+
import auraloss
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
|
| 33 |
+
# --- Setup ---
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 37 |
+
stream=sys.stdout
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 41 |
+
SAMPLE_RATE = 44100
|
| 42 |
+
|
| 43 |
+
# --- Metric Definitions ---
|
| 44 |
+
|
| 45 |
+
# SI-SDR
|
| 46 |
+
sisdr_criteria = auraloss.time.SISDRLoss().to(DEVICE)
|
| 47 |
+
|
| 48 |
+
# Multi-Resolution Mel-Spectrogram Loss
|
| 49 |
+
mel_fft_sizes = [4096, 2048, 1024, 512]
|
| 50 |
+
mel_win_sizes = mel_fft_sizes
|
| 51 |
+
mel_hop_sizes = [i // 4 for i in mel_fft_sizes]
|
| 52 |
+
mel_criteria = auraloss.freq.MultiResolutionSTFTLoss(
|
| 53 |
+
fft_sizes=mel_fft_sizes,
|
| 54 |
+
hop_sizes=mel_hop_sizes,
|
| 55 |
+
win_lengths=mel_win_sizes,
|
| 56 |
+
sample_rate=SAMPLE_RATE,
|
| 57 |
+
scale="mel",
|
| 58 |
+
n_bins=64,
|
| 59 |
+
perceptual_weighting=True
|
| 60 |
+
).to(DEVICE)
|
| 61 |
+
|
| 62 |
+
# Multi-Resolution STFT Loss
|
| 63 |
+
fft_sizes = [4096, 2048, 1024, 512, 256, 128]
|
| 64 |
+
win_sizes = fft_sizes
|
| 65 |
+
hop_sizes = [i // 4 for i in fft_sizes]
|
| 66 |
+
stft_criteria = auraloss.freq.MultiResolutionSTFTLoss(
|
| 67 |
+
fft_sizes=fft_sizes,
|
| 68 |
+
hop_sizes=hop_sizes,
|
| 69 |
+
win_lengths=win_sizes,
|
| 70 |
+
sample_rate=SAMPLE_RATE,
|
| 71 |
+
perceptual_weighting=True
|
| 72 |
+
).to(DEVICE)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def analyze_loudness(file_path: str) -> Optional[Dict[str, float]]:
|
| 76 |
+
"""
|
| 77 |
+
Analyzes audio file loudness using ffmpeg's ebur128 filter.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
file_path: Path to the audio file.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
A dictionary with LUFS-I, LRA, and True Peak, or None on failure.
|
| 84 |
+
"""
|
| 85 |
+
if not Path(file_path).exists():
|
| 86 |
+
logging.warning(f"Loudness analysis skipped: File not found at {file_path}")
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
command = [
|
| 90 |
+
"ffmpeg",
|
| 91 |
+
"-nostats",
|
| 92 |
+
"-i", file_path,
|
| 93 |
+
"-af", "ebur128=peak=true,ametadata=mode=print:file=-",
|
| 94 |
+
"-f", "null",
|
| 95 |
+
"-"
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
|
| 100 |
+
output_text = result.stderr
|
| 101 |
+
except FileNotFoundError:
|
| 102 |
+
logging.error("ffmpeg not found. Please install ffmpeg and ensure it's in your PATH.")
|
| 103 |
+
return None
|
| 104 |
+
except subprocess.CalledProcessError as e:
|
| 105 |
+
logging.error(f"ffmpeg analysis failed for {file_path}. Error: {e.stderr}")
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
loudness_data = {}
|
| 109 |
+
|
| 110 |
+
i_match = re.search(r"^\s*I:\s*(-?[\d\.]+)\s*LUFS", output_text, re.MULTILINE)
|
| 111 |
+
if i_match:
|
| 112 |
+
loudness_data['LUFS-I'] = float(i_match.group(1))
|
| 113 |
+
|
| 114 |
+
lra_match = re.search(r"^\s*LRA:\s*([\d\.]+)\s*LU", output_text, re.MULTILINE)
|
| 115 |
+
if lra_match:
|
| 116 |
+
loudness_data['LRA'] = float(lra_match.group(1))
|
| 117 |
+
|
| 118 |
+
tp_match = re.search(r"Peak:\s*(-?[\d\.]+)\s*dBFS", output_text, re.MULTILINE)
|
| 119 |
+
if tp_match:
|
| 120 |
+
loudness_data['True Peak'] = float(tp_match.group(1))
|
| 121 |
+
|
| 122 |
+
if not loudness_data:
|
| 123 |
+
logging.warning(f"Could not parse loudness data for {file_path}.")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
return loudness_data
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class PhaseCoherenceLoss(nn.Module):
|
| 130 |
+
"""
|
| 131 |
+
Calculates phase coherence between two audio signals.
|
| 132 |
+
Adapted for stereo and multi-resolution analysis.
|
| 133 |
+
"""
|
| 134 |
+
def __init__(self, fft_size, hop_size, win_size, mag_threshold=1e-6, eps=1e-8):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.fft_size = int(fft_size)
|
| 137 |
+
self.hop_size = int(hop_size)
|
| 138 |
+
self.win_size = int(win_size)
|
| 139 |
+
self.register_buffer("window", torch.hann_window(win_size))
|
| 140 |
+
self.mag_threshold = float(mag_threshold)
|
| 141 |
+
self.eps = float(eps)
|
| 142 |
+
|
| 143 |
+
def _to_complex(self, x):
|
| 144 |
+
if torch.is_complex(x):
|
| 145 |
+
return x
|
| 146 |
+
if x.dim() >= 1 and x.size(-1) == 2:
|
| 147 |
+
return torch.complex(x[..., 0], x[..., 1])
|
| 148 |
+
raise ValueError("Input must be complex or real/imag tensor.")
|
| 149 |
+
|
| 150 |
+
def _stereo_stft(self, x):
|
| 151 |
+
if x.dim() == 2:
|
| 152 |
+
x = x.unsqueeze(0)
|
| 153 |
+
B, C, T = x.shape
|
| 154 |
+
stft = torch.stft(x.reshape(B * C, T),
|
| 155 |
+
n_fft=self.fft_size,
|
| 156 |
+
hop_length=self.hop_size,
|
| 157 |
+
win_length=self.win_size,
|
| 158 |
+
window=self.window,
|
| 159 |
+
return_complex=True)
|
| 160 |
+
return stft.view(B, C, -1, stft.size(-1))
|
| 161 |
+
|
| 162 |
+
def forward(self, pred, target):
|
| 163 |
+
pred_stft = self._stereo_stft(pred)
|
| 164 |
+
target_stft = self._stereo_stft(target)
|
| 165 |
+
|
| 166 |
+
pred_stft = self._to_complex(pred_stft)
|
| 167 |
+
target_stft = self._to_complex(target_stft)
|
| 168 |
+
|
| 169 |
+
B, C, F, T = pred_stft.shape
|
| 170 |
+
|
| 171 |
+
# magnitudes and weights
|
| 172 |
+
mag_pred = torch.abs(pred_stft)
|
| 173 |
+
mag_target = torch.abs(target_stft)
|
| 174 |
+
weights = mag_pred * mag_target
|
| 175 |
+
mask = (weights > self.mag_threshold).to(weights.dtype)
|
| 176 |
+
weights_masked = weights * mask
|
| 177 |
+
|
| 178 |
+
# phase difference Δφ = angle(pred) - angle(target)
|
| 179 |
+
delta = torch.angle(pred_stft) - torch.angle(target_stft)
|
| 180 |
+
# phasor e^{jΔφ}
|
| 181 |
+
phasor = torch.complex(torch.cos(delta), torch.sin(delta))
|
| 182 |
+
|
| 183 |
+
# weighted vector sum across frequency axis
|
| 184 |
+
num = torch.sum(weights_masked * phasor, dim=2) # [B, C, T], complex
|
| 185 |
+
den = torch.sum(weights_masked, dim=2).clamp_min(self.eps)
|
| 186 |
+
coherence_per_bin = torch.abs(num) / den
|
| 187 |
+
|
| 188 |
+
# pool across time (energy-weighted mean) -> per-channel scalar
|
| 189 |
+
# weight time pooling by per-frame energy sum to emphasize active frames
|
| 190 |
+
frame_energy = torch.sum(weights_masked, dim=2)
|
| 191 |
+
frame_energy_sum = torch.sum(frame_energy, dim=2).clamp_min(self.eps)
|
| 192 |
+
|
| 193 |
+
# energy-weighted average over time
|
| 194 |
+
coherence_chan = torch.sum(coherence_per_bin * frame_energy, dim=2) / frame_energy_sum
|
| 195 |
+
|
| 196 |
+
# mean across batch
|
| 197 |
+
per_channel_coherence = coherence_chan.mean(dim=0)
|
| 198 |
+
|
| 199 |
+
inter_coherence = None
|
| 200 |
+
if C >= 2:
|
| 201 |
+
Lp, Rp = pred_stft[:, 0], pred_stft[:, 1]
|
| 202 |
+
Lt, Rt = target_stft[:, 0], target_stft[:, 1]
|
| 203 |
+
|
| 204 |
+
# inter-channel phase: angle(L) - angle(R) <=> angle(L * conj(R))
|
| 205 |
+
inter_delta = torch.angle(Lp * torch.conj(Rp)) - torch.angle(Lt * torch.conj(Rt))
|
| 206 |
+
inter_weights = torch.abs(Lp) * torch.abs(Rp)
|
| 207 |
+
inter_mask = (inter_weights > self.mag_threshold).to(inter_weights.dtype)
|
| 208 |
+
inter_weights_masked = inter_weights * inter_mask
|
| 209 |
+
inter_phasor = torch.complex(torch.cos(inter_delta), torch.sin(inter_delta))
|
| 210 |
+
inter_num = torch.sum(inter_weights_masked * inter_phasor, dim=1)
|
| 211 |
+
inter_den = torch.sum(inter_weights_masked, dim=1).clamp_min(self.eps)
|
| 212 |
+
inter_coh_time = torch.abs(inter_num) / inter_den
|
| 213 |
+
|
| 214 |
+
# pool across time weighted by energy
|
| 215 |
+
inter_frame_energy = torch.sum(inter_weights_masked, dim=1)
|
| 216 |
+
inter_energy_sum = inter_frame_energy.sum(dim=1).clamp_min(self.eps)
|
| 217 |
+
inter_coh_b = (inter_coh_time * inter_frame_energy).sum(dim=1) / inter_energy_sum
|
| 218 |
+
inter_coherence = inter_coh_b.mean()
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
"per_channel_coherence": per_channel_coherence.detach().cpu(),
|
| 222 |
+
"interchannel_coherence": (inter_coherence.detach().cpu() if inter_coherence is not None else None),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class MultiResolutionPhaseCoherenceLoss(nn.Module):
|
| 227 |
+
def __init__(self, fft_sizes, hop_sizes, win_sizes):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.criteria = nn.ModuleList([
|
| 230 |
+
PhaseCoherenceLoss(fft, hop, win) for fft, hop, win in zip(fft_sizes, hop_sizes, win_sizes)
|
| 231 |
+
])
|
| 232 |
+
|
| 233 |
+
def forward(self, pred, target):
|
| 234 |
+
results = [criterion(pred, target) for criterion in self.criteria]
|
| 235 |
+
|
| 236 |
+
per_channel = torch.stack([r["per_channel_coherence"] for r in results]).mean(dim=0)
|
| 237 |
+
inter_items = [r["interchannel_coherence"] for r in results if r["interchannel_coherence"] is not None]
|
| 238 |
+
inter_channel = torch.stack(inter_items).mean() if inter_items else None
|
| 239 |
+
|
| 240 |
+
return {"per_channel_coherence": per_channel, "interchannel_coherence": inter_channel}
|
| 241 |
+
|
| 242 |
+
phase_coherence_criteria = MultiResolutionPhaseCoherenceLoss(
|
| 243 |
+
fft_sizes=mel_fft_sizes, hop_sizes=mel_hop_sizes, win_sizes=mel_win_sizes
|
| 244 |
+
).to(DEVICE)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def find_audio_pairs(model_path: Path) -> List[Tuple[Path, Path]]:
|
| 248 |
+
"""Finds pairs of reconstructed and ground truth audio files."""
|
| 249 |
+
rec_files = sorted(model_path.glob("*_vae_rec.wav"))
|
| 250 |
+
pairs = []
|
| 251 |
+
for rec_file in rec_files:
|
| 252 |
+
gt_file = model_path / rec_file.name.replace("_vae_rec.wav", ".wav")
|
| 253 |
+
if gt_file.exists():
|
| 254 |
+
pairs.append((rec_file, gt_file))
|
| 255 |
+
else:
|
| 256 |
+
logging.warning(f"Ground truth file not found for {rec_file.name}")
|
| 257 |
+
return pairs
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def evaluate_pair(rec_path: Path, gt_path: Path) -> Optional[Dict[str, float]]:
|
| 261 |
+
"""Evaluates a single pair of audio files."""
|
| 262 |
+
try:
|
| 263 |
+
gen_wav, gen_sr = torchaudio.load(rec_path, backend="ffmpeg")
|
| 264 |
+
gt_wav, gt_sr = torchaudio.load(gt_path, backend="ffmpeg")
|
| 265 |
+
|
| 266 |
+
if gen_sr != SAMPLE_RATE:
|
| 267 |
+
gen_wav = torchaudio.transforms.Resample(gen_sr, SAMPLE_RATE)(gen_wav)
|
| 268 |
+
if gt_sr != SAMPLE_RATE:
|
| 269 |
+
gt_wav = torchaudio.transforms.Resample(gt_sr, SAMPLE_RATE)(gt_wav)
|
| 270 |
+
|
| 271 |
+
# Trim to same length
|
| 272 |
+
if gen_wav.shape[-1] != gt_wav.shape[-1]:
|
| 273 |
+
logging.info(f"Shape Mismatched, Trimming audio files to the same length: {rec_path.name}, {gt_path.name}")
|
| 274 |
+
min_len = min(gen_wav.shape[-1], gt_wav.shape[-1])
|
| 275 |
+
gen_wav, gt_wav = gen_wav[:, :min_len], gt_wav[:, :min_len]
|
| 276 |
+
|
| 277 |
+
gen_wav, gt_wav = gen_wav.to(DEVICE).unsqueeze(0), gt_wav.to(DEVICE).unsqueeze(0)
|
| 278 |
+
|
| 279 |
+
metrics = {}
|
| 280 |
+
metrics['sisdr'] = -sisdr_criteria(gen_wav, gt_wav).item()
|
| 281 |
+
metrics['mel_distance'] = mel_criteria(gen_wav, gt_wav).item()
|
| 282 |
+
metrics['stft_distance'] = stft_criteria(gen_wav, gt_wav).item()
|
| 283 |
+
|
| 284 |
+
phase_metrics = phase_coherence_criteria(gen_wav, gt_wav)
|
| 285 |
+
metrics['per_channel_coherence'] = phase_metrics["per_channel_coherence"].mean().item()
|
| 286 |
+
if phase_metrics["interchannel_coherence"] is not None:
|
| 287 |
+
metrics['interchannel_coherence'] = phase_metrics["interchannel_coherence"].item()
|
| 288 |
+
|
| 289 |
+
return metrics
|
| 290 |
+
except Exception as e:
|
| 291 |
+
logging.error(f"Error processing pair {rec_path.name}, {gt_path.name}: {e}")
|
| 292 |
+
return None
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def process_model(model_path: Path, force_eval: bool = False, echo=True):
|
| 296 |
+
"""Processes all audio pairs for a given model."""
|
| 297 |
+
logging.info(f"Processing model: {model_path.name}")
|
| 298 |
+
results_file = model_path / "evaluation_results.json"
|
| 299 |
+
|
| 300 |
+
if results_file.exists() and not force_eval:
|
| 301 |
+
logging.info(f"Results already exist for {model_path.name}, skipping.")
|
| 302 |
+
return
|
| 303 |
+
|
| 304 |
+
audio_pairs = find_audio_pairs(model_path)
|
| 305 |
+
if not audio_pairs:
|
| 306 |
+
logging.warning(f"No valid audio pairs found for {model_path.name}.")
|
| 307 |
+
return
|
| 308 |
+
|
| 309 |
+
all_metrics = []
|
| 310 |
+
gen_loudness_data, gt_loudness_data = [], []
|
| 311 |
+
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
for rec_path, gt_path in tqdm(audio_pairs, desc=f"Evaluating {model_path.name}"):
|
| 314 |
+
pair_metrics = evaluate_pair(rec_path, gt_path)
|
| 315 |
+
if pair_metrics:
|
| 316 |
+
all_metrics.append(pair_metrics)
|
| 317 |
+
|
| 318 |
+
gen_loudness = analyze_loudness(str(rec_path))
|
| 319 |
+
if gen_loudness:
|
| 320 |
+
gen_loudness_data.append(gen_loudness)
|
| 321 |
+
|
| 322 |
+
gt_loudness = analyze_loudness(str(gt_path))
|
| 323 |
+
if gt_loudness:
|
| 324 |
+
gt_loudness_data.append(gt_loudness)
|
| 325 |
+
|
| 326 |
+
if echo:
|
| 327 |
+
logging.info(f"Metrics for {rec_path.name}: {pair_metrics}")
|
| 328 |
+
if gen_loudness:
|
| 329 |
+
logging.info(f"Generated Loudness: {gen_loudness}")
|
| 330 |
+
if gt_loudness:
|
| 331 |
+
logging.info(f"Ground Truth Loudness: {gt_loudness}")
|
| 332 |
+
|
| 333 |
+
if not all_metrics:
|
| 334 |
+
logging.warning(f"No metrics could be calculated for {model_path.name}.")
|
| 335 |
+
return
|
| 336 |
+
|
| 337 |
+
# Aggregate results
|
| 338 |
+
summary = {"model_name": model_path.name, "file_count": len(all_metrics)}
|
| 339 |
+
|
| 340 |
+
# Average objective metrics
|
| 341 |
+
metric_keys = all_metrics[0].keys()
|
| 342 |
+
for key in metric_keys:
|
| 343 |
+
valid_values = [m[key] for m in all_metrics if key in m]
|
| 344 |
+
if valid_values:
|
| 345 |
+
summary[f"avg_{key}"] = float(np.mean(valid_values))
|
| 346 |
+
|
| 347 |
+
# Average loudness metrics
|
| 348 |
+
def _avg_loudness(data: List[Dict[str, float]], prefix: str):
|
| 349 |
+
if not data: return
|
| 350 |
+
for key in data[0].keys():
|
| 351 |
+
values = [d[key] for d in data if key in d]
|
| 352 |
+
if values:
|
| 353 |
+
summary[f"avg_{prefix}_{key.lower().replace(' ', '_')}"] = float(np.mean(values))
|
| 354 |
+
|
| 355 |
+
_avg_loudness(gen_loudness_data, "gen")
|
| 356 |
+
_avg_loudness(gt_loudness_data, "gt")
|
| 357 |
+
|
| 358 |
+
# Save results
|
| 359 |
+
logging.info(f"Saving results for {model_path.name} to {results_file}")
|
| 360 |
+
with open(results_file, 'w') as f:
|
| 361 |
+
json.dump(summary, f, indent=4)
|
| 362 |
+
|
| 363 |
+
# Also save a human-readable version
|
| 364 |
+
with open(model_path / "evaluation_summary.txt", "w") as f:
|
| 365 |
+
for key, value in summary.items():
|
| 366 |
+
f.write(f"{key}: {value}\n")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def main():
|
| 370 |
+
parser = argparse.ArgumentParser(description="Run evaluation on generated audio.")
|
| 371 |
+
parser.add_argument(
|
| 372 |
+
"--input_dir",
|
| 373 |
+
type=str,
|
| 374 |
+
required=True,
|
| 375 |
+
help="Root directory containing model output folders."
|
| 376 |
+
)
|
| 377 |
+
parser.add_argument(
|
| 378 |
+
"--force",
|
| 379 |
+
action="store_true",
|
| 380 |
+
help="Force re-evaluation even if results files exist."
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--echo",
|
| 385 |
+
action="store_true",
|
| 386 |
+
help="Echo per-file metrics to console during evaluation."
|
| 387 |
+
)
|
| 388 |
+
args = parser.parse_args()
|
| 389 |
+
|
| 390 |
+
root_path = Path(args.input_dir)
|
| 391 |
+
if not root_path.is_dir():
|
| 392 |
+
logging.error(f"Input directory not found: {root_path}")
|
| 393 |
+
sys.exit(1)
|
| 394 |
+
|
| 395 |
+
model_paths = [p for p in root_path.iterdir() if p.is_dir() and not p.name.startswith('.')]
|
| 396 |
+
|
| 397 |
+
logging.info(f"Found {len(model_paths)} model(s) to evaluate: {[p.name for p in model_paths]}")
|
| 398 |
+
|
| 399 |
+
for model_path in sorted(model_paths):
|
| 400 |
+
process_model(model_path, args.force, args.echo)
|
| 401 |
+
|
| 402 |
+
logging.info("Evaluation complete.")
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
main()
|
eval/install_requirements.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip install torch torchaudio auraloss numpy
|
inference.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchaudio
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import torch
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from model.ear_vae import EAR_VAE
|
| 8 |
+
|
| 9 |
+
def main(args):
|
| 10 |
+
indir = args.indir
|
| 11 |
+
model_path = args.model_path
|
| 12 |
+
outdir = args.outdir
|
| 13 |
+
device = args.device
|
| 14 |
+
config_path = args.config
|
| 15 |
+
|
| 16 |
+
print(f"Input directory: {indir}")
|
| 17 |
+
print(f"Model path: {model_path}")
|
| 18 |
+
print(f"Output directory: {outdir}")
|
| 19 |
+
print(f"Device: {device}")
|
| 20 |
+
print(f"Config path: {config_path}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
input_path = Path(indir)
|
| 24 |
+
output_path_dir = Path(outdir)
|
| 25 |
+
output_path_dir.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
|
| 27 |
+
with open(config_path, 'r') as f:
|
| 28 |
+
vae_gan_model_config = json.load(f)
|
| 29 |
+
|
| 30 |
+
print("Loading model...")
|
| 31 |
+
model = EAR_VAE(model_config=vae_gan_model_config).to(device)
|
| 32 |
+
|
| 33 |
+
state = torch.load(model_path, map_location="cpu")
|
| 34 |
+
model.load_state_dict(state)
|
| 35 |
+
model.eval()
|
| 36 |
+
print("Model loaded successfully.")
|
| 37 |
+
|
| 38 |
+
audios = list(input_path.rglob("*"))
|
| 39 |
+
print(f"Found {len(audios)} audio files to process.")
|
| 40 |
+
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
for audio_path in tqdm(audios, desc="Processing audio files"):
|
| 43 |
+
try:
|
| 44 |
+
gt_y, sr = torchaudio.load(audio_path, backend="ffmpeg")
|
| 45 |
+
|
| 46 |
+
if len(gt_y.shape) == 1:
|
| 47 |
+
gt_y = gt_y.unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
# Resample if necessary
|
| 50 |
+
if sr != 44100:
|
| 51 |
+
resampler = torchaudio.transforms.Resample(sr, 44100).to(device)
|
| 52 |
+
gt_y = resampler(gt_y)
|
| 53 |
+
|
| 54 |
+
gt_y = gt_y.to(device, torch.float32)
|
| 55 |
+
|
| 56 |
+
# Convert to stereo if mono
|
| 57 |
+
if gt_y.shape[0] == 1:
|
| 58 |
+
gt_y = torch.cat([gt_y, gt_y], dim=0)
|
| 59 |
+
|
| 60 |
+
# Add batch dimension
|
| 61 |
+
gt_y = gt_y.unsqueeze(0)
|
| 62 |
+
|
| 63 |
+
fake_audio = model.inference(gt_y)
|
| 64 |
+
|
| 65 |
+
output_filename = f"{Path(audio_path).stem}_{Path(model_path).stem}.wav"
|
| 66 |
+
output_path = output_path_dir / output_filename
|
| 67 |
+
|
| 68 |
+
fake_audio_processed = fake_audio.squeeze(0).cpu()
|
| 69 |
+
torchaudio.save(output_path, fake_audio_processed, sample_rate=44100, backend="ffmpeg")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error processing {audio_path}: {e}")
|
| 72 |
+
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
parser = argparse.ArgumentParser(description="Run VAE-GAN audio inference.")
|
| 78 |
+
parser.add_argument('--indir', type=str, default='./data', help='Input directory for audio files.')
|
| 79 |
+
parser.add_argument('--model_path', type=str, default='./pretrained_weight/ear_vae_44k.pyt', help='Path to the pretrained model weight.')
|
| 80 |
+
parser.add_argument('--outdir', type=str, default='./results', help='Output directory for generated audio files.')
|
| 81 |
+
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to run the model on (e.g., "cuda:0" or "cpu").')
|
| 82 |
+
parser.add_argument('--config', type=str, default='./config/model_config.json', help='Path to the model config file.')
|
| 83 |
+
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
main(args)
|
install_requirements.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip install descript-audio-codec
|
| 2 |
+
python -m pip install alias-free-torch
|
| 3 |
+
conda install -c conda-forge 'ffmpeg<7'
|
model/autoencoders.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from torch import nn, pow
|
| 5 |
+
from alias_free_torch import Activation1d
|
| 6 |
+
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
| 7 |
+
from typing import Literal
|
| 8 |
+
|
| 9 |
+
def snake_beta(x, alpha, beta):
|
| 10 |
+
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
| 11 |
+
|
| 12 |
+
class SnakeBeta(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
| 15 |
+
):
|
| 16 |
+
super(SnakeBeta, self).__init__()
|
| 17 |
+
self.in_features = in_features
|
| 18 |
+
|
| 19 |
+
# initialize alpha
|
| 20 |
+
self.alpha_logscale = alpha_logscale
|
| 21 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 22 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 23 |
+
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 24 |
+
else: # linear scale alphas initialized to ones
|
| 25 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
| 26 |
+
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
| 27 |
+
|
| 28 |
+
self.alpha.requires_grad = alpha_trainable
|
| 29 |
+
self.beta.requires_grad = alpha_trainable
|
| 30 |
+
|
| 31 |
+
self.no_div_by_zero = 0.000000001
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 35 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 36 |
+
if self.alpha_logscale:
|
| 37 |
+
alpha = torch.exp(alpha)
|
| 38 |
+
beta = torch.exp(beta)
|
| 39 |
+
x = snake_beta(x, alpha, beta)
|
| 40 |
+
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def checkpoint(function, *args, **kwargs):
|
| 45 |
+
kwargs.setdefault("use_reentrant", False)
|
| 46 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_activation(
|
| 50 |
+
activation: Literal["elu", "snake", "none"], antialias=False, channels=None
|
| 51 |
+
) -> nn.Module:
|
| 52 |
+
if activation == "elu":
|
| 53 |
+
act = nn.ELU()
|
| 54 |
+
elif activation == "snake":
|
| 55 |
+
act = SnakeBeta(channels)
|
| 56 |
+
elif activation == "none":
|
| 57 |
+
act = nn.Identity()
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown activation {activation}")
|
| 60 |
+
|
| 61 |
+
if antialias:
|
| 62 |
+
act = Activation1d(act)
|
| 63 |
+
|
| 64 |
+
return act
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ResidualUnit(nn.Module):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
in_channels,
|
| 71 |
+
out_channels,
|
| 72 |
+
dilation,
|
| 73 |
+
use_snake=False,
|
| 74 |
+
antialias_activation=False,
|
| 75 |
+
bias=True,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.dilation = dilation
|
| 80 |
+
|
| 81 |
+
act = get_activation(
|
| 82 |
+
"snake" if use_snake else "elu",
|
| 83 |
+
antialias=antialias_activation,
|
| 84 |
+
channels=out_channels,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
padding = (dilation * (7 - 1)) // 2
|
| 88 |
+
|
| 89 |
+
self.layers = nn.Sequential(
|
| 90 |
+
act,
|
| 91 |
+
WNConv1d(
|
| 92 |
+
in_channels=in_channels,
|
| 93 |
+
out_channels=out_channels,
|
| 94 |
+
kernel_size=7,
|
| 95 |
+
dilation=dilation,
|
| 96 |
+
padding=padding,
|
| 97 |
+
bias=bias,
|
| 98 |
+
),
|
| 99 |
+
act,
|
| 100 |
+
WNConv1d(
|
| 101 |
+
in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
res = x
|
| 107 |
+
|
| 108 |
+
# x = checkpoint(self.layers, x)
|
| 109 |
+
x = self.layers(x)
|
| 110 |
+
|
| 111 |
+
return x + res
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class EncoderBlock(nn.Module):
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
in_channels,
|
| 118 |
+
out_channels,
|
| 119 |
+
stride,
|
| 120 |
+
use_snake=False,
|
| 121 |
+
antialias_activation=False,
|
| 122 |
+
bias=True,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
act = get_activation(
|
| 127 |
+
"snake" if use_snake else "elu",
|
| 128 |
+
antialias=antialias_activation,
|
| 129 |
+
channels=in_channels,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.layers = nn.Sequential(
|
| 133 |
+
ResidualUnit(
|
| 134 |
+
in_channels=in_channels,
|
| 135 |
+
out_channels=in_channels,
|
| 136 |
+
dilation=1,
|
| 137 |
+
use_snake=use_snake,
|
| 138 |
+
bias=bias,
|
| 139 |
+
),
|
| 140 |
+
ResidualUnit(
|
| 141 |
+
in_channels=in_channels,
|
| 142 |
+
out_channels=in_channels,
|
| 143 |
+
dilation=3,
|
| 144 |
+
use_snake=use_snake,
|
| 145 |
+
bias=bias,
|
| 146 |
+
),
|
| 147 |
+
ResidualUnit(
|
| 148 |
+
in_channels=in_channels,
|
| 149 |
+
out_channels=in_channels,
|
| 150 |
+
dilation=9,
|
| 151 |
+
use_snake=use_snake,
|
| 152 |
+
bias=bias,
|
| 153 |
+
),
|
| 154 |
+
act,
|
| 155 |
+
WNConv1d(
|
| 156 |
+
in_channels=in_channels,
|
| 157 |
+
out_channels=out_channels,
|
| 158 |
+
kernel_size=2 * stride,
|
| 159 |
+
stride=stride,
|
| 160 |
+
padding=math.ceil(stride / 2),
|
| 161 |
+
bias=bias,
|
| 162 |
+
),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
return self.layers(x)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AntiAliasUpsamplerBlock(nn.Module):
|
| 170 |
+
def __init__(self, in_channels, out_channels, stride=2, bias=True):
|
| 171 |
+
super().__init__()
|
| 172 |
+
|
| 173 |
+
self.upsample = nn.Upsample(scale_factor=stride, mode="nearest")
|
| 174 |
+
|
| 175 |
+
self.conv = WNConv1d(
|
| 176 |
+
in_channels=in_channels,
|
| 177 |
+
out_channels=out_channels,
|
| 178 |
+
kernel_size=2 * stride,
|
| 179 |
+
bias=bias,
|
| 180 |
+
padding="same",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
x = self.upsample(x)
|
| 185 |
+
x = self.conv(x)
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class DecoderBlock(nn.Module):
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
in_channels,
|
| 193 |
+
out_channels,
|
| 194 |
+
stride,
|
| 195 |
+
use_snake=False,
|
| 196 |
+
antialias_activation=False,
|
| 197 |
+
use_nearest_upsample=False,
|
| 198 |
+
bias=True,
|
| 199 |
+
):
|
| 200 |
+
super().__init__()
|
| 201 |
+
|
| 202 |
+
if use_nearest_upsample:
|
| 203 |
+
upsample_layer = AntiAliasUpsamplerBlock(
|
| 204 |
+
in_channels=in_channels, out_channels=out_channels, stride=stride, bias=bias
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
upsample_layer = WNConvTranspose1d(
|
| 208 |
+
in_channels=in_channels,
|
| 209 |
+
out_channels=out_channels,
|
| 210 |
+
kernel_size=2 * stride,
|
| 211 |
+
stride=stride,
|
| 212 |
+
padding=math.ceil(stride / 2),
|
| 213 |
+
bias=bias,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
act = get_activation(
|
| 217 |
+
"snake" if use_snake else "elu",
|
| 218 |
+
antialias=antialias_activation,
|
| 219 |
+
channels=in_channels,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
self.layers = nn.Sequential(
|
| 223 |
+
act,
|
| 224 |
+
upsample_layer,
|
| 225 |
+
ResidualUnit(
|
| 226 |
+
in_channels=out_channels,
|
| 227 |
+
out_channels=out_channels,
|
| 228 |
+
dilation=1,
|
| 229 |
+
use_snake=use_snake,
|
| 230 |
+
bias=bias,
|
| 231 |
+
),
|
| 232 |
+
ResidualUnit(
|
| 233 |
+
in_channels=out_channels,
|
| 234 |
+
out_channels=out_channels,
|
| 235 |
+
dilation=3,
|
| 236 |
+
use_snake=use_snake,
|
| 237 |
+
bias=bias,
|
| 238 |
+
),
|
| 239 |
+
ResidualUnit(
|
| 240 |
+
in_channels=out_channels,
|
| 241 |
+
out_channels=out_channels,
|
| 242 |
+
dilation=9,
|
| 243 |
+
use_snake=use_snake,
|
| 244 |
+
bias=bias,
|
| 245 |
+
),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def forward(self, x):
|
| 249 |
+
return self.layers(x)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class OobleckEncoder(nn.Module):
|
| 253 |
+
def __init__(
|
| 254 |
+
self,
|
| 255 |
+
in_channels=2,
|
| 256 |
+
channels=128,
|
| 257 |
+
latent_dim=32,
|
| 258 |
+
c_mults=[1, 2, 4, 8],
|
| 259 |
+
strides=[2, 4, 8, 8],
|
| 260 |
+
use_snake=False,
|
| 261 |
+
antialias_activation=False,
|
| 262 |
+
bias=True,
|
| 263 |
+
):
|
| 264 |
+
super().__init__()
|
| 265 |
+
|
| 266 |
+
c_mults = [1] + c_mults
|
| 267 |
+
|
| 268 |
+
self.depth = len(c_mults)
|
| 269 |
+
|
| 270 |
+
layers = [
|
| 271 |
+
WNConv1d(
|
| 272 |
+
in_channels=in_channels,
|
| 273 |
+
out_channels=c_mults[0] * channels,
|
| 274 |
+
kernel_size=7,
|
| 275 |
+
padding=3,
|
| 276 |
+
bias=bias,
|
| 277 |
+
)
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
for i in range(self.depth - 1):
|
| 281 |
+
layers += [
|
| 282 |
+
EncoderBlock(
|
| 283 |
+
in_channels=c_mults[i] * channels,
|
| 284 |
+
out_channels=c_mults[i + 1] * channels,
|
| 285 |
+
stride=strides[i],
|
| 286 |
+
use_snake=use_snake,
|
| 287 |
+
bias=bias,
|
| 288 |
+
)
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
layers += [
|
| 292 |
+
get_activation(
|
| 293 |
+
"snake" if use_snake else "elu",
|
| 294 |
+
antialias=antialias_activation,
|
| 295 |
+
channels=c_mults[-1] * channels,
|
| 296 |
+
),
|
| 297 |
+
WNConv1d(
|
| 298 |
+
in_channels=c_mults[-1] * channels,
|
| 299 |
+
out_channels=latent_dim,
|
| 300 |
+
kernel_size=3,
|
| 301 |
+
padding=1,
|
| 302 |
+
bias=bias,
|
| 303 |
+
),
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
self.layers = nn.Sequential(*layers)
|
| 307 |
+
|
| 308 |
+
def forward(self, x):
|
| 309 |
+
return self.layers(x)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class OobleckDecoder(nn.Module):
|
| 313 |
+
def __init__(
|
| 314 |
+
self,
|
| 315 |
+
out_channels=2,
|
| 316 |
+
channels=128,
|
| 317 |
+
latent_dim=32,
|
| 318 |
+
c_mults=[1, 2, 4, 8],
|
| 319 |
+
strides=[2, 4, 8, 8],
|
| 320 |
+
use_snake=False,
|
| 321 |
+
antialias_activation=False,
|
| 322 |
+
use_nearest_upsample=False,
|
| 323 |
+
final_tanh=True,
|
| 324 |
+
bias=True,
|
| 325 |
+
):
|
| 326 |
+
super().__init__()
|
| 327 |
+
|
| 328 |
+
c_mults = [1] + c_mults
|
| 329 |
+
|
| 330 |
+
self.depth = len(c_mults)
|
| 331 |
+
|
| 332 |
+
layers = [
|
| 333 |
+
WNConv1d(
|
| 334 |
+
in_channels=latent_dim,
|
| 335 |
+
out_channels=c_mults[-1] * channels,
|
| 336 |
+
kernel_size=7,
|
| 337 |
+
padding=3,
|
| 338 |
+
bias=bias,
|
| 339 |
+
),
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
for i in range(self.depth - 1, 0, -1):
|
| 343 |
+
layers += [
|
| 344 |
+
DecoderBlock(
|
| 345 |
+
in_channels=c_mults[i] * channels,
|
| 346 |
+
out_channels=c_mults[i - 1] * channels,
|
| 347 |
+
stride=strides[i - 1],
|
| 348 |
+
use_snake=use_snake,
|
| 349 |
+
antialias_activation=antialias_activation,
|
| 350 |
+
use_nearest_upsample=use_nearest_upsample,
|
| 351 |
+
bias=bias,
|
| 352 |
+
)
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
layers += [
|
| 356 |
+
get_activation(
|
| 357 |
+
"snake" if use_snake else "elu",
|
| 358 |
+
antialias=antialias_activation,
|
| 359 |
+
channels=c_mults[0] * channels,
|
| 360 |
+
),
|
| 361 |
+
WNConv1d(
|
| 362 |
+
in_channels=c_mults[0] * channels,
|
| 363 |
+
out_channels=out_channels,
|
| 364 |
+
kernel_size=7,
|
| 365 |
+
padding=3,
|
| 366 |
+
bias=False,
|
| 367 |
+
),
|
| 368 |
+
nn.Tanh() if final_tanh else nn.Identity(),
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
self.layers = nn.Sequential(*layers)
|
| 372 |
+
|
| 373 |
+
def forward(self, x):
|
| 374 |
+
return self.layers(x)
|
model/ear_vae.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from torch import Tensor, nn, no_grad
|
| 7 |
+
from .autoencoders import OobleckDecoder, OobleckEncoder
|
| 8 |
+
|
| 9 |
+
from .transformer import ContinuousTransformer
|
| 10 |
+
LRELU_SLOPE = 0.1
|
| 11 |
+
padding_mode = "zeros"
|
| 12 |
+
sample_eps = 1e-6
|
| 13 |
+
|
| 14 |
+
def vae_sample(mean, scale):
|
| 15 |
+
stdev = nn.functional.softplus(scale)
|
| 16 |
+
var = stdev * stdev + sample_eps
|
| 17 |
+
logvar = torch.log(var)
|
| 18 |
+
latents = torch.randn_like(mean) * stdev + mean
|
| 19 |
+
|
| 20 |
+
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
| 21 |
+
|
| 22 |
+
return latents, kl
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class EAR_VAE(nn.Module):
|
| 26 |
+
|
| 27 |
+
def __init__(self, model_config: dict = None):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
if model_config is None:
|
| 31 |
+
model_config = {
|
| 32 |
+
"encoder": {
|
| 33 |
+
"config": {
|
| 34 |
+
"in_channels": 2,
|
| 35 |
+
"channels": 128,
|
| 36 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 37 |
+
"strides": [2, 4, 4, 4, 8],
|
| 38 |
+
"latent_dim": 128,
|
| 39 |
+
"use_snake": True
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
"decoder": {
|
| 43 |
+
"config": {
|
| 44 |
+
"out_channels": 2,
|
| 45 |
+
"channels": 128,
|
| 46 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 47 |
+
"strides": [2, 4, 4, 4, 8],
|
| 48 |
+
"latent_dim": 64,
|
| 49 |
+
"use_nearest_upsample": False,
|
| 50 |
+
"use_snake": True,
|
| 51 |
+
"final_tanh": False,
|
| 52 |
+
},
|
| 53 |
+
},
|
| 54 |
+
"latent_dim": 64,
|
| 55 |
+
"downsampling_ratio": 1024,
|
| 56 |
+
"io_channels": 2,
|
| 57 |
+
}
|
| 58 |
+
else:
|
| 59 |
+
model_config = model_config
|
| 60 |
+
|
| 61 |
+
if model_config.get("transformer") is not None:
|
| 62 |
+
self.transformers = ContinuousTransformer(
|
| 63 |
+
dim=model_config["decoder"]["config"]["latent_dim"],
|
| 64 |
+
depth=model_config["transformer"]["depth"],
|
| 65 |
+
**model_config["transformer"].get("config", {}),
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.transformers = None
|
| 69 |
+
|
| 70 |
+
self.encoder = OobleckEncoder(**model_config["encoder"]["config"])
|
| 71 |
+
self.decoder = OobleckDecoder(**model_config["decoder"]["config"])
|
| 72 |
+
|
| 73 |
+
def forward(self, audio) -> Tensor:
|
| 74 |
+
"""
|
| 75 |
+
audio: Input audio tensor [B,C,T]
|
| 76 |
+
"""
|
| 77 |
+
status = self.encoder(audio)
|
| 78 |
+
mean, scale = status.chunk(2, dim=1)
|
| 79 |
+
z, kl = vae_sample(mean, scale)
|
| 80 |
+
|
| 81 |
+
if self.transformers is not None:
|
| 82 |
+
z = z.permute(0, 2, 1)
|
| 83 |
+
z = self.transformers(z)
|
| 84 |
+
z = z.permute(0, 2, 1)
|
| 85 |
+
|
| 86 |
+
x = self.decoder(z)
|
| 87 |
+
return x, kl
|
| 88 |
+
|
| 89 |
+
def encode(self, audio, use_sample=True):
|
| 90 |
+
x = self.encoder(audio)
|
| 91 |
+
mean, scale = x.chunk(2, dim=1)
|
| 92 |
+
if use_sample:
|
| 93 |
+
z, _ = vae_sample(mean, scale)
|
| 94 |
+
else:
|
| 95 |
+
z = mean
|
| 96 |
+
return z
|
| 97 |
+
|
| 98 |
+
def decode(self, z):
|
| 99 |
+
|
| 100 |
+
if self.transformers is not None:
|
| 101 |
+
z = z.permute(0, 2, 1)
|
| 102 |
+
z = self.transformers(z)
|
| 103 |
+
z = z.permute(0, 2, 1)
|
| 104 |
+
|
| 105 |
+
x = self.decoder(z)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
@no_grad()
|
| 109 |
+
def inference(self, audio):
|
| 110 |
+
z = self.encode(audio)
|
| 111 |
+
recon_audio = self.decode(z)
|
| 112 |
+
return recon_audio
|
model/transformer.py
ADDED
|
@@ -0,0 +1,846 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Literal
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from einops.layers.torch import Rearrange
|
| 7 |
+
from packaging import version
|
| 8 |
+
from torch import einsum, nn
|
| 9 |
+
from torch.cuda.amp import autocast
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
| 13 |
+
# flash_attn==2.3.3 is required
|
| 14 |
+
except ImportError as e:
|
| 15 |
+
print(e)
|
| 16 |
+
print('flash_attn not installed, disabling Flash Attention')
|
| 17 |
+
flash_attn_kvpacked_func = None
|
| 18 |
+
flash_attn_func = None
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import natten
|
| 22 |
+
except ImportError:
|
| 23 |
+
natten = None
|
| 24 |
+
import math
|
| 25 |
+
from functools import reduce
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FourierFeatures(nn.Module):
|
| 31 |
+
def __init__(self, in_features, out_features, std=1.):
|
| 32 |
+
super().__init__()
|
| 33 |
+
assert out_features % 2 == 0
|
| 34 |
+
self.weight = nn.Parameter(torch.randn(
|
| 35 |
+
[out_features // 2, in_features]) * std)
|
| 36 |
+
|
| 37 |
+
def forward(self, input):
|
| 38 |
+
f = 2 * math.pi * input @ self.weight.T
|
| 39 |
+
return torch.cat([f.cos(), f.sin()], dim=-1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def normalize(x, eps=1e-4):
|
| 43 |
+
dim = list(range(1, x.ndim))
|
| 44 |
+
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
| 45 |
+
alpha = np.sqrt(n.numel() / x.numel())
|
| 46 |
+
return x / torch.add(eps, n, alpha=alpha)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def checkpoint(function, *args, **kwargs):
|
| 50 |
+
kwargs.setdefault("use_reentrant", False)
|
| 51 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def create_causal_mask(i, j, device):
|
| 55 |
+
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def or_reduce(masks):
|
| 59 |
+
head, *body = masks
|
| 60 |
+
for rest in body:
|
| 61 |
+
head = head | rest
|
| 62 |
+
return head
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# positional embeddings
|
| 66 |
+
|
| 67 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
| 68 |
+
def __init__(self, dim, max_seq_len):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.scale = dim ** -0.5
|
| 71 |
+
self.max_seq_len = max_seq_len
|
| 72 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
| 73 |
+
|
| 74 |
+
def forward(self, x, pos=None, seq_start_pos=None):
|
| 75 |
+
seq_len, device = x.shape[1], x.device
|
| 76 |
+
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
| 77 |
+
|
| 78 |
+
if pos is None:
|
| 79 |
+
pos = torch.arange(seq_len, device=device)
|
| 80 |
+
|
| 81 |
+
if seq_start_pos is not None:
|
| 82 |
+
pos = (pos - seq_start_pos[..., None]).clamp(min=0)
|
| 83 |
+
|
| 84 |
+
pos_emb = self.emb(pos)
|
| 85 |
+
pos_emb = pos_emb * self.scale
|
| 86 |
+
return pos_emb
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ScaledSinusoidalEmbedding(nn.Module):
|
| 90 |
+
def __init__(self, dim, theta=10000):
|
| 91 |
+
super().__init__()
|
| 92 |
+
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
| 93 |
+
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
| 94 |
+
|
| 95 |
+
half_dim = dim // 2
|
| 96 |
+
freq_seq = torch.arange(half_dim).float() / half_dim
|
| 97 |
+
inv_freq = theta ** -freq_seq
|
| 98 |
+
self.register_buffer('inv_freq', inv_freq, persistent=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, x, pos=None, seq_start_pos=None):
|
| 101 |
+
seq_len, device = x.shape[1], x.device
|
| 102 |
+
|
| 103 |
+
if pos is None:
|
| 104 |
+
pos = torch.arange(seq_len, device=device)
|
| 105 |
+
|
| 106 |
+
if seq_start_pos is not None:
|
| 107 |
+
pos = pos - seq_start_pos[..., None]
|
| 108 |
+
|
| 109 |
+
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
| 110 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 111 |
+
return emb * self.scale
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class RotaryEmbedding(nn.Module):
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
dim,
|
| 118 |
+
use_xpos=False,
|
| 119 |
+
scale_base=512,
|
| 120 |
+
interpolation_factor=1.,
|
| 121 |
+
base=10000,
|
| 122 |
+
base_rescale_factor=1.
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 126 |
+
# has some connection to NTK literature
|
| 127 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 128 |
+
base *= base_rescale_factor ** (dim / (dim - 2))
|
| 129 |
+
|
| 130 |
+
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 131 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 132 |
+
|
| 133 |
+
assert interpolation_factor >= 1.
|
| 134 |
+
self.interpolation_factor = interpolation_factor
|
| 135 |
+
|
| 136 |
+
if not use_xpos:
|
| 137 |
+
self.register_buffer('scale', None)
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 141 |
+
|
| 142 |
+
self.scale_base = scale_base
|
| 143 |
+
self.register_buffer('scale', scale)
|
| 144 |
+
|
| 145 |
+
def forward_from_seq_len(self, seq_len):
|
| 146 |
+
device = self.inv_freq.device
|
| 147 |
+
|
| 148 |
+
t = torch.arange(seq_len, device=device)
|
| 149 |
+
return self.forward(t)
|
| 150 |
+
|
| 151 |
+
@torch.amp.autocast('cuda', enabled=False)
|
| 152 |
+
def forward(self, t):
|
| 153 |
+
device = self.inv_freq.device
|
| 154 |
+
|
| 155 |
+
t = t.to(torch.float32)
|
| 156 |
+
|
| 157 |
+
t = t / self.interpolation_factor
|
| 158 |
+
|
| 159 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
| 160 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 161 |
+
|
| 162 |
+
if self.scale is None:
|
| 163 |
+
return freqs, 1.
|
| 164 |
+
|
| 165 |
+
power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base
|
| 166 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
| 167 |
+
scale = torch.cat((scale, scale), dim=-1)
|
| 168 |
+
|
| 169 |
+
return freqs, scale
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def rotate_half(x):
|
| 173 |
+
x = rearrange(x, '... (j d) -> ... j d', j=2)
|
| 174 |
+
x1, x2 = x.unbind(dim=-2)
|
| 175 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@torch.amp.autocast('cuda', enabled=False)
|
| 179 |
+
def apply_rotary_pos_emb(t, freqs, scale=1):
|
| 180 |
+
out_dtype = t.dtype
|
| 181 |
+
|
| 182 |
+
# cast to float32 if necessary for numerical stability
|
| 183 |
+
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
| 184 |
+
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
| 185 |
+
freqs, t = freqs.to(dtype), t.to(dtype)
|
| 186 |
+
freqs = freqs[-seq_len:, :]
|
| 187 |
+
|
| 188 |
+
if t.ndim == 4 and freqs.ndim == 3:
|
| 189 |
+
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
| 190 |
+
|
| 191 |
+
# partial rotary embeddings, Wang et al. GPT-J
|
| 192 |
+
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
| 193 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
| 194 |
+
|
| 195 |
+
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
| 196 |
+
|
| 197 |
+
return torch.cat((t, t_unrotated), dim=-1)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# norms
|
| 201 |
+
class LayerNorm(nn.Module):
|
| 202 |
+
def __init__(self, dim, bias=False, fix_scale=False):
|
| 203 |
+
"""
|
| 204 |
+
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
| 205 |
+
"""
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
if fix_scale:
|
| 209 |
+
self.register_buffer("gamma", torch.ones(dim))
|
| 210 |
+
else:
|
| 211 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
| 212 |
+
|
| 213 |
+
if bias:
|
| 214 |
+
self.beta = nn.Parameter(torch.zeros(dim))
|
| 215 |
+
else:
|
| 216 |
+
self.register_buffer("beta", torch.zeros(dim))
|
| 217 |
+
|
| 218 |
+
def forward(self, x):
|
| 219 |
+
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# feedforward
|
| 223 |
+
|
| 224 |
+
class GLU(nn.Module):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
dim_in,
|
| 228 |
+
dim_out,
|
| 229 |
+
activation: Callable,
|
| 230 |
+
use_conv=False,
|
| 231 |
+
conv_kernel_size=3,
|
| 232 |
+
):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.act = activation
|
| 235 |
+
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size,
|
| 236 |
+
padding=(conv_kernel_size // 2))
|
| 237 |
+
self.use_conv = use_conv
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
if self.use_conv:
|
| 241 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 242 |
+
x = self.proj(x)
|
| 243 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 244 |
+
else:
|
| 245 |
+
x = self.proj(x)
|
| 246 |
+
|
| 247 |
+
x, gate = x.chunk(2, dim=-1)
|
| 248 |
+
return x * self.act(gate)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class FeedForward(nn.Module):
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
dim,
|
| 255 |
+
dim_out=None,
|
| 256 |
+
mult=4,
|
| 257 |
+
no_bias=False,
|
| 258 |
+
glu=True,
|
| 259 |
+
use_conv=False,
|
| 260 |
+
conv_kernel_size=3,
|
| 261 |
+
zero_init_output=True,
|
| 262 |
+
):
|
| 263 |
+
super().__init__()
|
| 264 |
+
inner_dim = int(dim * mult)
|
| 265 |
+
|
| 266 |
+
# Default to SwiGLU
|
| 267 |
+
|
| 268 |
+
activation = nn.SiLU()
|
| 269 |
+
|
| 270 |
+
dim_out = dim if dim_out is None else dim_out
|
| 271 |
+
|
| 272 |
+
if glu:
|
| 273 |
+
linear_in = GLU(dim, inner_dim, activation)
|
| 274 |
+
else:
|
| 275 |
+
linear_in = nn.Sequential(
|
| 276 |
+
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
| 277 |
+
nn.Linear(dim, inner_dim, bias=not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim,
|
| 278 |
+
conv_kernel_size, padding=(
|
| 279 |
+
conv_kernel_size // 2), bias=not no_bias),
|
| 280 |
+
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
| 281 |
+
activation
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
linear_out = nn.Linear(inner_dim, dim_out, bias=not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out,
|
| 285 |
+
conv_kernel_size,
|
| 286 |
+
padding=(
|
| 287 |
+
conv_kernel_size // 2),
|
| 288 |
+
bias=not no_bias)
|
| 289 |
+
|
| 290 |
+
# init last linear layer to 0
|
| 291 |
+
if zero_init_output:
|
| 292 |
+
nn.init.zeros_(linear_out.weight)
|
| 293 |
+
if not no_bias:
|
| 294 |
+
nn.init.zeros_(linear_out.bias)
|
| 295 |
+
|
| 296 |
+
self.ff = nn.Sequential(
|
| 297 |
+
linear_in,
|
| 298 |
+
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
| 299 |
+
linear_out,
|
| 300 |
+
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def forward(self, x):
|
| 304 |
+
return self.ff(x)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class Attention(nn.Module):
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
dim,
|
| 311 |
+
dim_heads=64,
|
| 312 |
+
dim_context=None,
|
| 313 |
+
causal=False,
|
| 314 |
+
zero_init_output=True,
|
| 315 |
+
qk_norm: Literal['l2', 'ln', 'none'] = 'none',
|
| 316 |
+
natten_kernel_size=None
|
| 317 |
+
):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.dim = dim
|
| 320 |
+
self.dim_heads = dim_heads
|
| 321 |
+
self.causal = causal
|
| 322 |
+
|
| 323 |
+
dim_kv = dim_context if dim_context is not None else dim
|
| 324 |
+
|
| 325 |
+
self.num_heads = dim // dim_heads
|
| 326 |
+
self.kv_heads = dim_kv // dim_heads
|
| 327 |
+
|
| 328 |
+
if dim_context is not None:
|
| 329 |
+
self.to_q = nn.Linear(dim, dim, bias=False)
|
| 330 |
+
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
| 331 |
+
else:
|
| 332 |
+
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
| 333 |
+
|
| 334 |
+
self.to_out = nn.Linear(dim, dim, bias=False)
|
| 335 |
+
|
| 336 |
+
if zero_init_output:
|
| 337 |
+
nn.init.zeros_(self.to_out.weight)
|
| 338 |
+
|
| 339 |
+
self.qk_norm = qk_norm
|
| 340 |
+
|
| 341 |
+
if self.qk_norm == "ln":
|
| 342 |
+
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
| 343 |
+
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
| 344 |
+
|
| 345 |
+
# Using 1d neighborhood attention
|
| 346 |
+
self.natten_kernel_size = natten_kernel_size
|
| 347 |
+
if natten_kernel_size is not None:
|
| 348 |
+
return
|
| 349 |
+
|
| 350 |
+
self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
| 351 |
+
|
| 352 |
+
self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
|
| 353 |
+
|
| 354 |
+
self.sdp_kwargs = dict(
|
| 355 |
+
enable_flash=True,
|
| 356 |
+
enable_math=True,
|
| 357 |
+
enable_mem_efficient=True
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
def flash_attn(
|
| 361 |
+
self,
|
| 362 |
+
q,
|
| 363 |
+
k,
|
| 364 |
+
v,
|
| 365 |
+
mask=None,
|
| 366 |
+
causal=None
|
| 367 |
+
):
|
| 368 |
+
batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
|
| 369 |
+
kv_heads = k.shape[1]
|
| 370 |
+
# Recommended for multi-query single-key-value attention by Tri Dao
|
| 371 |
+
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
| 372 |
+
|
| 373 |
+
if heads != kv_heads:
|
| 374 |
+
# Repeat interleave kv_heads to match q_heads
|
| 375 |
+
heads_per_kv_head = heads // kv_heads
|
| 376 |
+
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
|
| 377 |
+
|
| 378 |
+
if k.ndim == 3:
|
| 379 |
+
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
|
| 380 |
+
|
| 381 |
+
if v.ndim == 3:
|
| 382 |
+
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
|
| 383 |
+
|
| 384 |
+
causal = self.causal if causal is None else causal
|
| 385 |
+
|
| 386 |
+
if q_len == 1 and causal:
|
| 387 |
+
causal = False
|
| 388 |
+
|
| 389 |
+
if mask is not None:
|
| 390 |
+
assert mask.ndim == 4
|
| 391 |
+
mask = mask.expand(batch, heads, q_len, k_len)
|
| 392 |
+
|
| 393 |
+
# handle kv cache - this should be bypassable in updated flash attention 2
|
| 394 |
+
|
| 395 |
+
if k_len > q_len and causal:
|
| 396 |
+
causal_mask = self.create_causal_mask(q_len, k_len, device=device)
|
| 397 |
+
if mask is None:
|
| 398 |
+
mask = ~causal_mask
|
| 399 |
+
else:
|
| 400 |
+
mask = mask & ~causal_mask
|
| 401 |
+
causal = False
|
| 402 |
+
|
| 403 |
+
# manually handle causal mask, if another mask was given
|
| 404 |
+
|
| 405 |
+
row_is_entirely_masked = None
|
| 406 |
+
|
| 407 |
+
if mask is not None and causal:
|
| 408 |
+
causal_mask = self.create_causal_mask(q_len, k_len, device=device)
|
| 409 |
+
mask = mask & ~causal_mask
|
| 410 |
+
|
| 411 |
+
# protect against an entire row being masked out
|
| 412 |
+
|
| 413 |
+
row_is_entirely_masked = ~mask.any(dim=-1)
|
| 414 |
+
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
| 415 |
+
|
| 416 |
+
causal = False
|
| 417 |
+
|
| 418 |
+
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
|
| 419 |
+
out = F.scaled_dot_product_attention(
|
| 420 |
+
q, k, v,
|
| 421 |
+
attn_mask=mask,
|
| 422 |
+
is_causal=causal
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# for a row that is entirely masked out, should zero out the output of that row token
|
| 426 |
+
|
| 427 |
+
if row_is_entirely_masked is not None:
|
| 428 |
+
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
| 429 |
+
|
| 430 |
+
return out
|
| 431 |
+
|
| 432 |
+
def forward(
|
| 433 |
+
self,
|
| 434 |
+
x,
|
| 435 |
+
context=None,
|
| 436 |
+
mask=None,
|
| 437 |
+
context_mask=None,
|
| 438 |
+
rotary_pos_emb=None,
|
| 439 |
+
causal=None
|
| 440 |
+
):
|
| 441 |
+
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
| 442 |
+
|
| 443 |
+
kv_input = context if has_context else x
|
| 444 |
+
|
| 445 |
+
if hasattr(self, 'to_q'):
|
| 446 |
+
# Use separate linear projections for q and k/v
|
| 447 |
+
q = self.to_q(x)
|
| 448 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h=h)
|
| 449 |
+
|
| 450 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 451 |
+
|
| 452 |
+
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=kv_h), (k, v))
|
| 453 |
+
else:
|
| 454 |
+
# Use fused linear projection
|
| 455 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 456 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
| 457 |
+
|
| 458 |
+
# Normalize q and k for cosine sim attention
|
| 459 |
+
if self.qk_norm == "l2":
|
| 460 |
+
q = F.normalize(q, dim=-1)
|
| 461 |
+
k = F.normalize(k, dim=-1)
|
| 462 |
+
elif self.qk_norm == "ln":
|
| 463 |
+
q = self.q_norm(q)
|
| 464 |
+
k = self.k_norm(k)
|
| 465 |
+
|
| 466 |
+
if rotary_pos_emb is not None and not has_context:
|
| 467 |
+
freqs, _ = rotary_pos_emb
|
| 468 |
+
|
| 469 |
+
q_dtype = q.dtype
|
| 470 |
+
k_dtype = k.dtype
|
| 471 |
+
|
| 472 |
+
q = q.to(torch.float32)
|
| 473 |
+
k = k.to(torch.float32)
|
| 474 |
+
freqs = freqs.to(torch.float32)
|
| 475 |
+
|
| 476 |
+
q = apply_rotary_pos_emb(q, freqs)
|
| 477 |
+
k = apply_rotary_pos_emb(k, freqs)
|
| 478 |
+
|
| 479 |
+
q = q.to(q_dtype)
|
| 480 |
+
k = k.to(k_dtype)
|
| 481 |
+
|
| 482 |
+
input_mask = context_mask
|
| 483 |
+
|
| 484 |
+
if input_mask is None and not has_context:
|
| 485 |
+
input_mask = mask
|
| 486 |
+
|
| 487 |
+
# determine masking
|
| 488 |
+
masks = []
|
| 489 |
+
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
| 490 |
+
|
| 491 |
+
if input_mask is not None:
|
| 492 |
+
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
| 493 |
+
masks.append(~input_mask)
|
| 494 |
+
|
| 495 |
+
# Other masks will be added here later
|
| 496 |
+
|
| 497 |
+
if len(masks) > 0:
|
| 498 |
+
final_attn_mask = ~or_reduce(masks)
|
| 499 |
+
|
| 500 |
+
n, device = q.shape[-2], q.device
|
| 501 |
+
|
| 502 |
+
causal = self.causal if causal is None else causal
|
| 503 |
+
|
| 504 |
+
if n == 1 and causal:
|
| 505 |
+
causal = False
|
| 506 |
+
|
| 507 |
+
if self.natten_kernel_size is not None:
|
| 508 |
+
if natten is None:
|
| 509 |
+
raise ImportError('natten not installed, please install natten to use neighborhood attention')
|
| 510 |
+
|
| 511 |
+
dtype_in = q.dtype
|
| 512 |
+
q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
|
| 513 |
+
|
| 514 |
+
attn = natten.functional.natten1dqk(q, k, kernel_size=self.natten_kernel_size, dilation=1)
|
| 515 |
+
|
| 516 |
+
if final_attn_mask is not None:
|
| 517 |
+
attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
|
| 518 |
+
|
| 519 |
+
attn = F.softmax(attn, dim=-1, dtype=torch.float32)
|
| 520 |
+
|
| 521 |
+
out = natten.functional.natten1dav(attn, v, kernel_size=self.natten_kernel_size, dilation=1).to(dtype_in)
|
| 522 |
+
|
| 523 |
+
# Prioritize Flash Attention 2
|
| 524 |
+
elif self.use_fa_flash:
|
| 525 |
+
assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
|
| 526 |
+
# Flash Attention 2 requires FP16 inputs
|
| 527 |
+
fa_dtype_in = q.dtype
|
| 528 |
+
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
|
| 529 |
+
|
| 530 |
+
out = flash_attn_func(q, k, v, causal=causal)
|
| 531 |
+
|
| 532 |
+
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
| 533 |
+
|
| 534 |
+
# Fall back to PyTorch implementation
|
| 535 |
+
elif self.use_pt_flash:
|
| 536 |
+
out = self.flash_attn(q, k, v, causal=causal, mask=final_attn_mask)
|
| 537 |
+
|
| 538 |
+
else:
|
| 539 |
+
# Fall back to custom implementation
|
| 540 |
+
|
| 541 |
+
if h != kv_h:
|
| 542 |
+
# Repeat interleave kv_heads to match q_heads
|
| 543 |
+
heads_per_kv_head = h // kv_h
|
| 544 |
+
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
|
| 545 |
+
|
| 546 |
+
scale = 1. / (q.shape[-1] ** 0.5)
|
| 547 |
+
|
| 548 |
+
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
| 549 |
+
|
| 550 |
+
dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
|
| 551 |
+
|
| 552 |
+
i, j, dtype = *dots.shape[-2:], dots.dtype
|
| 553 |
+
|
| 554 |
+
mask_value = -torch.finfo(dots.dtype).max
|
| 555 |
+
|
| 556 |
+
if final_attn_mask is not None:
|
| 557 |
+
dots = dots.masked_fill(~final_attn_mask, mask_value)
|
| 558 |
+
|
| 559 |
+
if causal:
|
| 560 |
+
causal_mask = self.create_causal_mask(i, j, device=device)
|
| 561 |
+
dots = dots.masked_fill(causal_mask, mask_value)
|
| 562 |
+
|
| 563 |
+
attn = F.softmax(dots, dim=-1, dtype=torch.float32)
|
| 564 |
+
attn = attn.type(dtype)
|
| 565 |
+
|
| 566 |
+
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
| 567 |
+
|
| 568 |
+
# merge heads
|
| 569 |
+
out = rearrange(out, ' b h n d -> b n (h d)')
|
| 570 |
+
out = self.to_out(out)
|
| 571 |
+
|
| 572 |
+
if mask is not None:
|
| 573 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 574 |
+
out = out.masked_fill(~mask, 0.)
|
| 575 |
+
|
| 576 |
+
return out
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class ConformerModule(nn.Module):
|
| 580 |
+
def __init__(
|
| 581 |
+
self,
|
| 582 |
+
dim,
|
| 583 |
+
norm_kwargs={},
|
| 584 |
+
):
|
| 585 |
+
super().__init__()
|
| 586 |
+
|
| 587 |
+
self.dim = dim
|
| 588 |
+
|
| 589 |
+
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
| 590 |
+
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
| 591 |
+
self.glu = GLU(dim, dim, nn.SiLU())
|
| 592 |
+
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
| 593 |
+
self.mid_norm = LayerNorm(dim,
|
| 594 |
+
**norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
| 595 |
+
self.swish = nn.SiLU()
|
| 596 |
+
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
| 597 |
+
|
| 598 |
+
def forward(self, x):
|
| 599 |
+
x = self.in_norm(x)
|
| 600 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 601 |
+
x = self.pointwise_conv(x)
|
| 602 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 603 |
+
x = self.glu(x)
|
| 604 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 605 |
+
x = self.depthwise_conv(x)
|
| 606 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 607 |
+
x = self.mid_norm(x)
|
| 608 |
+
x = self.swish(x)
|
| 609 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 610 |
+
x = self.pointwise_conv_2(x)
|
| 611 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 612 |
+
|
| 613 |
+
return x
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class TransformerBlock(nn.Module):
|
| 617 |
+
def __init__(
|
| 618 |
+
self,
|
| 619 |
+
dim,
|
| 620 |
+
dim_heads=64,
|
| 621 |
+
cross_attend=False,
|
| 622 |
+
dim_context=None,
|
| 623 |
+
global_cond_dim=None,
|
| 624 |
+
causal=False,
|
| 625 |
+
zero_init_branch_outputs=True,
|
| 626 |
+
conformer=False,
|
| 627 |
+
layer_ix=-1,
|
| 628 |
+
remove_norms=False,
|
| 629 |
+
attn_kwargs={},
|
| 630 |
+
ff_kwargs={},
|
| 631 |
+
norm_kwargs={}
|
| 632 |
+
):
|
| 633 |
+
|
| 634 |
+
super().__init__()
|
| 635 |
+
self.dim = dim
|
| 636 |
+
self.dim_heads = dim_heads
|
| 637 |
+
self.cross_attend = cross_attend
|
| 638 |
+
self.dim_context = dim_context
|
| 639 |
+
self.causal = causal
|
| 640 |
+
|
| 641 |
+
self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
| 642 |
+
|
| 643 |
+
self.self_attn = Attention(
|
| 644 |
+
dim,
|
| 645 |
+
dim_heads=dim_heads,
|
| 646 |
+
causal=causal,
|
| 647 |
+
zero_init_output=zero_init_branch_outputs,
|
| 648 |
+
**attn_kwargs
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
if cross_attend:
|
| 652 |
+
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
| 653 |
+
self.cross_attn = Attention(
|
| 654 |
+
dim,
|
| 655 |
+
dim_heads=dim_heads,
|
| 656 |
+
dim_context=dim_context,
|
| 657 |
+
causal=causal,
|
| 658 |
+
zero_init_output=zero_init_branch_outputs,
|
| 659 |
+
**attn_kwargs
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
| 663 |
+
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
| 664 |
+
|
| 665 |
+
self.layer_ix = layer_ix
|
| 666 |
+
|
| 667 |
+
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
| 668 |
+
|
| 669 |
+
self.global_cond_dim = global_cond_dim
|
| 670 |
+
|
| 671 |
+
if global_cond_dim is not None:
|
| 672 |
+
self.to_scale_shift_gate = nn.Sequential(
|
| 673 |
+
nn.SiLU(),
|
| 674 |
+
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
| 678 |
+
# nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
| 679 |
+
|
| 680 |
+
def forward(
|
| 681 |
+
self,
|
| 682 |
+
x,
|
| 683 |
+
context=None,
|
| 684 |
+
global_cond=None,
|
| 685 |
+
mask=None,
|
| 686 |
+
context_mask=None,
|
| 687 |
+
rotary_pos_emb=None
|
| 688 |
+
):
|
| 689 |
+
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
| 690 |
+
|
| 691 |
+
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(
|
| 692 |
+
global_cond).unsqueeze(1).chunk(6, dim=-1)
|
| 693 |
+
|
| 694 |
+
# self-attention with adaLN
|
| 695 |
+
residual = x
|
| 696 |
+
x = self.pre_norm(x)
|
| 697 |
+
x = x * (1 + scale_self) + shift_self
|
| 698 |
+
x = self.self_attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb)
|
| 699 |
+
x = x * torch.sigmoid(1 - gate_self)
|
| 700 |
+
x = x + residual
|
| 701 |
+
|
| 702 |
+
if context is not None:
|
| 703 |
+
x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
|
| 704 |
+
|
| 705 |
+
if self.conformer is not None:
|
| 706 |
+
x = x + self.conformer(x)
|
| 707 |
+
|
| 708 |
+
# feedforward with adaLN
|
| 709 |
+
residual = x
|
| 710 |
+
x = self.ff_norm(x)
|
| 711 |
+
x = x * (1 + scale_ff) + shift_ff
|
| 712 |
+
x = self.ff(x)
|
| 713 |
+
x = x * torch.sigmoid(1 - gate_ff)
|
| 714 |
+
x = x + residual
|
| 715 |
+
|
| 716 |
+
else:
|
| 717 |
+
x = x + self.self_attn(self.pre_norm(x), mask=mask, rotary_pos_emb=rotary_pos_emb)
|
| 718 |
+
|
| 719 |
+
if context is not None:
|
| 720 |
+
x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
|
| 721 |
+
|
| 722 |
+
if self.conformer is not None:
|
| 723 |
+
x = x + self.conformer(x)
|
| 724 |
+
|
| 725 |
+
x = x + self.ff(self.ff_norm(x))
|
| 726 |
+
|
| 727 |
+
return x
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
class ContinuousTransformer(nn.Module):
|
| 731 |
+
def __init__(
|
| 732 |
+
self,
|
| 733 |
+
dim,
|
| 734 |
+
depth,
|
| 735 |
+
*,
|
| 736 |
+
dim_in=None,
|
| 737 |
+
dim_out=None,
|
| 738 |
+
dim_heads=64,
|
| 739 |
+
cross_attend=False,
|
| 740 |
+
cond_token_dim=None,
|
| 741 |
+
global_cond_dim=None,
|
| 742 |
+
causal=False,
|
| 743 |
+
rotary_pos_emb=True,
|
| 744 |
+
zero_init_branch_outputs=True,
|
| 745 |
+
conformer=False,
|
| 746 |
+
use_sinusoidal_emb=False,
|
| 747 |
+
use_abs_pos_emb=False,
|
| 748 |
+
abs_pos_emb_max_length=10000,
|
| 749 |
+
**kwargs
|
| 750 |
+
):
|
| 751 |
+
|
| 752 |
+
super().__init__()
|
| 753 |
+
|
| 754 |
+
self.dim = dim
|
| 755 |
+
self.depth = depth
|
| 756 |
+
self.causal = causal
|
| 757 |
+
self.layers = nn.ModuleList([])
|
| 758 |
+
|
| 759 |
+
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
| 760 |
+
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
| 761 |
+
|
| 762 |
+
if rotary_pos_emb:
|
| 763 |
+
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
| 764 |
+
else:
|
| 765 |
+
self.rotary_pos_emb = None
|
| 766 |
+
|
| 767 |
+
self.use_sinusoidal_emb = use_sinusoidal_emb
|
| 768 |
+
if use_sinusoidal_emb:
|
| 769 |
+
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
| 770 |
+
|
| 771 |
+
self.use_abs_pos_emb = use_abs_pos_emb
|
| 772 |
+
if use_abs_pos_emb:
|
| 773 |
+
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
| 774 |
+
|
| 775 |
+
for i in range(depth):
|
| 776 |
+
self.layers.append(
|
| 777 |
+
TransformerBlock(
|
| 778 |
+
dim,
|
| 779 |
+
dim_heads=dim_heads,
|
| 780 |
+
cross_attend=cross_attend,
|
| 781 |
+
dim_context=cond_token_dim,
|
| 782 |
+
global_cond_dim=global_cond_dim,
|
| 783 |
+
causal=causal,
|
| 784 |
+
zero_init_branch_outputs=zero_init_branch_outputs,
|
| 785 |
+
conformer=conformer,
|
| 786 |
+
layer_ix=i,
|
| 787 |
+
**kwargs
|
| 788 |
+
)
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
def forward(
|
| 792 |
+
self,
|
| 793 |
+
x,
|
| 794 |
+
mask=None,
|
| 795 |
+
prepend_embeds=None,
|
| 796 |
+
prepend_mask=None,
|
| 797 |
+
global_cond=None,
|
| 798 |
+
return_info=False,
|
| 799 |
+
**kwargs
|
| 800 |
+
):
|
| 801 |
+
batch, seq, device = *x.shape[:2], x.device
|
| 802 |
+
|
| 803 |
+
info = {
|
| 804 |
+
"hidden_states": [],
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
x = self.project_in(x)
|
| 808 |
+
|
| 809 |
+
if prepend_embeds is not None:
|
| 810 |
+
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
| 811 |
+
|
| 812 |
+
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
| 813 |
+
|
| 814 |
+
x = torch.cat((prepend_embeds, x), dim=-2)
|
| 815 |
+
|
| 816 |
+
if prepend_mask is not None or mask is not None:
|
| 817 |
+
mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool)
|
| 818 |
+
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length),
|
| 819 |
+
device=device, dtype=torch.bool)
|
| 820 |
+
|
| 821 |
+
mask = torch.cat((prepend_mask, mask), dim=-1)
|
| 822 |
+
|
| 823 |
+
# Attention layers
|
| 824 |
+
|
| 825 |
+
if self.rotary_pos_emb is not None:
|
| 826 |
+
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
| 827 |
+
else:
|
| 828 |
+
rotary_pos_emb = None
|
| 829 |
+
|
| 830 |
+
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
| 831 |
+
x = x + self.pos_emb(x)
|
| 832 |
+
|
| 833 |
+
# Iterate over the transformer layers
|
| 834 |
+
for layer in self.layers:
|
| 835 |
+
# x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
| 836 |
+
x = checkpoint(layer, x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs)
|
| 837 |
+
|
| 838 |
+
if return_info:
|
| 839 |
+
info["hidden_states"].append(x)
|
| 840 |
+
|
| 841 |
+
x = self.project_out(x)
|
| 842 |
+
|
| 843 |
+
if return_info:
|
| 844 |
+
return x, info
|
| 845 |
+
|
| 846 |
+
return x
|
pretrained_weight/ear_vae_44k.pyt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0362dc7e96566869747dbe079b0a6d71c090b0a3a5d5077779e7be17c096d9d5
|
| 3 |
+
size 591453838
|