earlab commited on
Commit
b3c4dc3
·
verified ·
1 Parent(s): 88ac091

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pyt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # pretrained_weight/*
2
+ results
3
+ data/*
4
+ __pycache__
5
+ *.pyc
6
+ docs
7
+ images
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,147 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # εar-VAE: High Fidelity Music Reconstruction Model
2
+
3
+ This repository contains the official inference code for εar-VAE, aa 44.1 kHz music signal reconstruction model that rethinks and optimizes VAE training for audio. It targets two common weaknesses in existing open-source VAEs—phase accuracy and stereophonic spatial representation—by aligning objectives with auditory perception and introducing phase-aware training. Experiments show substantial improvements across diverse metrics, with particular strength in high-frequency harmonics and spatial characteristics.
4
+
5
+ <p align="center">
6
+ <img src="./images/all_compares.jpg" width=90%>
7
+ <img src="./images/table.png" width=90%>
8
+ </p>
9
+
10
+ <p align="center">
11
+ <em>Upper: Ablation study across our training components.</em> <em>Down: Cross-model metric comparison on the evaluation dataset.</em>
12
+ </p>
13
+
14
+ Why εar-VAE:
15
+ - 🎧 Perceptual alignment: A K-weighting perceptual filter is applied before loss computation to better match human hearing.
16
+ - 🔁 Phase-aware objectives: Two novel phase losses
17
+ - Stereo Correlation Loss for robust inter-channel coherence.
18
+ - Phase-Derivative Loss using Instantaneous Frequency and Group Delay for phase precision.
19
+ - 🌈 Spectral supervision paradigm: Magnitude supervised across MSLR (Mid/Side/Left/Right) components, while phase is supervised only by LR (Left/Right), improving stability and fidelity.
20
+ - 📈 44.1 kHz performance: Outperforms leading open-source models across diverse metrics.
21
+
22
+ ## 1. Installation
23
+
24
+ Follow these steps to set up the environment and install the necessary dependencies.
25
+
26
+ ### Installation Steps
27
+
28
+ 1. **Clone the repository:**
29
+ ```bash
30
+ git clone <your-repo-url>
31
+ cd ear_vae
32
+ ```
33
+
34
+ 2. **Create and activate a conda environment:**
35
+ ```bash
36
+ conda create -n ear_vae python=3.8
37
+ conda activate ear_vae
38
+ ```
39
+
40
+ 3. **Run the installation script:**
41
+
42
+ This script will install the remaining dependencies.
43
+ ```bash
44
+ bash install_requirements.sh
45
+ ```
46
+ This will install:
47
+ - `descript-audio-codec`
48
+ - `alias-free-torch`
49
+ - `ffmpeg < 7` (via conda)
50
+
51
+ 4. **Download the model weight:**
52
+
53
+ You could download the model checkpoint from **[Hugging Face](https://huggingface.co/earlab/EAR_VAE)**
54
+ ## 2. Usage
55
+
56
+ The `inference.py` script is used to process audio files from an input directory and save the reconstructed audio to an output directory.
57
+
58
+ ### Running Inference
59
+
60
+ You can run the inference with the following command:
61
+
62
+ ```bash
63
+ python inference.py --indir <input_directory> --outdir <output_directory> --model_path <path_to_model> --device <device>
64
+ ```
65
+
66
+ ### Command-Line Arguments
67
+
68
+ - `--indir`: (Optional) Path to the input directory containing audio files. Default: `./data`.
69
+ - `--outdir`: (Optional) Path to the output directory where reconstructed audio will be saved. Default: `./results`.
70
+ - `--model_path`: (Optional) Path to the pretrained model weights (`.pyt` file). Default: `./pretrained_weight/ear_vae_44k.pyt`.
71
+ - `--device`: (Optional) The device to run the model on (e.g., `cuda:0` or `cpu`). Defaults to `cuda:0` if available, otherwise `cpu`.
72
+
73
+ ### Example
74
+
75
+ 1. Place your input audio files (e.g., `.wav`, `.mp3`) into the `data/` directory.
76
+ 2. Run the inference script:
77
+
78
+ ```bash
79
+ python inference.py
80
+ ```
81
+ This will use the default paths. The reconstructed audio files will be saved in the `results/` directory.
82
+
83
+ ## 3. Project Structure
84
+
85
+ ```
86
+ .
87
+ ├── README.md # This file
88
+ ├── config/ # For model configurations
89
+ │ └── model_config.json
90
+ ├── data/ # Default directory for input audio files
91
+ ├── eval/ # Scripts for model evaluation
92
+ │ ├── eval_compare_matrix.py
93
+ │ ├── install_requirements.sh
94
+ │ └── README.md
95
+ ├── inference.py # Main script for running audio reconstruction
96
+ ├── install_requirements.sh # Installation script for dependencies
97
+ ├── model/ # Contains the model architecture code
98
+ │ ├── sa2vae.py
99
+ │ ├── transformer.py
100
+ │ └── vaegan.py
101
+ ├── pretrained_weight/ # Directory for pretrained model weights
102
+ │ └── your_weight_here
103
+ ```
104
+
105
+ ## 4. Model Details
106
+
107
+ The model is a Variational Autoencoder with a Generative Adversarial Network (VAE-GAN) structure.
108
+ - **Encoder**: An Oobleck-style encoder that downsamples the input audio into a latent representation.
109
+ - **Bottleneck**: A VAE bottleneck that introduces a probabilistic latent space, sampling from a learned mean and variance.
110
+ - **Decoder**: An Oobleck-style decoder that upsamples the latent representation back into an audio waveform.
111
+ - **Transformer**: A Continuous Transformer can optionally be placed in the bottleneck to further process the latent sequence.
112
+
113
+ This architecture allows for efficient and high-quality audio reconstruction.
114
+
115
+ ## 5. Evaluation
116
+
117
+ The `eval/` directory contains scripts to evaluate the model's reconstruction performance using objective metrics.
118
+
119
+ ### Evaluation Prerequisites
120
+
121
+ 1. **Install Dependencies**: The evaluation script has its own set of dependencies. Install them by running the script in the `eval` directory:
122
+ ```bash
123
+ bash eval/install_requirements.sh
124
+ ```
125
+ This will install libraries such as `auraloss`.
126
+
127
+ 2. **FFmpeg**: The script uses `ffmpeg` for loudness analysis. Make sure `ffmpeg` is installed and available in your system's PATH. You can install it via conda:
128
+ ```bash
129
+ conda install -c conda-forge 'ffmpeg<7'
130
+ ```
131
+
132
+ ### Running Evaluation
133
+
134
+ The `eval_compare_matrix.py` script compares the reconstructed audio with the original ground truth files and computes various metrics.
135
+
136
+ For more details on the evaluation metrics and options, refer to the `eval/README.md` file.
137
+
138
+ ## 6. Acknowledgements
139
+
140
+ This project builds upon the work of several open-source projects. We would like to extend our special thanks to:
141
+
142
+ - **[Stability AI's Stable Audio Tools](https://github.com/Stability-AI/stable-audio-tools)**: For providing a foundational framework and tools for audio generation.
143
+ - **[Descript's Audio Codec](https://github.com/descriptinc/descript-audio-codec)**: For the weight-normed convolusional layers
144
+
145
+ Their contributions have been invaluable to the development of εar-VAE.
146
+
147
+
config/model_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "transformer": {
3
+ "depth": 2,
4
+ "config": {
5
+ "rotary_pos_emb": true,
6
+ "dim_heads": 32
7
+ }
8
+ },
9
+ "encoder": {
10
+ "config": {
11
+ "in_channels": 2,
12
+ "channels": 128,
13
+ "c_mults": [1, 2, 4, 8, 16],
14
+ "strides": [2, 4, 4, 4, 8],
15
+ "latent_dim": 128,
16
+ "use_snake": true
17
+ }
18
+ },
19
+ "decoder": {
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 4, 8],
25
+ "latent_dim": 64,
26
+ "use_nearest_upsample": false,
27
+ "use_snake": true,
28
+ "final_tanh": false
29
+ }
30
+ },
31
+ "latent_dim": 64,
32
+ "downsampling_ratio": 1024,
33
+ "io_channels": 2
34
+ }
eval/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VAE Audio Evaluation
2
+
3
+ This directory contains the script and resources for evaluating the performance of models in audio reconstruction tasks. The primary script, `eval_compare_matrix.py`, computes a suite of objective metrics to compare the quality of audio generated by the model against the original ground truth audio.
4
+
5
+ ## Features
6
+
7
+ - **Comprehensive Metrics**: Calculates a wide range of industry-standard and research-grade metrics:
8
+ - **Time-Domain**: Scale-Invariant Signal-to-Distortion Ratio (SI-SDR).
9
+ - **Frequency-Domain**: Multi-Resolution STFT Loss and Multi-Resolution Mel-Spectrogram Loss.
10
+ - **Phase**: Multi-Resolution Phase Coherence (both per-channel and inter-channel for stereo).
11
+ - **Loudness**: Integrated Loudness (LUFS-I), Loudness Range (LRA), and True Peak, analyzed using `ffmpeg`.
12
+ - **Batch Processing**: Automatically discovers and processes multiple model output directories.
13
+ - **File Matching**: Intelligently pairs reconstructed audio files (e.g., `*_vae_rec.wav`) with their corresponding ground truth files (e.g., `*.wav`).
14
+ - **Robust & Resilient**: Handles missing files, audio processing errors, and varying sample rates gracefully.
15
+ - **Organized Output**: Saves aggregated results in both machine-readable (`.json`) and human-readable (`.txt`) formats for each model.
16
+ - **Command-Line Interface**: Easy-to-use CLI for specifying the input directory and other options.
17
+
18
+ ## Prerequisites
19
+
20
+ ### 1. Python Environment
21
+ Ensure you have a Python environment (3.8 or newer recommended) with the required packages installed. You can install them using pip:
22
+ ```bash
23
+ pip install torch torchaudio auraloss numpy
24
+ ```
25
+
26
+ ### 2. FFmpeg
27
+ The script relies on `ffmpeg` for loudness analysis. You must have `ffmpeg` installed and accessible in your system's PATH.
28
+
29
+ **On Ubuntu/Debian:**
30
+ ```bash
31
+ sudo apt update && sudo apt install ffmpeg
32
+ ```
33
+
34
+ **On macOS (using Homebrew):**
35
+ ```bash
36
+ brew install ffmpeg
37
+ ```
38
+
39
+ **On Windows:**
40
+ Download the executable from the [official FFmpeg website](https://ffmpeg.org/download.html) and add its `bin` directory to your system's PATH environment variable.
41
+
42
+ You can verify the installation by running:
43
+ ```bash
44
+ ffmpeg -version
45
+ ```
46
+
47
+ **Also On Conda ENv:**
48
+ ```bash
49
+ conda install -c conda-forge 'ffmpeg<7'
50
+ ```
51
+
52
+ ## Directory Structure
53
+
54
+ The script expects a specific directory structure for the evaluation data. The root input directory should contain subdirectories, where each subdirectory represents a different model or experiment to be evaluated.
55
+
56
+ Inside each model's subdirectory, you should place the pairs of ground truth and reconstructed audio files. The script identifies pairs based on a naming convention:
57
+ - **Ground Truth**: `your_audio_file.wav`
58
+ - **Reconstructed**: `your_audio_file_vae_rec.wav`
59
+
60
+ Here is an example structure:
61
+ ```
62
+ /path/to/your/evaluation_data/
63
+ ├── model_A/
64
+ │ ├── song1.wav # Ground Truth 1
65
+ │ ├── song1_vae_rec.wav # Reconstructed 1
66
+ │ ├── song2.wav # Ground Truth 2
67
+ │ ├── song2_vae_rec.wav # Reconstructed 2
68
+ │ └── ...
69
+ ├── model_B/
70
+ │ ├── trackA.wav
71
+ │ ├── trackA_vae_rec.wav
72
+ │ └── ...
73
+ └── ...
74
+ ```
75
+
76
+ ## Usage
77
+
78
+ Run the evaluation script from the command line, pointing it to the root directory containing your model outputs.
79
+
80
+ ```bash
81
+ python eval_compare_matrix.py --input_dir /path/to/your/evaluation_data/
82
+ ```
83
+
84
+ ### Command-Line Arguments
85
+
86
+ - `--input_dir` (required): The path to the root directory containing the model folders (e.g., `/path/to/your/evaluation_data/`).
87
+ - `--force` (optional): If specified, the script will re-run the evaluation for all models, even if results files (`evaluation_results.json`) already exist. By default, it skips models that have already been evaluated.
88
+ - `--echo` (optional): If specified, the script will print the detailed evaluation metrics for each individual audio pair during processing. By default, only the progress bar and final summary are shown.
89
+
90
+ ### Example
91
+ ```bash
92
+ python eval/eval_compare_matrix.py --input_dir ./results/
93
+ ```
94
+
95
+ ## Output
96
+
97
+ After running, the script will generate two files inside each model's directory:
98
+
99
+ 1. **`evaluation_results.json`**: A JSON file containing the aggregated average of all computed metrics. This is ideal for programmatic analysis.
100
+ ```json
101
+ {
102
+ "model_name": "model_A",
103
+ "file_count": 50,
104
+ "avg_sisdr": 15.78,
105
+ "avg_mel_distance": 0.45,
106
+ "avg_stft_distance": 0.89,
107
+ "avg_per_channel_coherence": 0.95,
108
+ "avg_interchannel_coherence": 0.92,
109
+ "avg_gen_lufs-i": -14.2,
110
+ "avg_gt_lufs-i": -14.0,
111
+ ...
112
+ }
113
+ ```
114
+
115
+ 2. **`evaluation_summary.txt`**: A human-readable text file summarizing the results.
116
+ ```
117
+ model_name: model_A
118
+ file_count: 50
119
+ avg_sisdr: 15.78...
120
+ avg_mel_distance: 0.45...
121
+ ...
122
+ ```
123
+ This allows for quick inspection of a model's performance without needing to parse the JSON.
eval/eval_compare_matrix.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Evaluation Script
3
+
4
+ This script evaluates the quality of generated audio against ground truth audio
5
+ using a variety of metrics, including:
6
+ - SI-SDR (Scale-Invariant Signal-to-Distortion Ratio)
7
+ - Multi-Resolution STFT Loss
8
+ - Multi-Resolution Mel-Spectrogram Loss
9
+ - Phase Coherence (Per-channel and Inter-channel)
10
+ - Loudness metrics (LUFS-I, LRA, True Peak) via ffmpeg.
11
+
12
+ The script processes a directory of models, where each model directory contains
13
+ pairs of reconstructed (_rec.wav) and ground truth (.wav) audio files.
14
+ """
15
+
16
+ import os
17
+ import re
18
+ import sys
19
+ import json
20
+ import logging
21
+ import argparse
22
+ import subprocess
23
+ from pathlib import Path
24
+ from typing import Dict, List, Tuple, Optional
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+ import torchaudio
30
+ import auraloss
31
+ from tqdm import tqdm
32
+
33
+ # --- Setup ---
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format='%(asctime)s - %(levelname)s - %(message)s',
37
+ stream=sys.stdout
38
+ )
39
+
40
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ SAMPLE_RATE = 44100
42
+
43
+ # --- Metric Definitions ---
44
+
45
+ # SI-SDR
46
+ sisdr_criteria = auraloss.time.SISDRLoss().to(DEVICE)
47
+
48
+ # Multi-Resolution Mel-Spectrogram Loss
49
+ mel_fft_sizes = [4096, 2048, 1024, 512]
50
+ mel_win_sizes = mel_fft_sizes
51
+ mel_hop_sizes = [i // 4 for i in mel_fft_sizes]
52
+ mel_criteria = auraloss.freq.MultiResolutionSTFTLoss(
53
+ fft_sizes=mel_fft_sizes,
54
+ hop_sizes=mel_hop_sizes,
55
+ win_lengths=mel_win_sizes,
56
+ sample_rate=SAMPLE_RATE,
57
+ scale="mel",
58
+ n_bins=64,
59
+ perceptual_weighting=True
60
+ ).to(DEVICE)
61
+
62
+ # Multi-Resolution STFT Loss
63
+ fft_sizes = [4096, 2048, 1024, 512, 256, 128]
64
+ win_sizes = fft_sizes
65
+ hop_sizes = [i // 4 for i in fft_sizes]
66
+ stft_criteria = auraloss.freq.MultiResolutionSTFTLoss(
67
+ fft_sizes=fft_sizes,
68
+ hop_sizes=hop_sizes,
69
+ win_lengths=win_sizes,
70
+ sample_rate=SAMPLE_RATE,
71
+ perceptual_weighting=True
72
+ ).to(DEVICE)
73
+
74
+
75
+ def analyze_loudness(file_path: str) -> Optional[Dict[str, float]]:
76
+ """
77
+ Analyzes audio file loudness using ffmpeg's ebur128 filter.
78
+
79
+ Args:
80
+ file_path: Path to the audio file.
81
+
82
+ Returns:
83
+ A dictionary with LUFS-I, LRA, and True Peak, or None on failure.
84
+ """
85
+ if not Path(file_path).exists():
86
+ logging.warning(f"Loudness analysis skipped: File not found at {file_path}")
87
+ return None
88
+
89
+ command = [
90
+ "ffmpeg",
91
+ "-nostats",
92
+ "-i", file_path,
93
+ "-af", "ebur128=peak=true,ametadata=mode=print:file=-",
94
+ "-f", "null",
95
+ "-"
96
+ ]
97
+
98
+ try:
99
+ result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
100
+ output_text = result.stderr
101
+ except FileNotFoundError:
102
+ logging.error("ffmpeg not found. Please install ffmpeg and ensure it's in your PATH.")
103
+ return None
104
+ except subprocess.CalledProcessError as e:
105
+ logging.error(f"ffmpeg analysis failed for {file_path}. Error: {e.stderr}")
106
+ return None
107
+
108
+ loudness_data = {}
109
+
110
+ i_match = re.search(r"^\s*I:\s*(-?[\d\.]+)\s*LUFS", output_text, re.MULTILINE)
111
+ if i_match:
112
+ loudness_data['LUFS-I'] = float(i_match.group(1))
113
+
114
+ lra_match = re.search(r"^\s*LRA:\s*([\d\.]+)\s*LU", output_text, re.MULTILINE)
115
+ if lra_match:
116
+ loudness_data['LRA'] = float(lra_match.group(1))
117
+
118
+ tp_match = re.search(r"Peak:\s*(-?[\d\.]+)\s*dBFS", output_text, re.MULTILINE)
119
+ if tp_match:
120
+ loudness_data['True Peak'] = float(tp_match.group(1))
121
+
122
+ if not loudness_data:
123
+ logging.warning(f"Could not parse loudness data for {file_path}.")
124
+ return None
125
+
126
+ return loudness_data
127
+
128
+
129
+ class PhaseCoherenceLoss(nn.Module):
130
+ """
131
+ Calculates phase coherence between two audio signals.
132
+ Adapted for stereo and multi-resolution analysis.
133
+ """
134
+ def __init__(self, fft_size, hop_size, win_size, mag_threshold=1e-6, eps=1e-8):
135
+ super().__init__()
136
+ self.fft_size = int(fft_size)
137
+ self.hop_size = int(hop_size)
138
+ self.win_size = int(win_size)
139
+ self.register_buffer("window", torch.hann_window(win_size))
140
+ self.mag_threshold = float(mag_threshold)
141
+ self.eps = float(eps)
142
+
143
+ def _to_complex(self, x):
144
+ if torch.is_complex(x):
145
+ return x
146
+ if x.dim() >= 1 and x.size(-1) == 2:
147
+ return torch.complex(x[..., 0], x[..., 1])
148
+ raise ValueError("Input must be complex or real/imag tensor.")
149
+
150
+ def _stereo_stft(self, x):
151
+ if x.dim() == 2:
152
+ x = x.unsqueeze(0)
153
+ B, C, T = x.shape
154
+ stft = torch.stft(x.reshape(B * C, T),
155
+ n_fft=self.fft_size,
156
+ hop_length=self.hop_size,
157
+ win_length=self.win_size,
158
+ window=self.window,
159
+ return_complex=True)
160
+ return stft.view(B, C, -1, stft.size(-1))
161
+
162
+ def forward(self, pred, target):
163
+ pred_stft = self._stereo_stft(pred)
164
+ target_stft = self._stereo_stft(target)
165
+
166
+ pred_stft = self._to_complex(pred_stft)
167
+ target_stft = self._to_complex(target_stft)
168
+
169
+ B, C, F, T = pred_stft.shape
170
+
171
+ # magnitudes and weights
172
+ mag_pred = torch.abs(pred_stft)
173
+ mag_target = torch.abs(target_stft)
174
+ weights = mag_pred * mag_target
175
+ mask = (weights > self.mag_threshold).to(weights.dtype)
176
+ weights_masked = weights * mask
177
+
178
+ # phase difference Δφ = angle(pred) - angle(target)
179
+ delta = torch.angle(pred_stft) - torch.angle(target_stft)
180
+ # phasor e^{jΔφ}
181
+ phasor = torch.complex(torch.cos(delta), torch.sin(delta))
182
+
183
+ # weighted vector sum across frequency axis
184
+ num = torch.sum(weights_masked * phasor, dim=2) # [B, C, T], complex
185
+ den = torch.sum(weights_masked, dim=2).clamp_min(self.eps)
186
+ coherence_per_bin = torch.abs(num) / den
187
+
188
+ # pool across time (energy-weighted mean) -> per-channel scalar
189
+ # weight time pooling by per-frame energy sum to emphasize active frames
190
+ frame_energy = torch.sum(weights_masked, dim=2)
191
+ frame_energy_sum = torch.sum(frame_energy, dim=2).clamp_min(self.eps)
192
+
193
+ # energy-weighted average over time
194
+ coherence_chan = torch.sum(coherence_per_bin * frame_energy, dim=2) / frame_energy_sum
195
+
196
+ # mean across batch
197
+ per_channel_coherence = coherence_chan.mean(dim=0)
198
+
199
+ inter_coherence = None
200
+ if C >= 2:
201
+ Lp, Rp = pred_stft[:, 0], pred_stft[:, 1]
202
+ Lt, Rt = target_stft[:, 0], target_stft[:, 1]
203
+
204
+ # inter-channel phase: angle(L) - angle(R) <=> angle(L * conj(R))
205
+ inter_delta = torch.angle(Lp * torch.conj(Rp)) - torch.angle(Lt * torch.conj(Rt))
206
+ inter_weights = torch.abs(Lp) * torch.abs(Rp)
207
+ inter_mask = (inter_weights > self.mag_threshold).to(inter_weights.dtype)
208
+ inter_weights_masked = inter_weights * inter_mask
209
+ inter_phasor = torch.complex(torch.cos(inter_delta), torch.sin(inter_delta))
210
+ inter_num = torch.sum(inter_weights_masked * inter_phasor, dim=1)
211
+ inter_den = torch.sum(inter_weights_masked, dim=1).clamp_min(self.eps)
212
+ inter_coh_time = torch.abs(inter_num) / inter_den
213
+
214
+ # pool across time weighted by energy
215
+ inter_frame_energy = torch.sum(inter_weights_masked, dim=1)
216
+ inter_energy_sum = inter_frame_energy.sum(dim=1).clamp_min(self.eps)
217
+ inter_coh_b = (inter_coh_time * inter_frame_energy).sum(dim=1) / inter_energy_sum
218
+ inter_coherence = inter_coh_b.mean()
219
+
220
+ return {
221
+ "per_channel_coherence": per_channel_coherence.detach().cpu(),
222
+ "interchannel_coherence": (inter_coherence.detach().cpu() if inter_coherence is not None else None),
223
+ }
224
+
225
+
226
+ class MultiResolutionPhaseCoherenceLoss(nn.Module):
227
+ def __init__(self, fft_sizes, hop_sizes, win_sizes):
228
+ super().__init__()
229
+ self.criteria = nn.ModuleList([
230
+ PhaseCoherenceLoss(fft, hop, win) for fft, hop, win in zip(fft_sizes, hop_sizes, win_sizes)
231
+ ])
232
+
233
+ def forward(self, pred, target):
234
+ results = [criterion(pred, target) for criterion in self.criteria]
235
+
236
+ per_channel = torch.stack([r["per_channel_coherence"] for r in results]).mean(dim=0)
237
+ inter_items = [r["interchannel_coherence"] for r in results if r["interchannel_coherence"] is not None]
238
+ inter_channel = torch.stack(inter_items).mean() if inter_items else None
239
+
240
+ return {"per_channel_coherence": per_channel, "interchannel_coherence": inter_channel}
241
+
242
+ phase_coherence_criteria = MultiResolutionPhaseCoherenceLoss(
243
+ fft_sizes=mel_fft_sizes, hop_sizes=mel_hop_sizes, win_sizes=mel_win_sizes
244
+ ).to(DEVICE)
245
+
246
+
247
+ def find_audio_pairs(model_path: Path) -> List[Tuple[Path, Path]]:
248
+ """Finds pairs of reconstructed and ground truth audio files."""
249
+ rec_files = sorted(model_path.glob("*_vae_rec.wav"))
250
+ pairs = []
251
+ for rec_file in rec_files:
252
+ gt_file = model_path / rec_file.name.replace("_vae_rec.wav", ".wav")
253
+ if gt_file.exists():
254
+ pairs.append((rec_file, gt_file))
255
+ else:
256
+ logging.warning(f"Ground truth file not found for {rec_file.name}")
257
+ return pairs
258
+
259
+
260
+ def evaluate_pair(rec_path: Path, gt_path: Path) -> Optional[Dict[str, float]]:
261
+ """Evaluates a single pair of audio files."""
262
+ try:
263
+ gen_wav, gen_sr = torchaudio.load(rec_path, backend="ffmpeg")
264
+ gt_wav, gt_sr = torchaudio.load(gt_path, backend="ffmpeg")
265
+
266
+ if gen_sr != SAMPLE_RATE:
267
+ gen_wav = torchaudio.transforms.Resample(gen_sr, SAMPLE_RATE)(gen_wav)
268
+ if gt_sr != SAMPLE_RATE:
269
+ gt_wav = torchaudio.transforms.Resample(gt_sr, SAMPLE_RATE)(gt_wav)
270
+
271
+ # Trim to same length
272
+ if gen_wav.shape[-1] != gt_wav.shape[-1]:
273
+ logging.info(f"Shape Mismatched, Trimming audio files to the same length: {rec_path.name}, {gt_path.name}")
274
+ min_len = min(gen_wav.shape[-1], gt_wav.shape[-1])
275
+ gen_wav, gt_wav = gen_wav[:, :min_len], gt_wav[:, :min_len]
276
+
277
+ gen_wav, gt_wav = gen_wav.to(DEVICE).unsqueeze(0), gt_wav.to(DEVICE).unsqueeze(0)
278
+
279
+ metrics = {}
280
+ metrics['sisdr'] = -sisdr_criteria(gen_wav, gt_wav).item()
281
+ metrics['mel_distance'] = mel_criteria(gen_wav, gt_wav).item()
282
+ metrics['stft_distance'] = stft_criteria(gen_wav, gt_wav).item()
283
+
284
+ phase_metrics = phase_coherence_criteria(gen_wav, gt_wav)
285
+ metrics['per_channel_coherence'] = phase_metrics["per_channel_coherence"].mean().item()
286
+ if phase_metrics["interchannel_coherence"] is not None:
287
+ metrics['interchannel_coherence'] = phase_metrics["interchannel_coherence"].item()
288
+
289
+ return metrics
290
+ except Exception as e:
291
+ logging.error(f"Error processing pair {rec_path.name}, {gt_path.name}: {e}")
292
+ return None
293
+
294
+
295
+ def process_model(model_path: Path, force_eval: bool = False, echo=True):
296
+ """Processes all audio pairs for a given model."""
297
+ logging.info(f"Processing model: {model_path.name}")
298
+ results_file = model_path / "evaluation_results.json"
299
+
300
+ if results_file.exists() and not force_eval:
301
+ logging.info(f"Results already exist for {model_path.name}, skipping.")
302
+ return
303
+
304
+ audio_pairs = find_audio_pairs(model_path)
305
+ if not audio_pairs:
306
+ logging.warning(f"No valid audio pairs found for {model_path.name}.")
307
+ return
308
+
309
+ all_metrics = []
310
+ gen_loudness_data, gt_loudness_data = [], []
311
+
312
+ with torch.no_grad():
313
+ for rec_path, gt_path in tqdm(audio_pairs, desc=f"Evaluating {model_path.name}"):
314
+ pair_metrics = evaluate_pair(rec_path, gt_path)
315
+ if pair_metrics:
316
+ all_metrics.append(pair_metrics)
317
+
318
+ gen_loudness = analyze_loudness(str(rec_path))
319
+ if gen_loudness:
320
+ gen_loudness_data.append(gen_loudness)
321
+
322
+ gt_loudness = analyze_loudness(str(gt_path))
323
+ if gt_loudness:
324
+ gt_loudness_data.append(gt_loudness)
325
+
326
+ if echo:
327
+ logging.info(f"Metrics for {rec_path.name}: {pair_metrics}")
328
+ if gen_loudness:
329
+ logging.info(f"Generated Loudness: {gen_loudness}")
330
+ if gt_loudness:
331
+ logging.info(f"Ground Truth Loudness: {gt_loudness}")
332
+
333
+ if not all_metrics:
334
+ logging.warning(f"No metrics could be calculated for {model_path.name}.")
335
+ return
336
+
337
+ # Aggregate results
338
+ summary = {"model_name": model_path.name, "file_count": len(all_metrics)}
339
+
340
+ # Average objective metrics
341
+ metric_keys = all_metrics[0].keys()
342
+ for key in metric_keys:
343
+ valid_values = [m[key] for m in all_metrics if key in m]
344
+ if valid_values:
345
+ summary[f"avg_{key}"] = float(np.mean(valid_values))
346
+
347
+ # Average loudness metrics
348
+ def _avg_loudness(data: List[Dict[str, float]], prefix: str):
349
+ if not data: return
350
+ for key in data[0].keys():
351
+ values = [d[key] for d in data if key in d]
352
+ if values:
353
+ summary[f"avg_{prefix}_{key.lower().replace(' ', '_')}"] = float(np.mean(values))
354
+
355
+ _avg_loudness(gen_loudness_data, "gen")
356
+ _avg_loudness(gt_loudness_data, "gt")
357
+
358
+ # Save results
359
+ logging.info(f"Saving results for {model_path.name} to {results_file}")
360
+ with open(results_file, 'w') as f:
361
+ json.dump(summary, f, indent=4)
362
+
363
+ # Also save a human-readable version
364
+ with open(model_path / "evaluation_summary.txt", "w") as f:
365
+ for key, value in summary.items():
366
+ f.write(f"{key}: {value}\n")
367
+
368
+
369
+ def main():
370
+ parser = argparse.ArgumentParser(description="Run evaluation on generated audio.")
371
+ parser.add_argument(
372
+ "--input_dir",
373
+ type=str,
374
+ required=True,
375
+ help="Root directory containing model output folders."
376
+ )
377
+ parser.add_argument(
378
+ "--force",
379
+ action="store_true",
380
+ help="Force re-evaluation even if results files exist."
381
+ )
382
+
383
+ parser.add_argument(
384
+ "--echo",
385
+ action="store_true",
386
+ help="Echo per-file metrics to console during evaluation."
387
+ )
388
+ args = parser.parse_args()
389
+
390
+ root_path = Path(args.input_dir)
391
+ if not root_path.is_dir():
392
+ logging.error(f"Input directory not found: {root_path}")
393
+ sys.exit(1)
394
+
395
+ model_paths = [p for p in root_path.iterdir() if p.is_dir() and not p.name.startswith('.')]
396
+
397
+ logging.info(f"Found {len(model_paths)} model(s) to evaluate: {[p.name for p in model_paths]}")
398
+
399
+ for model_path in sorted(model_paths):
400
+ process_model(model_path, args.force, args.echo)
401
+
402
+ logging.info("Evaluation complete.")
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()
eval/install_requirements.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ pip install torch torchaudio auraloss numpy
inference.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from pathlib import Path
3
+ from tqdm import tqdm
4
+ import torch
5
+ import argparse
6
+ import json
7
+ from model.ear_vae import EAR_VAE
8
+
9
+ def main(args):
10
+ indir = args.indir
11
+ model_path = args.model_path
12
+ outdir = args.outdir
13
+ device = args.device
14
+ config_path = args.config
15
+
16
+ print(f"Input directory: {indir}")
17
+ print(f"Model path: {model_path}")
18
+ print(f"Output directory: {outdir}")
19
+ print(f"Device: {device}")
20
+ print(f"Config path: {config_path}")
21
+
22
+
23
+ input_path = Path(indir)
24
+ output_path_dir = Path(outdir)
25
+ output_path_dir.mkdir(parents=True, exist_ok=True)
26
+
27
+ with open(config_path, 'r') as f:
28
+ vae_gan_model_config = json.load(f)
29
+
30
+ print("Loading model...")
31
+ model = EAR_VAE(model_config=vae_gan_model_config).to(device)
32
+
33
+ state = torch.load(model_path, map_location="cpu")
34
+ model.load_state_dict(state)
35
+ model.eval()
36
+ print("Model loaded successfully.")
37
+
38
+ audios = list(input_path.rglob("*"))
39
+ print(f"Found {len(audios)} audio files to process.")
40
+
41
+ with torch.no_grad():
42
+ for audio_path in tqdm(audios, desc="Processing audio files"):
43
+ try:
44
+ gt_y, sr = torchaudio.load(audio_path, backend="ffmpeg")
45
+
46
+ if len(gt_y.shape) == 1:
47
+ gt_y = gt_y.unsqueeze(0)
48
+
49
+ # Resample if necessary
50
+ if sr != 44100:
51
+ resampler = torchaudio.transforms.Resample(sr, 44100).to(device)
52
+ gt_y = resampler(gt_y)
53
+
54
+ gt_y = gt_y.to(device, torch.float32)
55
+
56
+ # Convert to stereo if mono
57
+ if gt_y.shape[0] == 1:
58
+ gt_y = torch.cat([gt_y, gt_y], dim=0)
59
+
60
+ # Add batch dimension
61
+ gt_y = gt_y.unsqueeze(0)
62
+
63
+ fake_audio = model.inference(gt_y)
64
+
65
+ output_filename = f"{Path(audio_path).stem}_{Path(model_path).stem}.wav"
66
+ output_path = output_path_dir / output_filename
67
+
68
+ fake_audio_processed = fake_audio.squeeze(0).cpu()
69
+ torchaudio.save(output_path, fake_audio_processed, sample_rate=44100, backend="ffmpeg")
70
+ except Exception as e:
71
+ print(f"Error processing {audio_path}: {e}")
72
+
73
+ continue
74
+
75
+
76
+ if __name__ == '__main__':
77
+ parser = argparse.ArgumentParser(description="Run VAE-GAN audio inference.")
78
+ parser.add_argument('--indir', type=str, default='./data', help='Input directory for audio files.')
79
+ parser.add_argument('--model_path', type=str, default='./pretrained_weight/ear_vae_44k.pyt', help='Path to the pretrained model weight.')
80
+ parser.add_argument('--outdir', type=str, default='./results', help='Output directory for generated audio files.')
81
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to run the model on (e.g., "cuda:0" or "cpu").')
82
+ parser.add_argument('--config', type=str, default='./config/model_config.json', help='Path to the model config file.')
83
+
84
+ args = parser.parse_args()
85
+ main(args)
install_requirements.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pip install descript-audio-codec
2
+ python -m pip install alias-free-torch
3
+ conda install -c conda-forge 'ffmpeg<7'
model/autoencoders.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from torch import nn, pow
5
+ from alias_free_torch import Activation1d
6
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
7
+ from typing import Literal
8
+
9
+ def snake_beta(x, alpha, beta):
10
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
11
+
12
+ class SnakeBeta(nn.Module):
13
+ def __init__(
14
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
15
+ ):
16
+ super(SnakeBeta, self).__init__()
17
+ self.in_features = in_features
18
+
19
+ # initialize alpha
20
+ self.alpha_logscale = alpha_logscale
21
+ if self.alpha_logscale: # log scale alphas initialized to zeros
22
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
23
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
24
+ else: # linear scale alphas initialized to ones
25
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
26
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
27
+
28
+ self.alpha.requires_grad = alpha_trainable
29
+ self.beta.requires_grad = alpha_trainable
30
+
31
+ self.no_div_by_zero = 0.000000001
32
+
33
+ def forward(self, x):
34
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
35
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
36
+ if self.alpha_logscale:
37
+ alpha = torch.exp(alpha)
38
+ beta = torch.exp(beta)
39
+ x = snake_beta(x, alpha, beta)
40
+
41
+ return x
42
+
43
+
44
+ def checkpoint(function, *args, **kwargs):
45
+ kwargs.setdefault("use_reentrant", False)
46
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
47
+
48
+
49
+ def get_activation(
50
+ activation: Literal["elu", "snake", "none"], antialias=False, channels=None
51
+ ) -> nn.Module:
52
+ if activation == "elu":
53
+ act = nn.ELU()
54
+ elif activation == "snake":
55
+ act = SnakeBeta(channels)
56
+ elif activation == "none":
57
+ act = nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unknown activation {activation}")
60
+
61
+ if antialias:
62
+ act = Activation1d(act)
63
+
64
+ return act
65
+
66
+
67
+ class ResidualUnit(nn.Module):
68
+ def __init__(
69
+ self,
70
+ in_channels,
71
+ out_channels,
72
+ dilation,
73
+ use_snake=False,
74
+ antialias_activation=False,
75
+ bias=True,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.dilation = dilation
80
+
81
+ act = get_activation(
82
+ "snake" if use_snake else "elu",
83
+ antialias=antialias_activation,
84
+ channels=out_channels,
85
+ )
86
+
87
+ padding = (dilation * (7 - 1)) // 2
88
+
89
+ self.layers = nn.Sequential(
90
+ act,
91
+ WNConv1d(
92
+ in_channels=in_channels,
93
+ out_channels=out_channels,
94
+ kernel_size=7,
95
+ dilation=dilation,
96
+ padding=padding,
97
+ bias=bias,
98
+ ),
99
+ act,
100
+ WNConv1d(
101
+ in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=bias
102
+ ),
103
+ )
104
+
105
+ def forward(self, x):
106
+ res = x
107
+
108
+ # x = checkpoint(self.layers, x)
109
+ x = self.layers(x)
110
+
111
+ return x + res
112
+
113
+
114
+ class EncoderBlock(nn.Module):
115
+ def __init__(
116
+ self,
117
+ in_channels,
118
+ out_channels,
119
+ stride,
120
+ use_snake=False,
121
+ antialias_activation=False,
122
+ bias=True,
123
+ ):
124
+ super().__init__()
125
+
126
+ act = get_activation(
127
+ "snake" if use_snake else "elu",
128
+ antialias=antialias_activation,
129
+ channels=in_channels,
130
+ )
131
+
132
+ self.layers = nn.Sequential(
133
+ ResidualUnit(
134
+ in_channels=in_channels,
135
+ out_channels=in_channels,
136
+ dilation=1,
137
+ use_snake=use_snake,
138
+ bias=bias,
139
+ ),
140
+ ResidualUnit(
141
+ in_channels=in_channels,
142
+ out_channels=in_channels,
143
+ dilation=3,
144
+ use_snake=use_snake,
145
+ bias=bias,
146
+ ),
147
+ ResidualUnit(
148
+ in_channels=in_channels,
149
+ out_channels=in_channels,
150
+ dilation=9,
151
+ use_snake=use_snake,
152
+ bias=bias,
153
+ ),
154
+ act,
155
+ WNConv1d(
156
+ in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ kernel_size=2 * stride,
159
+ stride=stride,
160
+ padding=math.ceil(stride / 2),
161
+ bias=bias,
162
+ ),
163
+ )
164
+
165
+ def forward(self, x):
166
+ return self.layers(x)
167
+
168
+
169
+ class AntiAliasUpsamplerBlock(nn.Module):
170
+ def __init__(self, in_channels, out_channels, stride=2, bias=True):
171
+ super().__init__()
172
+
173
+ self.upsample = nn.Upsample(scale_factor=stride, mode="nearest")
174
+
175
+ self.conv = WNConv1d(
176
+ in_channels=in_channels,
177
+ out_channels=out_channels,
178
+ kernel_size=2 * stride,
179
+ bias=bias,
180
+ padding="same",
181
+ )
182
+
183
+ def forward(self, x):
184
+ x = self.upsample(x)
185
+ x = self.conv(x)
186
+ return x
187
+
188
+
189
+ class DecoderBlock(nn.Module):
190
+ def __init__(
191
+ self,
192
+ in_channels,
193
+ out_channels,
194
+ stride,
195
+ use_snake=False,
196
+ antialias_activation=False,
197
+ use_nearest_upsample=False,
198
+ bias=True,
199
+ ):
200
+ super().__init__()
201
+
202
+ if use_nearest_upsample:
203
+ upsample_layer = AntiAliasUpsamplerBlock(
204
+ in_channels=in_channels, out_channels=out_channels, stride=stride, bias=bias
205
+ )
206
+ else:
207
+ upsample_layer = WNConvTranspose1d(
208
+ in_channels=in_channels,
209
+ out_channels=out_channels,
210
+ kernel_size=2 * stride,
211
+ stride=stride,
212
+ padding=math.ceil(stride / 2),
213
+ bias=bias,
214
+ )
215
+
216
+ act = get_activation(
217
+ "snake" if use_snake else "elu",
218
+ antialias=antialias_activation,
219
+ channels=in_channels,
220
+ )
221
+
222
+ self.layers = nn.Sequential(
223
+ act,
224
+ upsample_layer,
225
+ ResidualUnit(
226
+ in_channels=out_channels,
227
+ out_channels=out_channels,
228
+ dilation=1,
229
+ use_snake=use_snake,
230
+ bias=bias,
231
+ ),
232
+ ResidualUnit(
233
+ in_channels=out_channels,
234
+ out_channels=out_channels,
235
+ dilation=3,
236
+ use_snake=use_snake,
237
+ bias=bias,
238
+ ),
239
+ ResidualUnit(
240
+ in_channels=out_channels,
241
+ out_channels=out_channels,
242
+ dilation=9,
243
+ use_snake=use_snake,
244
+ bias=bias,
245
+ ),
246
+ )
247
+
248
+ def forward(self, x):
249
+ return self.layers(x)
250
+
251
+
252
+ class OobleckEncoder(nn.Module):
253
+ def __init__(
254
+ self,
255
+ in_channels=2,
256
+ channels=128,
257
+ latent_dim=32,
258
+ c_mults=[1, 2, 4, 8],
259
+ strides=[2, 4, 8, 8],
260
+ use_snake=False,
261
+ antialias_activation=False,
262
+ bias=True,
263
+ ):
264
+ super().__init__()
265
+
266
+ c_mults = [1] + c_mults
267
+
268
+ self.depth = len(c_mults)
269
+
270
+ layers = [
271
+ WNConv1d(
272
+ in_channels=in_channels,
273
+ out_channels=c_mults[0] * channels,
274
+ kernel_size=7,
275
+ padding=3,
276
+ bias=bias,
277
+ )
278
+ ]
279
+
280
+ for i in range(self.depth - 1):
281
+ layers += [
282
+ EncoderBlock(
283
+ in_channels=c_mults[i] * channels,
284
+ out_channels=c_mults[i + 1] * channels,
285
+ stride=strides[i],
286
+ use_snake=use_snake,
287
+ bias=bias,
288
+ )
289
+ ]
290
+
291
+ layers += [
292
+ get_activation(
293
+ "snake" if use_snake else "elu",
294
+ antialias=antialias_activation,
295
+ channels=c_mults[-1] * channels,
296
+ ),
297
+ WNConv1d(
298
+ in_channels=c_mults[-1] * channels,
299
+ out_channels=latent_dim,
300
+ kernel_size=3,
301
+ padding=1,
302
+ bias=bias,
303
+ ),
304
+ ]
305
+
306
+ self.layers = nn.Sequential(*layers)
307
+
308
+ def forward(self, x):
309
+ return self.layers(x)
310
+
311
+
312
+ class OobleckDecoder(nn.Module):
313
+ def __init__(
314
+ self,
315
+ out_channels=2,
316
+ channels=128,
317
+ latent_dim=32,
318
+ c_mults=[1, 2, 4, 8],
319
+ strides=[2, 4, 8, 8],
320
+ use_snake=False,
321
+ antialias_activation=False,
322
+ use_nearest_upsample=False,
323
+ final_tanh=True,
324
+ bias=True,
325
+ ):
326
+ super().__init__()
327
+
328
+ c_mults = [1] + c_mults
329
+
330
+ self.depth = len(c_mults)
331
+
332
+ layers = [
333
+ WNConv1d(
334
+ in_channels=latent_dim,
335
+ out_channels=c_mults[-1] * channels,
336
+ kernel_size=7,
337
+ padding=3,
338
+ bias=bias,
339
+ ),
340
+ ]
341
+
342
+ for i in range(self.depth - 1, 0, -1):
343
+ layers += [
344
+ DecoderBlock(
345
+ in_channels=c_mults[i] * channels,
346
+ out_channels=c_mults[i - 1] * channels,
347
+ stride=strides[i - 1],
348
+ use_snake=use_snake,
349
+ antialias_activation=antialias_activation,
350
+ use_nearest_upsample=use_nearest_upsample,
351
+ bias=bias,
352
+ )
353
+ ]
354
+
355
+ layers += [
356
+ get_activation(
357
+ "snake" if use_snake else "elu",
358
+ antialias=antialias_activation,
359
+ channels=c_mults[0] * channels,
360
+ ),
361
+ WNConv1d(
362
+ in_channels=c_mults[0] * channels,
363
+ out_channels=out_channels,
364
+ kernel_size=7,
365
+ padding=3,
366
+ bias=False,
367
+ ),
368
+ nn.Tanh() if final_tanh else nn.Identity(),
369
+ ]
370
+
371
+ self.layers = nn.Sequential(*layers)
372
+
373
+ def forward(self, x):
374
+ return self.layers(x)
model/ear_vae.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+
6
+ from torch import Tensor, nn, no_grad
7
+ from .autoencoders import OobleckDecoder, OobleckEncoder
8
+
9
+ from .transformer import ContinuousTransformer
10
+ LRELU_SLOPE = 0.1
11
+ padding_mode = "zeros"
12
+ sample_eps = 1e-6
13
+
14
+ def vae_sample(mean, scale):
15
+ stdev = nn.functional.softplus(scale)
16
+ var = stdev * stdev + sample_eps
17
+ logvar = torch.log(var)
18
+ latents = torch.randn_like(mean) * stdev + mean
19
+
20
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
21
+
22
+ return latents, kl
23
+
24
+
25
+ class EAR_VAE(nn.Module):
26
+
27
+ def __init__(self, model_config: dict = None):
28
+ super().__init__()
29
+
30
+ if model_config is None:
31
+ model_config = {
32
+ "encoder": {
33
+ "config": {
34
+ "in_channels": 2,
35
+ "channels": 128,
36
+ "c_mults": [1, 2, 4, 8, 16],
37
+ "strides": [2, 4, 4, 4, 8],
38
+ "latent_dim": 128,
39
+ "use_snake": True
40
+ }
41
+ },
42
+ "decoder": {
43
+ "config": {
44
+ "out_channels": 2,
45
+ "channels": 128,
46
+ "c_mults": [1, 2, 4, 8, 16],
47
+ "strides": [2, 4, 4, 4, 8],
48
+ "latent_dim": 64,
49
+ "use_nearest_upsample": False,
50
+ "use_snake": True,
51
+ "final_tanh": False,
52
+ },
53
+ },
54
+ "latent_dim": 64,
55
+ "downsampling_ratio": 1024,
56
+ "io_channels": 2,
57
+ }
58
+ else:
59
+ model_config = model_config
60
+
61
+ if model_config.get("transformer") is not None:
62
+ self.transformers = ContinuousTransformer(
63
+ dim=model_config["decoder"]["config"]["latent_dim"],
64
+ depth=model_config["transformer"]["depth"],
65
+ **model_config["transformer"].get("config", {}),
66
+ )
67
+ else:
68
+ self.transformers = None
69
+
70
+ self.encoder = OobleckEncoder(**model_config["encoder"]["config"])
71
+ self.decoder = OobleckDecoder(**model_config["decoder"]["config"])
72
+
73
+ def forward(self, audio) -> Tensor:
74
+ """
75
+ audio: Input audio tensor [B,C,T]
76
+ """
77
+ status = self.encoder(audio)
78
+ mean, scale = status.chunk(2, dim=1)
79
+ z, kl = vae_sample(mean, scale)
80
+
81
+ if self.transformers is not None:
82
+ z = z.permute(0, 2, 1)
83
+ z = self.transformers(z)
84
+ z = z.permute(0, 2, 1)
85
+
86
+ x = self.decoder(z)
87
+ return x, kl
88
+
89
+ def encode(self, audio, use_sample=True):
90
+ x = self.encoder(audio)
91
+ mean, scale = x.chunk(2, dim=1)
92
+ if use_sample:
93
+ z, _ = vae_sample(mean, scale)
94
+ else:
95
+ z = mean
96
+ return z
97
+
98
+ def decode(self, z):
99
+
100
+ if self.transformers is not None:
101
+ z = z.permute(0, 2, 1)
102
+ z = self.transformers(z)
103
+ z = z.permute(0, 2, 1)
104
+
105
+ x = self.decoder(z)
106
+ return x
107
+
108
+ @no_grad()
109
+ def inference(self, audio):
110
+ z = self.encode(audio)
111
+ recon_audio = self.decode(z)
112
+ return recon_audio
model/transformer.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Literal
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from packaging import version
8
+ from torch import einsum, nn
9
+ from torch.cuda.amp import autocast
10
+
11
+ try:
12
+ from flash_attn import flash_attn_func, flash_attn_kvpacked_func
13
+ # flash_attn==2.3.3 is required
14
+ except ImportError as e:
15
+ print(e)
16
+ print('flash_attn not installed, disabling Flash Attention')
17
+ flash_attn_kvpacked_func = None
18
+ flash_attn_func = None
19
+
20
+ try:
21
+ import natten
22
+ except ImportError:
23
+ natten = None
24
+ import math
25
+ from functools import reduce
26
+
27
+ import numpy as np
28
+
29
+
30
+ class FourierFeatures(nn.Module):
31
+ def __init__(self, in_features, out_features, std=1.):
32
+ super().__init__()
33
+ assert out_features % 2 == 0
34
+ self.weight = nn.Parameter(torch.randn(
35
+ [out_features // 2, in_features]) * std)
36
+
37
+ def forward(self, input):
38
+ f = 2 * math.pi * input @ self.weight.T
39
+ return torch.cat([f.cos(), f.sin()], dim=-1)
40
+
41
+
42
+ def normalize(x, eps=1e-4):
43
+ dim = list(range(1, x.ndim))
44
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
45
+ alpha = np.sqrt(n.numel() / x.numel())
46
+ return x / torch.add(eps, n, alpha=alpha)
47
+
48
+
49
+ def checkpoint(function, *args, **kwargs):
50
+ kwargs.setdefault("use_reentrant", False)
51
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
52
+
53
+
54
+ def create_causal_mask(i, j, device):
55
+ return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
56
+
57
+
58
+ def or_reduce(masks):
59
+ head, *body = masks
60
+ for rest in body:
61
+ head = head | rest
62
+ return head
63
+
64
+
65
+ # positional embeddings
66
+
67
+ class AbsolutePositionalEmbedding(nn.Module):
68
+ def __init__(self, dim, max_seq_len):
69
+ super().__init__()
70
+ self.scale = dim ** -0.5
71
+ self.max_seq_len = max_seq_len
72
+ self.emb = nn.Embedding(max_seq_len, dim)
73
+
74
+ def forward(self, x, pos=None, seq_start_pos=None):
75
+ seq_len, device = x.shape[1], x.device
76
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
77
+
78
+ if pos is None:
79
+ pos = torch.arange(seq_len, device=device)
80
+
81
+ if seq_start_pos is not None:
82
+ pos = (pos - seq_start_pos[..., None]).clamp(min=0)
83
+
84
+ pos_emb = self.emb(pos)
85
+ pos_emb = pos_emb * self.scale
86
+ return pos_emb
87
+
88
+
89
+ class ScaledSinusoidalEmbedding(nn.Module):
90
+ def __init__(self, dim, theta=10000):
91
+ super().__init__()
92
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
93
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
94
+
95
+ half_dim = dim // 2
96
+ freq_seq = torch.arange(half_dim).float() / half_dim
97
+ inv_freq = theta ** -freq_seq
98
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
99
+
100
+ def forward(self, x, pos=None, seq_start_pos=None):
101
+ seq_len, device = x.shape[1], x.device
102
+
103
+ if pos is None:
104
+ pos = torch.arange(seq_len, device=device)
105
+
106
+ if seq_start_pos is not None:
107
+ pos = pos - seq_start_pos[..., None]
108
+
109
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
110
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
111
+ return emb * self.scale
112
+
113
+
114
+ class RotaryEmbedding(nn.Module):
115
+ def __init__(
116
+ self,
117
+ dim,
118
+ use_xpos=False,
119
+ scale_base=512,
120
+ interpolation_factor=1.,
121
+ base=10000,
122
+ base_rescale_factor=1.
123
+ ):
124
+ super().__init__()
125
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
126
+ # has some connection to NTK literature
127
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
128
+ base *= base_rescale_factor ** (dim / (dim - 2))
129
+
130
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
131
+ self.register_buffer('inv_freq', inv_freq)
132
+
133
+ assert interpolation_factor >= 1.
134
+ self.interpolation_factor = interpolation_factor
135
+
136
+ if not use_xpos:
137
+ self.register_buffer('scale', None)
138
+ return
139
+
140
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
141
+
142
+ self.scale_base = scale_base
143
+ self.register_buffer('scale', scale)
144
+
145
+ def forward_from_seq_len(self, seq_len):
146
+ device = self.inv_freq.device
147
+
148
+ t = torch.arange(seq_len, device=device)
149
+ return self.forward(t)
150
+
151
+ @torch.amp.autocast('cuda', enabled=False)
152
+ def forward(self, t):
153
+ device = self.inv_freq.device
154
+
155
+ t = t.to(torch.float32)
156
+
157
+ t = t / self.interpolation_factor
158
+
159
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
160
+ freqs = torch.cat((freqs, freqs), dim=-1)
161
+
162
+ if self.scale is None:
163
+ return freqs, 1.
164
+
165
+ power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base
166
+ scale = self.scale ** rearrange(power, 'n -> n 1')
167
+ scale = torch.cat((scale, scale), dim=-1)
168
+
169
+ return freqs, scale
170
+
171
+
172
+ def rotate_half(x):
173
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
174
+ x1, x2 = x.unbind(dim=-2)
175
+ return torch.cat((-x2, x1), dim=-1)
176
+
177
+
178
+ @torch.amp.autocast('cuda', enabled=False)
179
+ def apply_rotary_pos_emb(t, freqs, scale=1):
180
+ out_dtype = t.dtype
181
+
182
+ # cast to float32 if necessary for numerical stability
183
+ dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
184
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
185
+ freqs, t = freqs.to(dtype), t.to(dtype)
186
+ freqs = freqs[-seq_len:, :]
187
+
188
+ if t.ndim == 4 and freqs.ndim == 3:
189
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
190
+
191
+ # partial rotary embeddings, Wang et al. GPT-J
192
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
193
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
194
+
195
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
196
+
197
+ return torch.cat((t, t_unrotated), dim=-1)
198
+
199
+
200
+ # norms
201
+ class LayerNorm(nn.Module):
202
+ def __init__(self, dim, bias=False, fix_scale=False):
203
+ """
204
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
205
+ """
206
+ super().__init__()
207
+
208
+ if fix_scale:
209
+ self.register_buffer("gamma", torch.ones(dim))
210
+ else:
211
+ self.gamma = nn.Parameter(torch.ones(dim))
212
+
213
+ if bias:
214
+ self.beta = nn.Parameter(torch.zeros(dim))
215
+ else:
216
+ self.register_buffer("beta", torch.zeros(dim))
217
+
218
+ def forward(self, x):
219
+ return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
220
+
221
+
222
+ # feedforward
223
+
224
+ class GLU(nn.Module):
225
+ def __init__(
226
+ self,
227
+ dim_in,
228
+ dim_out,
229
+ activation: Callable,
230
+ use_conv=False,
231
+ conv_kernel_size=3,
232
+ ):
233
+ super().__init__()
234
+ self.act = activation
235
+ self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size,
236
+ padding=(conv_kernel_size // 2))
237
+ self.use_conv = use_conv
238
+
239
+ def forward(self, x):
240
+ if self.use_conv:
241
+ x = rearrange(x, 'b n d -> b d n')
242
+ x = self.proj(x)
243
+ x = rearrange(x, 'b d n -> b n d')
244
+ else:
245
+ x = self.proj(x)
246
+
247
+ x, gate = x.chunk(2, dim=-1)
248
+ return x * self.act(gate)
249
+
250
+
251
+ class FeedForward(nn.Module):
252
+ def __init__(
253
+ self,
254
+ dim,
255
+ dim_out=None,
256
+ mult=4,
257
+ no_bias=False,
258
+ glu=True,
259
+ use_conv=False,
260
+ conv_kernel_size=3,
261
+ zero_init_output=True,
262
+ ):
263
+ super().__init__()
264
+ inner_dim = int(dim * mult)
265
+
266
+ # Default to SwiGLU
267
+
268
+ activation = nn.SiLU()
269
+
270
+ dim_out = dim if dim_out is None else dim_out
271
+
272
+ if glu:
273
+ linear_in = GLU(dim, inner_dim, activation)
274
+ else:
275
+ linear_in = nn.Sequential(
276
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
277
+ nn.Linear(dim, inner_dim, bias=not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim,
278
+ conv_kernel_size, padding=(
279
+ conv_kernel_size // 2), bias=not no_bias),
280
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
281
+ activation
282
+ )
283
+
284
+ linear_out = nn.Linear(inner_dim, dim_out, bias=not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out,
285
+ conv_kernel_size,
286
+ padding=(
287
+ conv_kernel_size // 2),
288
+ bias=not no_bias)
289
+
290
+ # init last linear layer to 0
291
+ if zero_init_output:
292
+ nn.init.zeros_(linear_out.weight)
293
+ if not no_bias:
294
+ nn.init.zeros_(linear_out.bias)
295
+
296
+ self.ff = nn.Sequential(
297
+ linear_in,
298
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
299
+ linear_out,
300
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
301
+ )
302
+
303
+ def forward(self, x):
304
+ return self.ff(x)
305
+
306
+
307
+ class Attention(nn.Module):
308
+ def __init__(
309
+ self,
310
+ dim,
311
+ dim_heads=64,
312
+ dim_context=None,
313
+ causal=False,
314
+ zero_init_output=True,
315
+ qk_norm: Literal['l2', 'ln', 'none'] = 'none',
316
+ natten_kernel_size=None
317
+ ):
318
+ super().__init__()
319
+ self.dim = dim
320
+ self.dim_heads = dim_heads
321
+ self.causal = causal
322
+
323
+ dim_kv = dim_context if dim_context is not None else dim
324
+
325
+ self.num_heads = dim // dim_heads
326
+ self.kv_heads = dim_kv // dim_heads
327
+
328
+ if dim_context is not None:
329
+ self.to_q = nn.Linear(dim, dim, bias=False)
330
+ self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
331
+ else:
332
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
333
+
334
+ self.to_out = nn.Linear(dim, dim, bias=False)
335
+
336
+ if zero_init_output:
337
+ nn.init.zeros_(self.to_out.weight)
338
+
339
+ self.qk_norm = qk_norm
340
+
341
+ if self.qk_norm == "ln":
342
+ self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
343
+ self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
344
+
345
+ # Using 1d neighborhood attention
346
+ self.natten_kernel_size = natten_kernel_size
347
+ if natten_kernel_size is not None:
348
+ return
349
+
350
+ self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
351
+
352
+ self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
353
+
354
+ self.sdp_kwargs = dict(
355
+ enable_flash=True,
356
+ enable_math=True,
357
+ enable_mem_efficient=True
358
+ )
359
+
360
+ def flash_attn(
361
+ self,
362
+ q,
363
+ k,
364
+ v,
365
+ mask=None,
366
+ causal=None
367
+ ):
368
+ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
369
+ kv_heads = k.shape[1]
370
+ # Recommended for multi-query single-key-value attention by Tri Dao
371
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
372
+
373
+ if heads != kv_heads:
374
+ # Repeat interleave kv_heads to match q_heads
375
+ heads_per_kv_head = heads // kv_heads
376
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
377
+
378
+ if k.ndim == 3:
379
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
380
+
381
+ if v.ndim == 3:
382
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
383
+
384
+ causal = self.causal if causal is None else causal
385
+
386
+ if q_len == 1 and causal:
387
+ causal = False
388
+
389
+ if mask is not None:
390
+ assert mask.ndim == 4
391
+ mask = mask.expand(batch, heads, q_len, k_len)
392
+
393
+ # handle kv cache - this should be bypassable in updated flash attention 2
394
+
395
+ if k_len > q_len and causal:
396
+ causal_mask = self.create_causal_mask(q_len, k_len, device=device)
397
+ if mask is None:
398
+ mask = ~causal_mask
399
+ else:
400
+ mask = mask & ~causal_mask
401
+ causal = False
402
+
403
+ # manually handle causal mask, if another mask was given
404
+
405
+ row_is_entirely_masked = None
406
+
407
+ if mask is not None and causal:
408
+ causal_mask = self.create_causal_mask(q_len, k_len, device=device)
409
+ mask = mask & ~causal_mask
410
+
411
+ # protect against an entire row being masked out
412
+
413
+ row_is_entirely_masked = ~mask.any(dim=-1)
414
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
415
+
416
+ causal = False
417
+
418
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
419
+ out = F.scaled_dot_product_attention(
420
+ q, k, v,
421
+ attn_mask=mask,
422
+ is_causal=causal
423
+ )
424
+
425
+ # for a row that is entirely masked out, should zero out the output of that row token
426
+
427
+ if row_is_entirely_masked is not None:
428
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
429
+
430
+ return out
431
+
432
+ def forward(
433
+ self,
434
+ x,
435
+ context=None,
436
+ mask=None,
437
+ context_mask=None,
438
+ rotary_pos_emb=None,
439
+ causal=None
440
+ ):
441
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
442
+
443
+ kv_input = context if has_context else x
444
+
445
+ if hasattr(self, 'to_q'):
446
+ # Use separate linear projections for q and k/v
447
+ q = self.to_q(x)
448
+ q = rearrange(q, 'b n (h d) -> b h n d', h=h)
449
+
450
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
451
+
452
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=kv_h), (k, v))
453
+ else:
454
+ # Use fused linear projection
455
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
456
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
457
+
458
+ # Normalize q and k for cosine sim attention
459
+ if self.qk_norm == "l2":
460
+ q = F.normalize(q, dim=-1)
461
+ k = F.normalize(k, dim=-1)
462
+ elif self.qk_norm == "ln":
463
+ q = self.q_norm(q)
464
+ k = self.k_norm(k)
465
+
466
+ if rotary_pos_emb is not None and not has_context:
467
+ freqs, _ = rotary_pos_emb
468
+
469
+ q_dtype = q.dtype
470
+ k_dtype = k.dtype
471
+
472
+ q = q.to(torch.float32)
473
+ k = k.to(torch.float32)
474
+ freqs = freqs.to(torch.float32)
475
+
476
+ q = apply_rotary_pos_emb(q, freqs)
477
+ k = apply_rotary_pos_emb(k, freqs)
478
+
479
+ q = q.to(q_dtype)
480
+ k = k.to(k_dtype)
481
+
482
+ input_mask = context_mask
483
+
484
+ if input_mask is None and not has_context:
485
+ input_mask = mask
486
+
487
+ # determine masking
488
+ masks = []
489
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
490
+
491
+ if input_mask is not None:
492
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
493
+ masks.append(~input_mask)
494
+
495
+ # Other masks will be added here later
496
+
497
+ if len(masks) > 0:
498
+ final_attn_mask = ~or_reduce(masks)
499
+
500
+ n, device = q.shape[-2], q.device
501
+
502
+ causal = self.causal if causal is None else causal
503
+
504
+ if n == 1 and causal:
505
+ causal = False
506
+
507
+ if self.natten_kernel_size is not None:
508
+ if natten is None:
509
+ raise ImportError('natten not installed, please install natten to use neighborhood attention')
510
+
511
+ dtype_in = q.dtype
512
+ q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
513
+
514
+ attn = natten.functional.natten1dqk(q, k, kernel_size=self.natten_kernel_size, dilation=1)
515
+
516
+ if final_attn_mask is not None:
517
+ attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
518
+
519
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32)
520
+
521
+ out = natten.functional.natten1dav(attn, v, kernel_size=self.natten_kernel_size, dilation=1).to(dtype_in)
522
+
523
+ # Prioritize Flash Attention 2
524
+ elif self.use_fa_flash:
525
+ assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
526
+ # Flash Attention 2 requires FP16 inputs
527
+ fa_dtype_in = q.dtype
528
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
529
+
530
+ out = flash_attn_func(q, k, v, causal=causal)
531
+
532
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
533
+
534
+ # Fall back to PyTorch implementation
535
+ elif self.use_pt_flash:
536
+ out = self.flash_attn(q, k, v, causal=causal, mask=final_attn_mask)
537
+
538
+ else:
539
+ # Fall back to custom implementation
540
+
541
+ if h != kv_h:
542
+ # Repeat interleave kv_heads to match q_heads
543
+ heads_per_kv_head = h // kv_h
544
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
545
+
546
+ scale = 1. / (q.shape[-1] ** 0.5)
547
+
548
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
549
+
550
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
551
+
552
+ i, j, dtype = *dots.shape[-2:], dots.dtype
553
+
554
+ mask_value = -torch.finfo(dots.dtype).max
555
+
556
+ if final_attn_mask is not None:
557
+ dots = dots.masked_fill(~final_attn_mask, mask_value)
558
+
559
+ if causal:
560
+ causal_mask = self.create_causal_mask(i, j, device=device)
561
+ dots = dots.masked_fill(causal_mask, mask_value)
562
+
563
+ attn = F.softmax(dots, dim=-1, dtype=torch.float32)
564
+ attn = attn.type(dtype)
565
+
566
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
567
+
568
+ # merge heads
569
+ out = rearrange(out, ' b h n d -> b n (h d)')
570
+ out = self.to_out(out)
571
+
572
+ if mask is not None:
573
+ mask = rearrange(mask, 'b n -> b n 1')
574
+ out = out.masked_fill(~mask, 0.)
575
+
576
+ return out
577
+
578
+
579
+ class ConformerModule(nn.Module):
580
+ def __init__(
581
+ self,
582
+ dim,
583
+ norm_kwargs={},
584
+ ):
585
+ super().__init__()
586
+
587
+ self.dim = dim
588
+
589
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
590
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
591
+ self.glu = GLU(dim, dim, nn.SiLU())
592
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
593
+ self.mid_norm = LayerNorm(dim,
594
+ **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
595
+ self.swish = nn.SiLU()
596
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
597
+
598
+ def forward(self, x):
599
+ x = self.in_norm(x)
600
+ x = rearrange(x, 'b n d -> b d n')
601
+ x = self.pointwise_conv(x)
602
+ x = rearrange(x, 'b d n -> b n d')
603
+ x = self.glu(x)
604
+ x = rearrange(x, 'b n d -> b d n')
605
+ x = self.depthwise_conv(x)
606
+ x = rearrange(x, 'b d n -> b n d')
607
+ x = self.mid_norm(x)
608
+ x = self.swish(x)
609
+ x = rearrange(x, 'b n d -> b d n')
610
+ x = self.pointwise_conv_2(x)
611
+ x = rearrange(x, 'b d n -> b n d')
612
+
613
+ return x
614
+
615
+
616
+ class TransformerBlock(nn.Module):
617
+ def __init__(
618
+ self,
619
+ dim,
620
+ dim_heads=64,
621
+ cross_attend=False,
622
+ dim_context=None,
623
+ global_cond_dim=None,
624
+ causal=False,
625
+ zero_init_branch_outputs=True,
626
+ conformer=False,
627
+ layer_ix=-1,
628
+ remove_norms=False,
629
+ attn_kwargs={},
630
+ ff_kwargs={},
631
+ norm_kwargs={}
632
+ ):
633
+
634
+ super().__init__()
635
+ self.dim = dim
636
+ self.dim_heads = dim_heads
637
+ self.cross_attend = cross_attend
638
+ self.dim_context = dim_context
639
+ self.causal = causal
640
+
641
+ self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
642
+
643
+ self.self_attn = Attention(
644
+ dim,
645
+ dim_heads=dim_heads,
646
+ causal=causal,
647
+ zero_init_output=zero_init_branch_outputs,
648
+ **attn_kwargs
649
+ )
650
+
651
+ if cross_attend:
652
+ self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
653
+ self.cross_attn = Attention(
654
+ dim,
655
+ dim_heads=dim_heads,
656
+ dim_context=dim_context,
657
+ causal=causal,
658
+ zero_init_output=zero_init_branch_outputs,
659
+ **attn_kwargs
660
+ )
661
+
662
+ self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
663
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
664
+
665
+ self.layer_ix = layer_ix
666
+
667
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
668
+
669
+ self.global_cond_dim = global_cond_dim
670
+
671
+ if global_cond_dim is not None:
672
+ self.to_scale_shift_gate = nn.Sequential(
673
+ nn.SiLU(),
674
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
675
+ )
676
+
677
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
678
+ # nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
679
+
680
+ def forward(
681
+ self,
682
+ x,
683
+ context=None,
684
+ global_cond=None,
685
+ mask=None,
686
+ context_mask=None,
687
+ rotary_pos_emb=None
688
+ ):
689
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
690
+
691
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(
692
+ global_cond).unsqueeze(1).chunk(6, dim=-1)
693
+
694
+ # self-attention with adaLN
695
+ residual = x
696
+ x = self.pre_norm(x)
697
+ x = x * (1 + scale_self) + shift_self
698
+ x = self.self_attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb)
699
+ x = x * torch.sigmoid(1 - gate_self)
700
+ x = x + residual
701
+
702
+ if context is not None:
703
+ x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
704
+
705
+ if self.conformer is not None:
706
+ x = x + self.conformer(x)
707
+
708
+ # feedforward with adaLN
709
+ residual = x
710
+ x = self.ff_norm(x)
711
+ x = x * (1 + scale_ff) + shift_ff
712
+ x = self.ff(x)
713
+ x = x * torch.sigmoid(1 - gate_ff)
714
+ x = x + residual
715
+
716
+ else:
717
+ x = x + self.self_attn(self.pre_norm(x), mask=mask, rotary_pos_emb=rotary_pos_emb)
718
+
719
+ if context is not None:
720
+ x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
721
+
722
+ if self.conformer is not None:
723
+ x = x + self.conformer(x)
724
+
725
+ x = x + self.ff(self.ff_norm(x))
726
+
727
+ return x
728
+
729
+
730
+ class ContinuousTransformer(nn.Module):
731
+ def __init__(
732
+ self,
733
+ dim,
734
+ depth,
735
+ *,
736
+ dim_in=None,
737
+ dim_out=None,
738
+ dim_heads=64,
739
+ cross_attend=False,
740
+ cond_token_dim=None,
741
+ global_cond_dim=None,
742
+ causal=False,
743
+ rotary_pos_emb=True,
744
+ zero_init_branch_outputs=True,
745
+ conformer=False,
746
+ use_sinusoidal_emb=False,
747
+ use_abs_pos_emb=False,
748
+ abs_pos_emb_max_length=10000,
749
+ **kwargs
750
+ ):
751
+
752
+ super().__init__()
753
+
754
+ self.dim = dim
755
+ self.depth = depth
756
+ self.causal = causal
757
+ self.layers = nn.ModuleList([])
758
+
759
+ self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
760
+ self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
761
+
762
+ if rotary_pos_emb:
763
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
764
+ else:
765
+ self.rotary_pos_emb = None
766
+
767
+ self.use_sinusoidal_emb = use_sinusoidal_emb
768
+ if use_sinusoidal_emb:
769
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
770
+
771
+ self.use_abs_pos_emb = use_abs_pos_emb
772
+ if use_abs_pos_emb:
773
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
774
+
775
+ for i in range(depth):
776
+ self.layers.append(
777
+ TransformerBlock(
778
+ dim,
779
+ dim_heads=dim_heads,
780
+ cross_attend=cross_attend,
781
+ dim_context=cond_token_dim,
782
+ global_cond_dim=global_cond_dim,
783
+ causal=causal,
784
+ zero_init_branch_outputs=zero_init_branch_outputs,
785
+ conformer=conformer,
786
+ layer_ix=i,
787
+ **kwargs
788
+ )
789
+ )
790
+
791
+ def forward(
792
+ self,
793
+ x,
794
+ mask=None,
795
+ prepend_embeds=None,
796
+ prepend_mask=None,
797
+ global_cond=None,
798
+ return_info=False,
799
+ **kwargs
800
+ ):
801
+ batch, seq, device = *x.shape[:2], x.device
802
+
803
+ info = {
804
+ "hidden_states": [],
805
+ }
806
+
807
+ x = self.project_in(x)
808
+
809
+ if prepend_embeds is not None:
810
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
811
+
812
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
813
+
814
+ x = torch.cat((prepend_embeds, x), dim=-2)
815
+
816
+ if prepend_mask is not None or mask is not None:
817
+ mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool)
818
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length),
819
+ device=device, dtype=torch.bool)
820
+
821
+ mask = torch.cat((prepend_mask, mask), dim=-1)
822
+
823
+ # Attention layers
824
+
825
+ if self.rotary_pos_emb is not None:
826
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
827
+ else:
828
+ rotary_pos_emb = None
829
+
830
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
831
+ x = x + self.pos_emb(x)
832
+
833
+ # Iterate over the transformer layers
834
+ for layer in self.layers:
835
+ # x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
836
+ x = checkpoint(layer, x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs)
837
+
838
+ if return_info:
839
+ info["hidden_states"].append(x)
840
+
841
+ x = self.project_out(x)
842
+
843
+ if return_info:
844
+ return x, info
845
+
846
+ return x
pretrained_weight/ear_vae_44k.pyt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0362dc7e96566869747dbe079b0a6d71c090b0a3a5d5077779e7be17c096d9d5
3
+ size 591453838