alexnasa commited on
Commit
257f706
·
verified ·
1 Parent(s): 001b61a

Upload 69 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. INSTALL.md +55 -0
  3. LICENSE.txt +201 -0
  4. Makefile +5 -0
  5. README.md +12 -12
  6. app.py +546 -0
  7. examples/desi.mp4 +3 -0
  8. examples/desi.png +3 -0
  9. examples/man.png +3 -0
  10. examples/paul.mp4 +3 -0
  11. generate.py +236 -0
  12. pyproject.toml +66 -0
  13. requirements.txt +31 -0
  14. wan/__init__.py +7 -0
  15. wan/animate.py +653 -0
  16. wan/configs/__init__.py +50 -0
  17. wan/configs/shared_config.py +20 -0
  18. wan/configs/wan_animate_14B.py +40 -0
  19. wan/configs/wan_i2v_A14B.py +37 -0
  20. wan/configs/wan_s2v_14B.py +59 -0
  21. wan/configs/wan_t2v_A14B.py +37 -0
  22. wan/configs/wan_ti2v_5B.py +36 -0
  23. wan/distributed/__init__.py +1 -0
  24. wan/distributed/fsdp.py +45 -0
  25. wan/distributed/sequence_parallel.py +176 -0
  26. wan/distributed/ulysses.py +47 -0
  27. wan/distributed/util.py +51 -0
  28. wan/image2video.py +431 -0
  29. wan/modules/__init__.py +19 -0
  30. wan/modules/animate/__init__.py +4 -0
  31. wan/modules/animate/animate_utils.py +143 -0
  32. wan/modules/animate/clip.py +542 -0
  33. wan/modules/animate/face_blocks.py +383 -0
  34. wan/modules/animate/model_animate.py +500 -0
  35. wan/modules/animate/motion_encoder.py +307 -0
  36. wan/modules/animate/preprocess/UserGuider.md +70 -0
  37. wan/modules/animate/preprocess/__init__.py +3 -0
  38. wan/modules/animate/preprocess/human_visualization.py +1357 -0
  39. wan/modules/animate/preprocess/pose2d.py +430 -0
  40. wan/modules/animate/preprocess/pose2d_utils.py +1159 -0
  41. wan/modules/animate/preprocess/preprocess_data.py +121 -0
  42. wan/modules/animate/preprocess/process_pipepline.py +354 -0
  43. wan/modules/animate/preprocess/retarget_pose.py +847 -0
  44. wan/modules/animate/preprocess/sam_utils.py +155 -0
  45. wan/modules/animate/preprocess/utils.py +226 -0
  46. wan/modules/animate/preprocess/video_predictor.py +157 -0
  47. wan/modules/animate/xlm_roberta.py +170 -0
  48. wan/modules/attention.py +256 -0
  49. wan/modules/model.py +546 -0
  50. wan/modules/s2v/__init__.py +5 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/desi.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/desi.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/man.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/paul.mp4 filter=lfs diff=lfs merge=lfs -text
INSTALL.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation Guide
2
+
3
+ ## Install with pip
4
+
5
+ ```bash
6
+ pip install .
7
+ pip install .[dev] # Installe aussi les outils de dev
8
+ ```
9
+
10
+ ## Install with Poetry
11
+
12
+ Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
13
+
14
+ To install all dependencies:
15
+
16
+ ```bash
17
+ poetry install
18
+ ```
19
+
20
+ ### Handling `flash-attn` Installation Issues
21
+
22
+ If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
23
+
24
+ #### No-Build-Isolation Installation (Recommended)
25
+ ```bash
26
+ poetry run pip install --upgrade pip setuptools wheel
27
+ poetry run pip install flash-attn --no-build-isolation
28
+ poetry install
29
+ ```
30
+
31
+ #### Install from Git (Alternative)
32
+ ```bash
33
+ poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
34
+ ```
35
+
36
+ ---
37
+
38
+ ### Running the Model
39
+
40
+ Once the installation is complete, you can run **Wan2.2** using:
41
+
42
+ ```bash
43
+ poetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
44
+ ```
45
+
46
+ #### Test
47
+ ```bash
48
+ bash tests/test.sh
49
+ ```
50
+
51
+ #### Format
52
+ ```bash
53
+ black .
54
+ isort .
55
+ ```
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .PHONY: format
2
+
3
+ format:
4
+ isort generate.py wan
5
+ yapf -i -r *.py generate.py wan
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Wan2.2 Animate ZEROGPU
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Wan2.2 Animate [Local]
3
+ emoji: 🔥
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from huggingface_hub import snapshot_download, hf_hub_download
3
+ import os
4
+ import subprocess
5
+ import importlib, site
6
+ from PIL import Image
7
+ import uuid
8
+ import shutil
9
+ import time
10
+ import cv2
11
+ from generate import generate, load_model
12
+ import json
13
+
14
+ # Re-discover all .pth/.egg-link files
15
+ for sitedir in site.getsitepackages():
16
+ site.addsitedir(sitedir)
17
+
18
+ # Clear caches so importlib will pick up new modules
19
+ importlib.invalidate_caches()
20
+
21
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
22
+
23
+ try:
24
+ print("Attempting to download and build sam2...")
25
+
26
+ print("download sam")
27
+ sam_dir = snapshot_download(repo_id="alexnasa/sam2")
28
+
29
+ @spaces.GPU(duration=450)
30
+ def install_sam():
31
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
32
+ sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .")
33
+
34
+ print("install sam")
35
+ install_sam()
36
+
37
+ # tell Python to re-scan site-packages now that the egg-link exists
38
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
39
+
40
+ flash_attention_installed = True
41
+ print("sam2 installed successfully.")
42
+
43
+ except Exception as e:
44
+ print(f"⚠️ Could not install sam2: {e}")
45
+ print("Continuing without sam2...")
46
+
47
+ import torch
48
+ print(f"Torch version: {torch.__version__}")
49
+
50
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
51
+
52
+ import gradio as gr
53
+
54
+
55
+ snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B")
56
+ wan_animate = load_model(True)
57
+
58
+
59
+ rc_mapping = {
60
+ "Video → Ref Image" : False,
61
+ "Video ← Ref Image" : True
62
+ }
63
+
64
+
65
+ def preprocess_video(input_video_path, session_id=None):
66
+
67
+ if session_id is None:
68
+ session_id = uuid.uuid4().hex
69
+
70
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
71
+ os.makedirs(output_dir, exist_ok=True)
72
+
73
+ process_video_path = os.path.join(output_dir, 'input_video.mp4')
74
+
75
+ convert_video_to_30fps_and_clip(input_video_path, process_video_path, crop_width=720, crop_height=1280)
76
+
77
+ return process_video_path
78
+
79
+ def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None):
80
+ """
81
+ Extracts the audio track from a video file and saves it as a WAV file.
82
+
83
+ Args:
84
+ video_path (str): Path to the input video file.
85
+ output_wav_path (str): Path to save the extracted WAV file.
86
+ sample_rate (int, optional): Output sample rate (e.g., 16000).
87
+ If None, keep the original.
88
+ """
89
+ cmd = [
90
+ 'ffmpeg',
91
+ '-i', video_path, # Input video
92
+ '-vn', # Disable video
93
+ '-acodec', 'pcm_s16le', # 16-bit PCM (WAV format)
94
+ '-ac', '1', # Mono channel (use '2' for stereo)
95
+ '-y', # Overwrite output
96
+ '-loglevel', 'error' # Cleaner output
97
+ ]
98
+
99
+ # Only add the sample rate option if explicitly specified
100
+ if sample_rate is not None:
101
+ cmd.extend(['-ar', str(sample_rate)])
102
+
103
+ cmd.append(output_wav_path)
104
+
105
+ try:
106
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
107
+ except subprocess.CalledProcessError as e:
108
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
109
+
110
+
111
+ def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path):
112
+ """
113
+ Combines a silent MP4 video with a WAV audio file into a single MP4 with sound.
114
+
115
+ Args:
116
+ video_path (str): Path to the silent video file.
117
+ audio_path (str): Path to the WAV audio file.
118
+ output_video_path (str): Path to save the output MP4 with audio.
119
+ """
120
+ cmd = [
121
+ 'ffmpeg',
122
+ '-i', video_path, # Input video
123
+ '-i', audio_path, # Input audio
124
+ '-c:v', 'copy', # Copy video without re-encoding
125
+ '-c:a', 'aac', # Encode audio as AAC (MP4-compatible)
126
+ '-shortest', # Stop when the shortest stream ends
127
+ '-y', # Overwrite output
128
+ '-loglevel', 'error',
129
+ output_video_path
130
+ ]
131
+
132
+ try:
133
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
134
+ except subprocess.CalledProcessError as e:
135
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
136
+
137
+
138
+ def convert_video_to_30fps_and_clip(
139
+ input_video_path,
140
+ output_video_path,
141
+ duration_s=2,
142
+ target_fps=30,
143
+ crop_width=None,
144
+ crop_height=None
145
+ ):
146
+ # Get input video dimensions using ffprobe
147
+ if crop_width and crop_height:
148
+ probe_cmd = [
149
+ 'ffprobe', '-v', 'error', '-select_streams', 'v:0',
150
+ '-show_entries', 'stream=width,height',
151
+ '-of', 'json', input_video_path
152
+ ]
153
+ probe_result = subprocess.run(probe_cmd, capture_output=True, text=True, check=True)
154
+ video_info = json.loads(probe_result.stdout)
155
+ w = video_info['streams'][0]['width']
156
+ h = video_info['streams'][0]['height']
157
+
158
+ # Clamp crop size to not exceed actual dimensions
159
+ crop_width = min(crop_width, w)
160
+ crop_height = min(crop_height, h)
161
+
162
+ # Center crop offsets
163
+ crop_x = max((w - crop_width) // 2, 0)
164
+ crop_y = max((h - crop_height) // 2, 0)
165
+ crop_filter = f"crop={crop_width}:{crop_height}:{crop_x}:{crop_y}"
166
+ else:
167
+ crop_filter = None
168
+
169
+ cmd = [
170
+ 'ffmpeg',
171
+ '-i', input_video_path,
172
+ '-r', str(target_fps),
173
+ '-t', str(duration_s),
174
+ ]
175
+
176
+ if crop_filter:
177
+ cmd += ['-vf', crop_filter]
178
+
179
+ cmd += [
180
+ '-c:v', 'libx264',
181
+ '-c:a', 'aac',
182
+ '-strict', 'experimental',
183
+ '-y',
184
+ '-loglevel', 'error',
185
+ output_video_path
186
+ ]
187
+
188
+ try:
189
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
190
+ except subprocess.CalledProcessError as e:
191
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
192
+
193
+ def get_frames_count(video_file):
194
+
195
+ # Get video information
196
+ cap = cv2.VideoCapture(video_file)
197
+ if not cap.isOpened():
198
+ error_msg = "Cannot open video file"
199
+ gr.Warning(error_msg)
200
+
201
+ orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
202
+ orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
203
+ orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
204
+
205
+ cap.release()
206
+
207
+ return orig_frame_count
208
+
209
+ def calculate_time_required(input_video, rc_bool):
210
+
211
+ frames_count = get_frames_count(input_video)
212
+
213
+ chunks = frames_count // 77 + 1
214
+
215
+
216
+ if rc_bool:
217
+ pose2d_tracking_duration_s = 75
218
+ iteration_per_step_s = 13
219
+ else:
220
+ pose2d_tracking_duration_s = 50
221
+ iteration_per_step_s = 12
222
+
223
+ time_required = pose2d_tracking_duration_s + iteration_per_step_s * 20 * chunks
224
+ print(f'for frames_count:{frames_count} doing {chunks} chunks the time_required is {time_required}')
225
+ return time_required
226
+
227
+ def update_time_required(input_video, rc_str):
228
+
229
+ if input_video is None:
230
+ return gr.update(value="⌚ Zero GPU Required: --")
231
+
232
+ rc_bool = rc_mapping[rc_str]
233
+
234
+ duration_s = calculate_time_required(input_video, rc_bool)
235
+ duration_m = duration_s / 60
236
+
237
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
238
+
239
+ def get_duration(input_video, edited_frame, rc_bool, session_id, progress):
240
+
241
+ return calculate_time_required(input_video, rc_bool)
242
+
243
+
244
+ @spaces.GPU(duration=get_duration)
245
+ def _animate(input_video, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),):
246
+
247
+ if session_id is None:
248
+ session_id = uuid.uuid4().hex
249
+
250
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
251
+ os.makedirs(output_dir, exist_ok=True)
252
+
253
+ preprocess_dir = os.path.join(output_dir, "preprocess_dir")
254
+ os.makedirs(preprocess_dir, exist_ok=True)
255
+
256
+ output_video_path = os.path.join(output_dir, 'result.mp4')
257
+
258
+ # --- Measure preprocess time ---
259
+ start_preprocess = time.time()
260
+
261
+ # w = 720
262
+ # h = 480
263
+
264
+ # w = 720
265
+ # h = 1280
266
+
267
+ w = 480
268
+ h = 832
269
+
270
+ # w = 480
271
+ # h = 720
272
+
273
+ tag_string = "retarget_flag"
274
+
275
+ if rc_bool:
276
+ tag_string = "replace_flag"
277
+
278
+ sh("python ./wan/modules/animate/preprocess/preprocess_data.py "
279
+ "--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint "
280
+ f"--video_path {input_video} "
281
+ f"--refer_path {edited_frame} "
282
+ f"--save_path {preprocess_dir} "
283
+ f"--resolution_area {w} {h} --{tag_string} "
284
+ )
285
+
286
+ preprocess_time = time.time() - start_preprocess
287
+ print(f"Preprocess took {preprocess_time:.2f} seconds")
288
+
289
+ # --- Measure generate time ---
290
+ start_generate = time.time()
291
+
292
+ generate(wan_animate, preprocess_dir, output_video_path, rc_bool)
293
+
294
+ generate_time = time.time() - start_generate
295
+ print(f"Generate took {generate_time:.2f} seconds")
296
+
297
+ # --- Optional total time ---
298
+ total_time = preprocess_time + generate_time
299
+ print(f"Total time: {total_time:.2f} seconds")
300
+
301
+ return output_video_path
302
+
303
+ def animate_scene(input_video, edited_frame, rc_str, session_id = None, progress=gr.Progress(track_tqdm=True),):
304
+
305
+ if not input_video:
306
+ raise gr.Error("Please provide an video")
307
+
308
+ if not edited_frame:
309
+ raise gr.Error("Please provide an image")
310
+
311
+ if session_id is None:
312
+ session_id = uuid.uuid4().hex
313
+
314
+ rc_bool = rc_mapping[rc_str]
315
+
316
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
317
+ os.makedirs(output_dir, exist_ok=True)
318
+
319
+ input_audio_path = os.path.join(output_dir, 'input_audio.wav')
320
+
321
+ extract_audio_from_video_ffmpeg(input_video, input_audio_path)
322
+
323
+ output_video_path = _animate(input_video, edited_frame, rc_bool, session_id, progress)
324
+
325
+ final_video_path = os.path.join(output_dir, 'final_result.mp4')
326
+
327
+ preprocess_dir = os.path.join(output_dir, "preprocess_dir")
328
+ pose_video = os.path.join(preprocess_dir, 'src_pose.mp4')
329
+
330
+ if rc_bool:
331
+ mask_video = os.path.join(preprocess_dir, 'src_mask.mp4')
332
+ bg_video = os.path.join(preprocess_dir, 'src_bg.mp4')
333
+ face_video = os.path.join(preprocess_dir, 'src_face.mp4')
334
+ else:
335
+ mask_video = os.path.join(preprocess_dir, 'src_pose.mp4')
336
+ bg_video = os.path.join(preprocess_dir, 'src_pose.mp4')
337
+ face_video = os.path.join(preprocess_dir, 'src_pose.mp4')
338
+
339
+ combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path)
340
+
341
+ return final_video_path, pose_video, bg_video, mask_video, face_video
342
+
343
+ css = """
344
+ #col-container {
345
+ margin: 0 auto;
346
+ max-width: 1600px;
347
+ }
348
+
349
+ #step-column {
350
+ padding: 20px;
351
+ border-radius: 8px;
352
+ box-shadow: var(--card-shadow);
353
+ margin: 10px;
354
+ }
355
+
356
+ #col-showcase {
357
+ margin: 0 auto;
358
+ max-width: 1100px;
359
+ }
360
+
361
+ .button-gradient {
362
+ background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%;
363
+ border: none;
364
+ padding: 14px 28px;
365
+ font-size: 16px;
366
+ font-weight: bold;
367
+ color: white;
368
+ border-radius: 10px;
369
+ cursor: pointer;
370
+ transition: 0.3s ease-in-out;
371
+ animation: 2s linear 0s infinite normal none running gradientAnimation;
372
+ box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px;
373
+ }
374
+
375
+ .toggle-container {
376
+ display: inline-flex;
377
+ background-color: #ffd6ff; /* light pink background */
378
+ border-radius: 9999px;
379
+ padding: 4px;
380
+ position: relative;
381
+ width: fit-content;
382
+ font-family: sans-serif;
383
+ }
384
+
385
+ .toggle-container input[type="radio"] {
386
+ display: none;
387
+ }
388
+
389
+ .toggle-container label {
390
+ position: relative;
391
+ z-index: 2;
392
+ flex: 1;
393
+ text-align: center;
394
+ font-weight: 700;
395
+ color: #4b2ab5; /* dark purple text for unselected */
396
+ padding: 6px 22px;
397
+ border-radius: 9999px;
398
+ cursor: pointer;
399
+ transition: color 0.25s ease;
400
+ }
401
+
402
+ /* Moving highlight */
403
+ .toggle-highlight {
404
+ position: absolute;
405
+ top: 4px;
406
+ left: 4px;
407
+ width: calc(50% - 4px);
408
+ height: calc(100% - 8px);
409
+ background-color: #4b2ab5; /* dark purple background */
410
+ border-radius: 9999px;
411
+ transition: transform 0.25s ease;
412
+ z-index: 1;
413
+ }
414
+
415
+ /* When "True" is checked */
416
+ #true:checked ~ label[for="true"] {
417
+ color: #ffd6ff; /* light pink text */
418
+ }
419
+
420
+ /* When "False" is checked */
421
+ #false:checked ~ label[for="false"] {
422
+ color: #ffd6ff; /* light pink text */
423
+ }
424
+
425
+ /* Move highlight to right side when False is checked */
426
+ #false:checked ~ .toggle-highlight {
427
+ transform: translateX(100%);
428
+ }
429
+ """
430
+ def start_session(request: gr.Request):
431
+
432
+ return request.session_hash
433
+
434
+ def cleanup(request: gr.Request):
435
+
436
+ sid = request.session_hash
437
+
438
+ if sid:
439
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
440
+ shutil.rmtree(d1, ignore_errors=True)
441
+
442
+ with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo:
443
+
444
+ session_state = gr.State()
445
+ demo.load(start_session, outputs=[session_state])
446
+
447
+ with gr.Column(elem_id="col-container"):
448
+ with gr.Row():
449
+ gr.HTML(
450
+ """
451
+ <div style="text-align: center;">
452
+ <p style="font-size:16px; display: inline; margin: 0;">
453
+ <strong>Wan2.2-Animate-14B </strong>
454
+ </p>
455
+ <a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
456
+ [Model]
457
+ </a>
458
+ <div style="text-align: center;">
459
+ <p style="font-size:16px; display: inline; margin: 0;">
460
+ HF Space By:
461
+ </p>
462
+ <a href="https://huggingface.co/alexnasa" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
463
+ <img src="https://img.shields.io/badge/🤗-Follow Me-yellow.svg">
464
+ </a>
465
+ </div>
466
+ """
467
+ )
468
+ with gr.Row():
469
+ with gr.Column(elem_id="step-column"):
470
+ gr.HTML("""
471
+ <div>
472
+ <span style="font-size: 24px;">1. Upload a Video</span><br>
473
+ </div>
474
+ """)
475
+ input_video = gr.Video(label="Input Video", height=512)
476
+
477
+
478
+ with gr.Column(elem_id="step-column"):
479
+ gr.HTML("""
480
+ <div>
481
+ <span style="font-size: 24px;">2. Upload a Ref Image</span><br>
482
+ </div>
483
+ """)
484
+ edited_frame = gr.Image(label="Ref Image", type="filepath", height=512)
485
+ gr.HTML("""
486
+ <div>
487
+ <span style="font-size: 24px;">3. Choose Mode</span><br>
488
+ </div>
489
+ """)
490
+ replace_character_string = gr.Radio(
491
+ ["Video → Ref Image", "Video ← Ref Image"], value="Video → Ref Image", show_label=False
492
+ )
493
+
494
+ with gr.Column(elem_id="step-column"):
495
+ gr.HTML("""
496
+ <div>
497
+ <span style="font-size: 24px;">4. Wan Animate it!</span><br>
498
+ </div>
499
+ """)
500
+ output_video = gr.Video(label="Edited Video", height=512)
501
+
502
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
503
+ action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
504
+
505
+ with gr.Accordion("Preprocessed Data", open=False, visible=False):
506
+ pose_video = gr.Video(label="Pose Video", height=512)
507
+ bg_video = gr.Video(label="Background Video", height=512)
508
+ face_video = gr.Video(label="Face Video", height=512)
509
+ mask_video = gr.Video(label="Mask Video", height=512)
510
+
511
+ with gr.Row():
512
+ with gr.Column(elem_id="col-showcase"):
513
+
514
+ gr.Examples(
515
+ examples=[
516
+
517
+ [
518
+ "./examples/desi.mp4",
519
+ "./examples/desi.png",
520
+ "Video ← Ref Image"
521
+ ],
522
+
523
+ [
524
+ "./examples/paul.mp4",
525
+ "./examples/man.png",
526
+ "Video → Ref Image"
527
+ ],
528
+
529
+
530
+ ],
531
+ inputs=[input_video, edited_frame, replace_character_string],
532
+ outputs=[output_video, pose_video, bg_video, mask_video, face_video],
533
+ fn=animate_scene,
534
+ cache_examples=True,
535
+ )
536
+
537
+ action_button.click(fn=animate_scene, inputs=[input_video, edited_frame, replace_character_string, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video])
538
+
539
+ input_video.upload(preprocess_video, inputs=[input_video, session_state], outputs=[input_video]).then(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required])
540
+ replace_character_string.change(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required])
541
+
542
+ if __name__ == "__main__":
543
+ demo.queue()
544
+ demo.unload(cleanup)
545
+ demo.launch(ssr_mode=False, share=True)
546
+
examples/desi.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02e84151e5625fb3863ebdf65dfab06940afac5fbd471db3b46a4ebd84b248d
3
+ size 551595
examples/desi.png ADDED

Git LFS Details

  • SHA256: 3f1a6ac41049380ddb43dcfb9efe1a0b6c561c4bb4132332fe07a82df263df66
  • Pointer size: 131 Bytes
  • Size of remote file: 477 kB
examples/man.png ADDED

Git LFS Details

  • SHA256: 6dc2c61f01a0290a8478fe3b494cf69ca054b2502b00b0be8c68a42ac544d5b5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.5 MB
examples/paul.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb065c2d24bff8a49955389f94c05c80d39638410dad8082f7e0eb7f2dc5c672
3
+ size 1029922
generate.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from datetime import datetime
8
+
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import random
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from PIL import Image
16
+
17
+ import wan
18
+ from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
+ from wan.distributed.util import init_distributed_group
20
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
+ from wan.utils.utils import merge_video_audio, save_video, str2bool
22
+
23
+
24
+ EXAMPLE_PROMPT = {
25
+ "t2v-A14B": {
26
+ "prompt":
27
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
28
+ },
29
+ "i2v-A14B": {
30
+ "prompt":
31
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
32
+ "image":
33
+ "examples/i2v_input.JPG",
34
+ },
35
+ "ti2v-5B": {
36
+ "prompt":
37
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
38
+ },
39
+ "animate-14B": {
40
+ "prompt": "视频中的人在做动作",
41
+ "video": "",
42
+ "pose": "",
43
+ "mask": "",
44
+ },
45
+ "s2v-14B": {
46
+ "prompt":
47
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
48
+ "image":
49
+ "examples/i2v_input.JPG",
50
+ "audio":
51
+ "examples/talk.wav",
52
+ "tts_prompt_audio":
53
+ "examples/zero_shot_prompt.wav",
54
+ "tts_prompt_text":
55
+ "希望你以后能够做的比我还好呦。",
56
+ "tts_text":
57
+ "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
58
+ },
59
+ }
60
+
61
+
62
+ def _validate_args(args):
63
+ # Basic check
64
+ assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
65
+ assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
66
+ assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
67
+
68
+ if args.prompt is None:
69
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
70
+ if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
71
+ args.image = EXAMPLE_PROMPT[args.task]["image"]
72
+ if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
73
+ args.audio = EXAMPLE_PROMPT[args.task]["audio"]
74
+ if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
75
+ args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
76
+ args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
77
+ args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
78
+
79
+ if args.task == "i2v-A14B":
80
+ assert args.image is not None, "Please specify the image path for i2v."
81
+
82
+ cfg = WAN_CONFIGS[args.task]
83
+
84
+ if args.sample_steps is None:
85
+ args.sample_steps = cfg.sample_steps
86
+
87
+ if args.sample_shift is None:
88
+ args.sample_shift = cfg.sample_shift
89
+
90
+ if args.sample_guide_scale is None:
91
+ args.sample_guide_scale = cfg.sample_guide_scale
92
+
93
+ if args.frame_num is None:
94
+ args.frame_num = cfg.frame_num
95
+
96
+ args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
97
+ 0, sys.maxsize)
98
+ # Size check
99
+ if not 's2v' in args.task:
100
+ assert args.size in SUPPORTED_SIZES[
101
+ args.
102
+ task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
103
+
104
+
105
+ class _Args:
106
+ pass
107
+
108
+ def _parse_args():
109
+ args = _Args()
110
+
111
+ # core generation options
112
+ args.task = "animate-14B"
113
+ # args.size = "1280*720"
114
+ args.size = "720*1280"
115
+ args.frame_num = None
116
+ args.ckpt_dir = "./Wan2.2-Animate-14B/"
117
+ args.offload_model = True
118
+ args.ulysses_size = 1
119
+ args.t5_fsdp = False
120
+ args.t5_cpu = False
121
+ args.dit_fsdp = False
122
+ args.prompt = None
123
+ args.use_prompt_extend = False
124
+ args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
125
+ args.prompt_extend_model = None
126
+ args.prompt_extend_target_lang = "zh" # ["zh", "en"]
127
+ args.base_seed = 0
128
+ args.image = None
129
+ args.sample_solver = "unipc" # ['unipc', 'dpm++']
130
+ args.sample_steps = None
131
+ args.sample_shift = None
132
+ args.sample_guide_scale = None
133
+ args.convert_model_dtype = False
134
+
135
+ # animate
136
+ args.refert_num = 1
137
+
138
+ # s2v-only
139
+ args.num_clip = None
140
+ args.audio = None
141
+ args.enable_tts = False
142
+ args.tts_prompt_audio = None
143
+ args.tts_prompt_text = None
144
+ args.tts_text = None
145
+ args.pose_video = None
146
+ args.start_from_ref = False
147
+ args.infer_frames = 80
148
+
149
+ _validate_args(args)
150
+ return args
151
+
152
+
153
+
154
+ def _init_logging(rank):
155
+ # logging
156
+ if rank == 0:
157
+ # set format
158
+ logging.basicConfig(
159
+ level=logging.INFO,
160
+ format="[%(asctime)s] %(levelname)s: %(message)s",
161
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
162
+ else:
163
+ logging.basicConfig(level=logging.ERROR)
164
+
165
+ def load_model(use_relighting_lora = False):
166
+
167
+ cfg = WAN_CONFIGS["animate-14B"]
168
+
169
+ return wan.WanAnimate(
170
+ config=cfg,
171
+ checkpoint_dir="./Wan2.2-Animate-14B/",
172
+ device_id=0,
173
+ rank=0,
174
+ t5_fsdp=False,
175
+ dit_fsdp=False,
176
+ use_sp=False,
177
+ t5_cpu=False,
178
+ convert_model_dtype=False,
179
+ use_relighting_lora=use_relighting_lora
180
+ )
181
+
182
+ def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
183
+ args = _parse_args()
184
+ rank = int(os.getenv("RANK", 0))
185
+ world_size = int(os.getenv("WORLD_SIZE", 1))
186
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
187
+ device = local_rank
188
+ _init_logging(rank)
189
+
190
+ cfg = WAN_CONFIGS[args.task]
191
+
192
+ logging.info(f"Input prompt: {args.prompt}")
193
+ img = None
194
+ if args.image is not None:
195
+ img = Image.open(args.image).convert("RGB")
196
+ logging.info(f"Input image: {args.image}")
197
+
198
+ print(f'rank:{rank}')
199
+
200
+
201
+
202
+ logging.info(f"Generating video ...")
203
+ video = wan_animate.generate(
204
+ src_root_path=preprocess_dir,
205
+ replace_flag=replace_flag,
206
+ refert_num = args.refert_num,
207
+ clip_len=args.frame_num,
208
+ shift=args.sample_shift,
209
+ sample_solver=args.sample_solver,
210
+ sampling_steps=args.sample_steps,
211
+ guide_scale=args.sample_guide_scale,
212
+ seed=args.base_seed,
213
+ offload_model=args.offload_model)
214
+ if rank == 0:
215
+
216
+ save_video(
217
+ tensor=video[None],
218
+ save_file=save_file,
219
+ fps=cfg.sample_fps,
220
+ nrow=1,
221
+ normalize=True,
222
+ value_range=(-1, 1))
223
+ # if "s2v" in args.task:
224
+ # if args.enable_tts is False:
225
+ # merge_video_audio(video_path=args.save_file, audio_path=args.audio)
226
+ # else:
227
+ # merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
228
+ del video
229
+
230
+ torch.cuda.synchronize()
231
+ if dist.is_initialized():
232
+ dist.barrier()
233
+ dist.destroy_process_group()
234
+
235
+ logging.info("Finished.")
236
+
pyproject.toml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "wan"
7
+ version = "2.2.0"
8
+ description = "Wan: Open and Advanced Large-Scale Video Generative Models"
9
+ authors = [
10
+ { name = "Wan Team", email = "[email protected]" }
11
+ ]
12
+ license = { file = "LICENSE.txt" }
13
+ readme = "README.md"
14
+ requires-python = ">=3.10,<4.0"
15
+ dependencies = [
16
+ "torch>=2.4.0",
17
+ "torchvision>=0.19.0",
18
+ "opencv-python>=4.9.0.80",
19
+ "diffusers>=0.31.0",
20
+ "transformers>=4.49.0",
21
+ "tokenizers>=0.20.3",
22
+ "accelerate>=1.1.1",
23
+ "tqdm",
24
+ "imageio",
25
+ "easydict",
26
+ "ftfy",
27
+ "dashscope",
28
+ "imageio-ffmpeg",
29
+ "flash_attn",
30
+ "numpy>=1.23.5,<2"
31
+ ]
32
+
33
+ [project.optional-dependencies]
34
+ dev = [
35
+ "pytest",
36
+ "black",
37
+ "flake8",
38
+ "isort",
39
+ "mypy",
40
+ "huggingface-hub[cli]"
41
+ ]
42
+
43
+ [project.urls]
44
+ homepage = "https://wanxai.com"
45
+ documentation = "https://github.com/Wan-Video/Wan2.2"
46
+ repository = "https://github.com/Wan-Video/Wan2.2"
47
+ huggingface = "https://huggingface.co/Wan-AI/"
48
+ modelscope = "https://modelscope.cn/organization/Wan-AI"
49
+ discord = "https://discord.gg/p5XbdQV7"
50
+
51
+ [tool.setuptools]
52
+ packages = ["wan"]
53
+
54
+ [tool.setuptools.package-data]
55
+ "wan" = ["**/*.py"]
56
+
57
+ [tool.black]
58
+ line-length = 88
59
+
60
+ [tool.isort]
61
+ profile = "black"
62
+
63
+ [tool.mypy]
64
+ strict = true
65
+
66
+
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ decord
3
+ peft
4
+ pandas
5
+ matplotlib
6
+ loguru
7
+ sentencepiece
8
+ dashscope
9
+ ftfy
10
+ diffusers
11
+ opencv-python
12
+ moviepy
13
+ torchvision
14
+ torchaudio
15
+ transformers
16
+ tokenizers
17
+ accelerate
18
+ tqdm
19
+ imageio[ffmpeg]
20
+ easydict
21
+ imageio-ffmpeg
22
+ numpy>=1.23.5,<2
23
+ hydra-core
24
+ iopath
25
+ pytest
26
+ pillow
27
+ fvcore
28
+ librosa
29
+ flash-attn
30
+ onnxruntime-gpu
31
+ flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
wan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from . import configs, distributed, modules
3
+ from .image2video import WanI2V
4
+ from .speech2video import WanS2V
5
+ from .text2video import WanT2V
6
+ from .textimage2video import WanTI2V
7
+ from .animate import WanAnimate
wan/animate.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+ import math
4
+ import os
5
+ import cv2
6
+ import types
7
+ from copy import deepcopy
8
+ from functools import partial
9
+ from einops import rearrange
10
+ import numpy as np
11
+ import torch
12
+
13
+ import torch.distributed as dist
14
+ from peft import set_peft_model_state_dict
15
+ from decord import VideoReader
16
+ from tqdm import tqdm
17
+ import torch.nn.functional as F
18
+ from .distributed.fsdp import shard_model
19
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
20
+ from .distributed.util import get_world_size
21
+
22
+ from .modules.animate import WanAnimateModel
23
+ from .modules.animate import CLIPModel
24
+ from .modules.t5 import T5EncoderModel
25
+ from .modules.vae2_1 import Wan2_1_VAE
26
+ from .modules.animate.animate_utils import TensorList, get_loraconfig
27
+ from .utils.fm_solvers import (
28
+ FlowDPMSolverMultistepScheduler,
29
+ get_sampling_sigmas,
30
+ retrieve_timesteps,
31
+ )
32
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
33
+
34
+
35
+
36
+ class WanAnimate:
37
+
38
+ def __init__(
39
+ self,
40
+ config,
41
+ checkpoint_dir,
42
+ device_id=0,
43
+ rank=0,
44
+ t5_fsdp=False,
45
+ dit_fsdp=False,
46
+ use_sp=False,
47
+ t5_cpu=False,
48
+ init_on_cpu=True,
49
+ convert_model_dtype=False,
50
+ use_relighting_lora=False
51
+ ):
52
+ r"""
53
+ Initializes the generation model components.
54
+
55
+ Args:
56
+ config (EasyDict):
57
+ Object containing model parameters initialized from config.py
58
+ checkpoint_dir (`str`):
59
+ Path to directory containing model checkpoints
60
+ device_id (`int`, *optional*, defaults to 0):
61
+ Id of target GPU device
62
+ rank (`int`, *optional*, defaults to 0):
63
+ Process rank for distributed training
64
+ t5_fsdp (`bool`, *optional*, defaults to False):
65
+ Enable FSDP sharding for T5 model
66
+ dit_fsdp (`bool`, *optional*, defaults to False):
67
+ Enable FSDP sharding for DiT model
68
+ use_sp (`bool`, *optional*, defaults to False):
69
+ Enable distribution strategy of sequence parallel.
70
+ t5_cpu (`bool`, *optional*, defaults to False):
71
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
72
+ init_on_cpu (`bool`, *optional*, defaults to True):
73
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
74
+ convert_model_dtype (`bool`, *optional*, defaults to False):
75
+ Convert DiT model parameters dtype to 'config.param_dtype'.
76
+ Only works without FSDP.
77
+ use_relighting_lora (`bool`, *optional*, defaults to False):
78
+ Whether to use relighting lora for character replacement.
79
+ """
80
+ self.device = torch.device(f"cuda:{device_id}")
81
+ self.config = config
82
+ self.rank = rank
83
+ self.t5_cpu = t5_cpu
84
+ self.init_on_cpu = init_on_cpu
85
+
86
+ self.num_train_timesteps = config.num_train_timesteps
87
+ self.param_dtype = config.param_dtype
88
+
89
+ if t5_fsdp or dit_fsdp or use_sp:
90
+ self.init_on_cpu = False
91
+
92
+ shard_fn = partial(shard_model, device_id=device_id)
93
+ self.text_encoder = T5EncoderModel(
94
+ text_len=config.text_len,
95
+ dtype=config.t5_dtype,
96
+ device=torch.device('cpu'),
97
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
98
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
99
+ shard_fn=shard_fn if t5_fsdp else None,
100
+ )
101
+
102
+ self.clip = CLIPModel(
103
+ dtype=torch.float16,
104
+ device=self.device,
105
+ checkpoint_path=os.path.join(checkpoint_dir,
106
+ config.clip_checkpoint),
107
+ tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
108
+
109
+ self.vae = Wan2_1_VAE(
110
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
111
+ device=self.device)
112
+
113
+ logging.info(f"Creating WanAnimate from {checkpoint_dir}")
114
+
115
+ if not dit_fsdp:
116
+ self.noise_model = WanAnimateModel.from_pretrained(
117
+ checkpoint_dir,
118
+ torch_dtype=self.param_dtype,
119
+ device_map=self.device)
120
+ else:
121
+ self.noise_model = WanAnimateModel.from_pretrained(
122
+ checkpoint_dir, torch_dtype=self.param_dtype)
123
+
124
+ self.noise_model = self._configure_model(
125
+ model=self.noise_model,
126
+ use_sp=use_sp,
127
+ dit_fsdp=dit_fsdp,
128
+ shard_fn=shard_fn,
129
+ convert_model_dtype=convert_model_dtype,
130
+ use_lora=use_relighting_lora,
131
+ checkpoint_dir=checkpoint_dir,
132
+ config=config
133
+ )
134
+
135
+ if use_sp:
136
+ self.sp_size = get_world_size()
137
+ else:
138
+ self.sp_size = 1
139
+
140
+ self.sample_neg_prompt = config.sample_neg_prompt
141
+ self.sample_prompt = config.prompt
142
+
143
+
144
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
145
+ convert_model_dtype, use_lora, checkpoint_dir, config):
146
+ """
147
+ Configures a model object. This includes setting evaluation modes,
148
+ applying distributed parallel strategy, and handling device placement.
149
+
150
+ Args:
151
+ model (torch.nn.Module):
152
+ The model instance to configure.
153
+ use_sp (`bool`):
154
+ Enable distribution strategy of sequence parallel.
155
+ dit_fsdp (`bool`):
156
+ Enable FSDP sharding for DiT model.
157
+ shard_fn (callable):
158
+ The function to apply FSDP sharding.
159
+ convert_model_dtype (`bool`):
160
+ Convert DiT model parameters dtype to 'config.param_dtype'.
161
+ Only works without FSDP.
162
+
163
+ Returns:
164
+ torch.nn.Module:
165
+ The configured model.
166
+ """
167
+ model.eval().requires_grad_(False)
168
+
169
+ if use_sp:
170
+ for block in model.blocks:
171
+ block.self_attn.forward = types.MethodType(
172
+ sp_attn_forward, block.self_attn)
173
+
174
+ model.use_context_parallel = True
175
+
176
+ if dist.is_initialized():
177
+ dist.barrier()
178
+
179
+ if use_lora:
180
+ logging.info("Loading Relighting Lora. ")
181
+ lora_config = get_loraconfig(
182
+ transformer=model,
183
+ rank=128,
184
+ alpha=128
185
+ )
186
+ model.add_adapter(lora_config)
187
+ lora_path = os.path.join(checkpoint_dir, config.lora_checkpoint)
188
+ peft_state_dict = torch.load(lora_path)["state_dict"]
189
+ set_peft_model_state_dict(model, peft_state_dict)
190
+
191
+ if dit_fsdp:
192
+ model = shard_fn(model, use_lora=use_lora)
193
+ else:
194
+ if convert_model_dtype:
195
+ model.to(self.param_dtype)
196
+ if not self.init_on_cpu:
197
+ model.to(self.device)
198
+
199
+ return model
200
+
201
+ def inputs_padding(self, array, target_len):
202
+ idx = 0
203
+ flip = False
204
+ target_array = []
205
+ while len(target_array) < target_len:
206
+ target_array.append(deepcopy(array[idx]))
207
+ if flip:
208
+ idx -= 1
209
+ else:
210
+ idx += 1
211
+ if idx == 0 or idx == len(array) - 1:
212
+ flip = not flip
213
+ return target_array[:target_len]
214
+
215
+ def get_valid_len(self, real_len, clip_len=81, overlap=1):
216
+ real_clip_len = clip_len - overlap
217
+ last_clip_num = (real_len - overlap) % real_clip_len
218
+ if last_clip_num == 0:
219
+ extra = 0
220
+ else:
221
+ extra = real_clip_len - last_clip_num
222
+ target_len = real_len + extra
223
+ return target_len
224
+
225
+
226
+ def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
227
+ if mask_pixel_values is None:
228
+ msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
229
+ else:
230
+ msk = mask_pixel_values.clone()
231
+ msk[:, :mask_len] = 1
232
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
233
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
234
+ msk = msk.transpose(1, 2)[0]
235
+ return msk
236
+
237
+ def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
238
+ ori_height = img_ori.shape[0]
239
+ ori_width = img_ori.shape[1]
240
+ channel = img_ori.shape[2]
241
+
242
+ img_pad = np.zeros((height, width, channel))
243
+ if channel == 1:
244
+ img_pad[:, :, 0] = padding_color[0]
245
+ else:
246
+ img_pad[:, :, 0] = padding_color[0]
247
+ img_pad[:, :, 1] = padding_color[1]
248
+ img_pad[:, :, 2] = padding_color[2]
249
+
250
+ if (ori_height / ori_width) > (height / width):
251
+ new_width = int(height / ori_height * ori_width)
252
+ img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
253
+ padding = int((width - new_width) / 2)
254
+ if len(img.shape) == 2:
255
+ img = img[:, :, np.newaxis]
256
+ img_pad[:, padding: padding + new_width, :] = img
257
+ else:
258
+ new_height = int(width / ori_width * ori_height)
259
+ img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
260
+ padding = int((height - new_height) / 2)
261
+ if len(img.shape) == 2:
262
+ img = img[:, :, np.newaxis]
263
+ img_pad[padding: padding + new_height, :, :] = img
264
+
265
+ img_pad = np.uint8(img_pad)
266
+
267
+ return img_pad
268
+
269
+ def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
270
+ pose_video_reader = VideoReader(src_pose_path)
271
+ pose_len = len(pose_video_reader)
272
+ pose_idxs = list(range(pose_len))
273
+ cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()
274
+
275
+ face_video_reader = VideoReader(src_face_path)
276
+ face_len = len(face_video_reader)
277
+ face_idxs = list(range(face_len))
278
+ face_images = face_video_reader.get_batch(face_idxs).asnumpy()
279
+ height, width = cond_images[0].shape[:2]
280
+ refer_images = cv2.imread(src_ref_path)[..., ::-1]
281
+ refer_images = self.padding_resize(refer_images, height=height, width=width)
282
+ return cond_images, face_images, refer_images
283
+
284
+ def prepare_source_for_replace(self, src_bg_path, src_mask_path):
285
+ bg_video_reader = VideoReader(src_bg_path)
286
+ bg_len = len(bg_video_reader)
287
+ bg_idxs = list(range(bg_len))
288
+ bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()
289
+
290
+ mask_video_reader = VideoReader(src_mask_path)
291
+ mask_len = len(mask_video_reader)
292
+ mask_idxs = list(range(mask_len))
293
+ mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
294
+ mask_images = mask_images[:, :, :, 0] / 255
295
+ return bg_images, mask_images
296
+
297
+ def generate(
298
+ self,
299
+ src_root_path,
300
+ replace_flag=False,
301
+ clip_len=77,
302
+ refert_num=1,
303
+ shift=5.0,
304
+ sample_solver='dpm++',
305
+ sampling_steps=20,
306
+ guide_scale=1,
307
+ input_prompt="",
308
+ n_prompt="",
309
+ seed=-1,
310
+ offload_model=True,
311
+ ):
312
+ r"""
313
+ Generates video frames from input image using diffusion process.
314
+
315
+ Args:
316
+ src_root_path ('str'):
317
+ Process output path
318
+ replace_flag (`bool`, *optional*, defaults to False):
319
+ Whether to use character replace.
320
+ clip_len (`int`, *optional*, defaults to 77):
321
+ How many frames to generate per clips. The number should be 4n+1
322
+ refert_num (`int`, *optional*, defaults to 1):
323
+ How many frames used for temporal guidance. Recommended to be 1 or 5.
324
+ shift (`float`, *optional*, defaults to 5.0):
325
+ Noise schedule shift parameter.
326
+ sample_solver (`str`, *optional*, defaults to 'dpm++'):
327
+ Solver used to sample the video.
328
+ sampling_steps (`int`, *optional*, defaults to 20):
329
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
330
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0):
331
+ Classifier-free guidance scale. We only use it for expression control.
332
+ In most cases, it's not necessary and faster generation can be achieved without it.
333
+ When expression adjustments are needed, you may consider using this feature.
334
+ input_prompt (`str`):
335
+ Text prompt for content generation. We don't recommend custom prompts (although they work)
336
+ n_prompt (`str`, *optional*, defaults to ""):
337
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
338
+ seed (`int`, *optional*, defaults to -1):
339
+ Random seed for noise generation. If -1, use random seed
340
+ offload_model (`bool`, *optional*, defaults to True):
341
+ If True, offloads models to CPU during generation to save VRAM
342
+
343
+ Returns:
344
+ torch.Tensor:
345
+ Generated video frames tensor. Dimensions: (C, N, H, W) where:
346
+ - C: Color channels (3 for RGB)
347
+ - N: Number of frames
348
+ - H: Frame height
349
+ - W: Frame width
350
+ """
351
+ assert refert_num == 1 or refert_num == 5, "refert_num should be 1 or 5."
352
+
353
+ seed_g = torch.Generator(device=self.device)
354
+ seed_g.manual_seed(seed)
355
+
356
+
357
+ if n_prompt == "":
358
+ n_prompt = self.sample_neg_prompt
359
+
360
+ if input_prompt == "":
361
+ input_prompt = self.sample_prompt
362
+
363
+ src_pose_path = os.path.join(src_root_path, "src_pose.mp4")
364
+ src_face_path = os.path.join(src_root_path, "src_face.mp4")
365
+ src_ref_path = os.path.join(src_root_path, "src_ref.png")
366
+
367
+ cond_images, face_images, refer_images = self.prepare_source(src_pose_path=src_pose_path, src_face_path=src_face_path, src_ref_path=src_ref_path)
368
+
369
+ if not self.t5_cpu:
370
+ self.text_encoder.model.to(self.device)
371
+ context = self.text_encoder([input_prompt], self.device)
372
+ context_null = self.text_encoder([n_prompt], self.device)
373
+ if offload_model:
374
+ self.text_encoder.model.cpu()
375
+ else:
376
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
377
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
378
+ context = [t.to(self.device) for t in context]
379
+ context_null = [t.to(self.device) for t in context_null]
380
+
381
+ real_frame_len = len(cond_images)
382
+ target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)
383
+ logging.info('real frames: {} target frames: {}'.format(real_frame_len, target_len))
384
+ cond_images = self.inputs_padding(cond_images, target_len)
385
+ face_images = self.inputs_padding(face_images, target_len)
386
+
387
+ if replace_flag:
388
+ src_bg_path = os.path.join(src_root_path, "src_bg.mp4")
389
+ src_mask_path = os.path.join(src_root_path, "src_mask.mp4")
390
+ bg_images, mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
391
+ bg_images = self.inputs_padding(bg_images, target_len)
392
+ mask_images = self.inputs_padding(mask_images, target_len)
393
+ self.noise_model.disable_adapters()
394
+ else:
395
+ self.noise_model.disable_adapters()
396
+
397
+
398
+ height, width = refer_images.shape[:2]
399
+ start = 0
400
+ end = clip_len
401
+ all_out_frames = []
402
+ while True:
403
+ if start + refert_num >= len(cond_images):
404
+ break
405
+
406
+ if start == 0:
407
+ mask_reft_len = 0
408
+ else:
409
+ mask_reft_len = refert_num
410
+
411
+ batch = {
412
+ "conditioning_pixel_values": torch.zeros(1, 3, clip_len, height, width),
413
+ "bg_pixel_values": torch.zeros(1, 3, clip_len, height, width),
414
+ "mask_pixel_values": torch.zeros(1, 1, clip_len, height, width),
415
+ "face_pixel_values": torch.zeros(1, 3, clip_len, 512, 512),
416
+ "refer_pixel_values": torch.zeros(1, 3, height, width),
417
+ "refer_t_pixel_values": torch.zeros(refert_num, 3, height, width)
418
+ }
419
+
420
+ batch["conditioning_pixel_values"] = rearrange(
421
+ torch.tensor(np.stack(cond_images[start:end]) / 127.5 - 1),
422
+ "t h w c -> 1 c t h w",
423
+ )
424
+ batch["face_pixel_values"] = rearrange(
425
+ torch.tensor(np.stack(face_images[start:end]) / 127.5 - 1),
426
+ "t h w c -> 1 c t h w",
427
+ )
428
+
429
+ batch["refer_pixel_values"] = rearrange(
430
+ torch.tensor(refer_images / 127.5 - 1), "h w c -> 1 c h w"
431
+ )
432
+
433
+ if start > 0:
434
+ batch["refer_t_pixel_values"] = rearrange(
435
+ out_frames[0, :, -refert_num:].clone().detach(),
436
+ "c t h w -> t c h w",
437
+ )
438
+
439
+ batch["refer_t_pixel_values"] = rearrange(batch["refer_t_pixel_values"],
440
+ "t c h w -> 1 c t h w",
441
+ )
442
+
443
+ if replace_flag:
444
+ batch["bg_pixel_values"] = rearrange(
445
+ torch.tensor(np.stack(bg_images[start:end]) / 127.5 - 1),
446
+ "t h w c -> 1 c t h w",
447
+ )
448
+
449
+ batch["mask_pixel_values"] = rearrange(
450
+ torch.tensor(np.stack(mask_images[start:end])[:, :, :, None]),
451
+ "t h w c -> 1 t c h w",
452
+ )
453
+
454
+
455
+ for key, value in batch.items():
456
+ if isinstance(value, torch.Tensor):
457
+ batch[key] = value.to(device=self.device, dtype=torch.bfloat16)
458
+
459
+ ref_pixel_values = batch["refer_pixel_values"]
460
+ refer_t_pixel_values = batch["refer_t_pixel_values"]
461
+ conditioning_pixel_values = batch["conditioning_pixel_values"]
462
+ face_pixel_values = batch["face_pixel_values"]
463
+
464
+ B, _, H, W = ref_pixel_values.shape
465
+ T = clip_len
466
+ lat_h = H // 8
467
+ lat_w = W // 8
468
+ lat_t = T // 4 + 1
469
+ target_shape = [lat_t + 1, lat_h, lat_w]
470
+ noise = [
471
+ torch.randn(
472
+ 16,
473
+ target_shape[0],
474
+ target_shape[1],
475
+ target_shape[2],
476
+ dtype=torch.float32,
477
+ device=self.device,
478
+ generator=seed_g,
479
+ )
480
+ ]
481
+
482
+ max_seq_len = int(math.ceil(np.prod(target_shape) // 4 / self.sp_size)) * self.sp_size
483
+ if max_seq_len % self.sp_size != 0:
484
+ raise ValueError(f"max_seq_len {max_seq_len} is not divisible by sp_size {self.sp_size}")
485
+
486
+ with (
487
+ torch.autocast(device_type=str(self.device), dtype=torch.bfloat16, enabled=True),
488
+ torch.no_grad()
489
+ ):
490
+ if sample_solver == 'unipc':
491
+ sample_scheduler = FlowUniPCMultistepScheduler(
492
+ num_train_timesteps=self.num_train_timesteps,
493
+ shift=1,
494
+ use_dynamic_shifting=False)
495
+ sample_scheduler.set_timesteps(
496
+ sampling_steps, device=self.device, shift=shift)
497
+ timesteps = sample_scheduler.timesteps
498
+ elif sample_solver == 'dpm++':
499
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
500
+ num_train_timesteps=self.num_train_timesteps,
501
+ shift=1,
502
+ use_dynamic_shifting=False)
503
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
504
+ timesteps, _ = retrieve_timesteps(
505
+ sample_scheduler,
506
+ device=self.device,
507
+ sigmas=sampling_sigmas)
508
+ else:
509
+ raise NotImplementedError("Unsupported solver.")
510
+
511
+ latents = noise
512
+
513
+ pose_latents_no_ref = self.vae.encode(conditioning_pixel_values.to(torch.bfloat16))
514
+ pose_latents_no_ref = torch.stack(pose_latents_no_ref)
515
+ pose_latents = torch.cat([pose_latents_no_ref], dim=2)
516
+
517
+ ref_pixel_values = rearrange(ref_pixel_values, "t c h w -> 1 c t h w")
518
+ ref_latents = self.vae.encode(ref_pixel_values.to(torch.bfloat16))
519
+ ref_latents = torch.stack(ref_latents)
520
+
521
+ mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=self.device)
522
+ y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=self.device)
523
+
524
+ img = ref_pixel_values[0, :, 0]
525
+ clip_context = self.clip.visual([img[:, None, :, :]]).to(dtype=torch.bfloat16, device=self.device)
526
+
527
+ if mask_reft_len > 0:
528
+ if replace_flag:
529
+ bg_pixel_values = batch["bg_pixel_values"]
530
+ y_reft = self.vae.encode(
531
+ [
532
+ torch.concat([refer_t_pixel_values[0, :, :mask_reft_len], bg_pixel_values[0, :, mask_reft_len:]], dim=1).to(self.device)
533
+ ]
534
+ )[0]
535
+ mask_pixel_values = 1 - batch["mask_pixel_values"]
536
+ mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
537
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
538
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
539
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
540
+ else:
541
+ y_reft = self.vae.encode(
542
+ [
543
+ torch.concat(
544
+ [
545
+ torch.nn.functional.interpolate(refer_t_pixel_values[0, :, :mask_reft_len].cpu(),
546
+ size=(H, W), mode="bicubic"),
547
+ torch.zeros(3, T - mask_reft_len, H, W),
548
+ ],
549
+ dim=1,
550
+ ).to(self.device)
551
+ ]
552
+ )[0]
553
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
554
+ else:
555
+ if replace_flag:
556
+ bg_pixel_values = batch["bg_pixel_values"]
557
+ mask_pixel_values = 1 - batch["mask_pixel_values"]
558
+ mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
559
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
560
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
561
+ y_reft = self.vae.encode(
562
+ [
563
+ torch.concat(
564
+ [
565
+ bg_pixel_values[0],
566
+ ],
567
+ dim=1,
568
+ ).to(self.device)
569
+ ]
570
+ )[0]
571
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
572
+ else:
573
+ y_reft = self.vae.encode(
574
+ [
575
+ torch.concat(
576
+ [
577
+ torch.zeros(3, T - mask_reft_len, H, W),
578
+ ],
579
+ dim=1,
580
+ ).to(self.device)
581
+ ]
582
+ )[0]
583
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
584
+
585
+ y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=self.device)
586
+ y = torch.concat([y_ref, y_reft], dim=1)
587
+
588
+ arg_c = {
589
+ "context": context,
590
+ "seq_len": max_seq_len,
591
+ "clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
592
+ "y": [y],
593
+ "pose_latents": pose_latents,
594
+ "face_pixel_values": face_pixel_values,
595
+ }
596
+
597
+ if guide_scale > 1:
598
+ face_pixel_values_uncond = face_pixel_values * 0 - 1
599
+ arg_null = {
600
+ "context": context_null,
601
+ "seq_len": max_seq_len,
602
+ "clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
603
+ "y": [y],
604
+ "pose_latents": pose_latents,
605
+ "face_pixel_values": face_pixel_values_uncond,
606
+ }
607
+
608
+ for i, t in enumerate(tqdm(timesteps)):
609
+ latent_model_input = latents
610
+ timestep = [t]
611
+
612
+ timestep = torch.stack(timestep)
613
+
614
+ noise_pred_cond = TensorList(
615
+ self.noise_model(TensorList(latent_model_input), t=timestep, **arg_c)
616
+ )
617
+
618
+ if guide_scale > 1:
619
+ noise_pred_uncond = TensorList(
620
+ self.noise_model(
621
+ TensorList(latent_model_input), t=timestep, **arg_null
622
+ )
623
+ )
624
+ noise_pred = noise_pred_uncond + guide_scale * (
625
+ noise_pred_cond - noise_pred_uncond
626
+ )
627
+ else:
628
+ noise_pred = noise_pred_cond
629
+
630
+ temp_x0 = sample_scheduler.step(
631
+ noise_pred[0].unsqueeze(0),
632
+ t,
633
+ latents[0].unsqueeze(0),
634
+ return_dict=False,
635
+ generator=seed_g,
636
+ )[0]
637
+ latents[0] = temp_x0.squeeze(0)
638
+
639
+ x0 = latents
640
+
641
+ x0 = [x.to(dtype=torch.float32) for x in x0]
642
+ out_frames = torch.stack(self.vae.decode([x0[0][:, 1:]]))
643
+
644
+ if start != 0:
645
+ out_frames = out_frames[:, :, refert_num:]
646
+
647
+ all_out_frames.append(out_frames.cpu())
648
+
649
+ start += clip_len - refert_num
650
+ end += clip_len - refert_num
651
+
652
+ videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]
653
+ return videos[0] if self.rank == 0 else None
wan/configs/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import copy
3
+ import os
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ from .wan_i2v_A14B import i2v_A14B
8
+ from .wan_s2v_14B import s2v_14B
9
+ from .wan_t2v_A14B import t2v_A14B
10
+ from .wan_ti2v_5B import ti2v_5B
11
+ from .wan_animate_14B import animate_14B
12
+
13
+ WAN_CONFIGS = {
14
+ 't2v-A14B': t2v_A14B,
15
+ 'i2v-A14B': i2v_A14B,
16
+ 'ti2v-5B': ti2v_5B,
17
+ 'animate-14B': animate_14B,
18
+ 's2v-14B': s2v_14B,
19
+ }
20
+
21
+ SIZE_CONFIGS = {
22
+ '720*1280': (720, 1280),
23
+ '1280*720': (1280, 720),
24
+ '480*832': (480, 832),
25
+ '832*480': (832, 480),
26
+ '704*1280': (704, 1280),
27
+ '1280*704': (1280, 704),
28
+ '1024*704': (1024, 704),
29
+ '704*1024': (704, 1024),
30
+ }
31
+
32
+ MAX_AREA_CONFIGS = {
33
+ '720*1280': 720 * 1280,
34
+ '1280*720': 1280 * 720,
35
+ '480*832': 480 * 832,
36
+ '832*480': 832 * 480,
37
+ '704*1280': 704 * 1280,
38
+ '1280*704': 1280 * 704,
39
+ '1024*704': 1024 * 704,
40
+ '704*1024': 704 * 1024,
41
+ }
42
+
43
+ SUPPORTED_SIZES = {
44
+ 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
45
+ 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
46
+ 'ti2v-5B': ('704*1280', '1280*704'),
47
+ 's2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '1024*704',
48
+ '704*1024', '704*1280', '1280*704'),
49
+ 'animate-14B': ('720*1280', '1280*720')
50
+ }
wan/configs/shared_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ #------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20
+ wan_shared_cfg.frame_num = 81
wan/configs/wan_animate_14B.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan animate 14B ------------------------#
7
+ animate_14B = EasyDict(__name__='Config: Wan animate 14B')
8
+ animate_14B.update(wan_shared_cfg)
9
+
10
+ animate_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
11
+ animate_14B.t5_tokenizer = 'google/umt5-xxl'
12
+
13
+ animate_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
14
+ animate_14B.clip_tokenizer = 'xlm-roberta-large'
15
+ animate_14B.lora_checkpoint = 'relighting_lora.ckpt'
16
+ # vae
17
+ animate_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
18
+ animate_14B.vae_stride = (4, 8, 8)
19
+
20
+ # transformer
21
+ animate_14B.patch_size = (1, 2, 2)
22
+ animate_14B.dim = 5120
23
+ animate_14B.ffn_dim = 13824
24
+ animate_14B.freq_dim = 256
25
+ animate_14B.num_heads = 40
26
+ animate_14B.num_layers = 40
27
+ animate_14B.window_size = (-1, -1)
28
+ animate_14B.qk_norm = True
29
+ animate_14B.cross_attn_norm = True
30
+ animate_14B.eps = 1e-6
31
+ animate_14B.use_face_encoder = True
32
+ animate_14B.motion_encoder_dim = 512
33
+
34
+ # inference
35
+ animate_14B.sample_shift = 5.0
36
+ animate_14B.sample_steps = 20
37
+ animate_14B.sample_guide_scale = 1.0
38
+ animate_14B.frame_num = 77
39
+ animate_14B.sample_fps = 30
40
+ animate_14B.prompt = '视频中的人在做动作'
wan/configs/wan_i2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import wan_shared_cfg
6
+
7
+ #------------------------ Wan I2V A14B ------------------------#
8
+
9
+ i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
10
+ i2v_A14B.update(wan_shared_cfg)
11
+
12
+ i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ i2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ i2v_A14B.patch_size = (1, 2, 2)
21
+ i2v_A14B.dim = 5120
22
+ i2v_A14B.ffn_dim = 13824
23
+ i2v_A14B.freq_dim = 256
24
+ i2v_A14B.num_heads = 40
25
+ i2v_A14B.num_layers = 40
26
+ i2v_A14B.window_size = (-1, -1)
27
+ i2v_A14B.qk_norm = True
28
+ i2v_A14B.cross_attn_norm = True
29
+ i2v_A14B.eps = 1e-6
30
+ i2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ i2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ i2v_A14B.sample_shift = 5.0
35
+ i2v_A14B.sample_steps = 40
36
+ i2v_A14B.boundary = 0.900
37
+ i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
wan/configs/wan_s2v_14B.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan S2V 14B ------------------------#
7
+
8
+ s2v_14B = EasyDict(__name__='Config: Wan S2V 14B')
9
+ s2v_14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ s2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ s2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ s2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ s2v_14B.vae_stride = (4, 8, 8)
18
+
19
+ # wav2vec
20
+ s2v_14B.wav2vec = "wav2vec2-large-xlsr-53-english"
21
+
22
+ s2v_14B.num_heads = 40
23
+ # transformer
24
+ s2v_14B.transformer = EasyDict(
25
+ __name__="Config: Transformer config for WanModel_S2V")
26
+ s2v_14B.transformer.patch_size = (1, 2, 2)
27
+ s2v_14B.transformer.dim = 5120
28
+ s2v_14B.transformer.ffn_dim = 13824
29
+ s2v_14B.transformer.freq_dim = 256
30
+ s2v_14B.transformer.num_heads = 40
31
+ s2v_14B.transformer.num_layers = 40
32
+ s2v_14B.transformer.window_size = (-1, -1)
33
+ s2v_14B.transformer.qk_norm = True
34
+ s2v_14B.transformer.cross_attn_norm = True
35
+ s2v_14B.transformer.eps = 1e-6
36
+ s2v_14B.transformer.enable_adain = True
37
+ s2v_14B.transformer.adain_mode = "attn_norm"
38
+ s2v_14B.transformer.audio_inject_layers = [
39
+ 0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39
40
+ ]
41
+ s2v_14B.transformer.zero_init = True
42
+ s2v_14B.transformer.zero_timestep = True
43
+ s2v_14B.transformer.enable_motioner = False
44
+ s2v_14B.transformer.add_last_motion = True
45
+ s2v_14B.transformer.trainable_token = False
46
+ s2v_14B.transformer.enable_tsm = False
47
+ s2v_14B.transformer.enable_framepack = True
48
+ s2v_14B.transformer.framepack_drop_mode = 'padd'
49
+ s2v_14B.transformer.audio_dim = 1024
50
+
51
+ s2v_14B.transformer.motion_frames = 73
52
+ s2v_14B.transformer.cond_dim = 16
53
+
54
+ # inference
55
+ s2v_14B.sample_neg_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
56
+ s2v_14B.drop_first_motion = True
57
+ s2v_14B.sample_shift = 3
58
+ s2v_14B.sample_steps = 40
59
+ s2v_14B.sample_guide_scale = 4.5
wan/configs/wan_t2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V A14B ------------------------#
7
+
8
+ t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
9
+ t2v_A14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_A14B.patch_size = (1, 2, 2)
21
+ t2v_A14B.dim = 5120
22
+ t2v_A14B.ffn_dim = 13824
23
+ t2v_A14B.freq_dim = 256
24
+ t2v_A14B.num_heads = 40
25
+ t2v_A14B.num_layers = 40
26
+ t2v_A14B.window_size = (-1, -1)
27
+ t2v_A14B.qk_norm = True
28
+ t2v_A14B.cross_attn_norm = True
29
+ t2v_A14B.eps = 1e-6
30
+ t2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ t2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ t2v_A14B.sample_shift = 12.0
35
+ t2v_A14B.sample_steps = 40
36
+ t2v_A14B.boundary = 0.875
37
+ t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
wan/configs/wan_ti2v_5B.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan TI2V 5B ------------------------#
7
+
8
+ ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
9
+ ti2v_5B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
17
+ ti2v_5B.vae_stride = (4, 16, 16)
18
+
19
+ # transformer
20
+ ti2v_5B.patch_size = (1, 2, 2)
21
+ ti2v_5B.dim = 3072
22
+ ti2v_5B.ffn_dim = 14336
23
+ ti2v_5B.freq_dim = 256
24
+ ti2v_5B.num_heads = 24
25
+ ti2v_5B.num_layers = 30
26
+ ti2v_5B.window_size = (-1, -1)
27
+ ti2v_5B.qk_norm = True
28
+ ti2v_5B.cross_attn_norm = True
29
+ ti2v_5B.eps = 1e-6
30
+
31
+ # inference
32
+ ti2v_5B.sample_fps = 24
33
+ ti2v_5B.sample_shift = 5.0
34
+ ti2v_5B.sample_steps = 50
35
+ ti2v_5B.sample_guide_scale = 5.0
36
+ ti2v_5B.frame_num = 121
wan/distributed/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(
13
+ model,
14
+ device_id,
15
+ param_dtype=torch.bfloat16,
16
+ reduce_dtype=torch.float32,
17
+ buffer_dtype=torch.float32,
18
+ process_group=None,
19
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
20
+ sync_module_states=True,
21
+ use_lora=False
22
+ ):
23
+ model = FSDP(
24
+ module=model,
25
+ process_group=process_group,
26
+ sharding_strategy=sharding_strategy,
27
+ auto_wrap_policy=partial(
28
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
29
+ mixed_precision=MixedPrecision(
30
+ param_dtype=param_dtype,
31
+ reduce_dtype=reduce_dtype,
32
+ buffer_dtype=buffer_dtype),
33
+ device_id=device_id,
34
+ sync_module_states=sync_module_states,
35
+ use_orig_params=True if use_lora else False)
36
+ return model
37
+
38
+
39
+ def free_model(model):
40
+ for m in model.modules():
41
+ if isinstance(m, FSDP):
42
+ _free_storage(m._handle.flat_param.data)
43
+ del model
44
+ gc.collect()
45
+ torch.cuda.empty_cache()
wan/distributed/sequence_parallel.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+
5
+ from ..modules.model import sinusoidal_embedding_1d
6
+ from .ulysses import distributed_attention
7
+ from .util import gather_forward, get_rank, get_world_size
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+
23
+ @torch.amp.autocast('cuda', enabled=False)
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_world_size()
51
+ sp_rank = get_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output).float()
62
+
63
+
64
+ def sp_dit_forward(
65
+ self,
66
+ x,
67
+ t,
68
+ context,
69
+ seq_len,
70
+ y=None,
71
+ ):
72
+ """
73
+ x: A list of videos each with shape [C, T, H, W].
74
+ t: [B].
75
+ context: A list of text embeddings each with shape [L, C].
76
+ """
77
+ if self.model_type == 'i2v':
78
+ assert y is not None
79
+ # params
80
+ device = self.patch_embedding.weight.device
81
+ if self.freqs.device != device:
82
+ self.freqs = self.freqs.to(device)
83
+
84
+ if y is not None:
85
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
86
+
87
+ # embeddings
88
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
89
+ grid_sizes = torch.stack(
90
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
91
+ x = [u.flatten(2).transpose(1, 2) for u in x]
92
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
93
+ assert seq_lens.max() <= seq_len
94
+ x = torch.cat([
95
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
96
+ for u in x
97
+ ])
98
+
99
+ # time embeddings
100
+ if t.dim() == 1:
101
+ t = t.expand(t.size(0), seq_len)
102
+ with torch.amp.autocast('cuda', dtype=torch.float32):
103
+ bt = t.size(0)
104
+ t = t.flatten()
105
+ e = self.time_embedding(
106
+ sinusoidal_embedding_1d(self.freq_dim,
107
+ t).unflatten(0, (bt, seq_len)).float())
108
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
109
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
110
+
111
+ # context
112
+ context_lens = None
113
+ context = self.text_embedding(
114
+ torch.stack([
115
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
116
+ for u in context
117
+ ]))
118
+
119
+ # Context Parallel
120
+ x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
121
+ e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
122
+ e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
123
+
124
+ # arguments
125
+ kwargs = dict(
126
+ e=e0,
127
+ seq_lens=seq_lens,
128
+ grid_sizes=grid_sizes,
129
+ freqs=self.freqs,
130
+ context=context,
131
+ context_lens=context_lens)
132
+
133
+ for block in self.blocks:
134
+ x = block(x, **kwargs)
135
+
136
+ # head
137
+ x = self.head(x, e)
138
+
139
+ # Context Parallel
140
+ x = gather_forward(x, dim=1)
141
+
142
+ # unpatchify
143
+ x = self.unpatchify(x, grid_sizes)
144
+ return [u.float() for u in x]
145
+
146
+
147
+ def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+ half_dtypes = (torch.float16, torch.bfloat16)
150
+
151
+ def half(x):
152
+ return x if x.dtype in half_dtypes else x.to(dtype)
153
+
154
+ # query, key, value function
155
+ def qkv_fn(x):
156
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
157
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
158
+ v = self.v(x).view(b, s, n, d)
159
+ return q, k, v
160
+
161
+ q, k, v = qkv_fn(x)
162
+ q = rope_apply(q, grid_sizes, freqs)
163
+ k = rope_apply(k, grid_sizes, freqs)
164
+
165
+ x = distributed_attention(
166
+ half(q),
167
+ half(k),
168
+ half(v),
169
+ seq_lens,
170
+ window_size=self.window_size,
171
+ )
172
+
173
+ # output
174
+ x = x.flatten(2)
175
+ x = self.o(x)
176
+ return x
wan/distributed/ulysses.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ from ..modules.attention import flash_attention
6
+ from .util import all_to_all
7
+
8
+
9
+ def distributed_attention(
10
+ q,
11
+ k,
12
+ v,
13
+ seq_lens,
14
+ window_size=(-1, -1),
15
+ ):
16
+ """
17
+ Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
18
+ please refer to https://arxiv.org/pdf/2309.14509
19
+
20
+ Args:
21
+ q: [B, Lq // p, Nq, C1].
22
+ k: [B, Lk // p, Nk, C1].
23
+ v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
24
+ seq_lens: [B], length of each sequence in batch
25
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
26
+ """
27
+ if not dist.is_initialized():
28
+ raise ValueError("distributed group should be initialized.")
29
+ b = q.shape[0]
30
+
31
+ # gather q/k/v sequence
32
+ q = all_to_all(q, scatter_dim=2, gather_dim=1)
33
+ k = all_to_all(k, scatter_dim=2, gather_dim=1)
34
+ v = all_to_all(v, scatter_dim=2, gather_dim=1)
35
+
36
+ # apply attention
37
+ x = flash_attention(
38
+ q,
39
+ k,
40
+ v,
41
+ k_lens=seq_lens,
42
+ window_size=window_size,
43
+ )
44
+
45
+ # scatter q/k/v sequence
46
+ x = all_to_all(x, scatter_dim=1, gather_dim=2)
47
+ return x
wan/distributed/util.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def init_distributed_group():
7
+ """r initialize sequence parallel group.
8
+ """
9
+ if not dist.is_initialized():
10
+ dist.init_process_group(backend='nccl')
11
+
12
+
13
+ def get_rank():
14
+ return dist.get_rank()
15
+
16
+
17
+ def get_world_size():
18
+ return dist.get_world_size()
19
+
20
+
21
+ def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
22
+ """
23
+ `scatter` along one dimension and `gather` along another.
24
+ """
25
+ world_size = get_world_size()
26
+ if world_size > 1:
27
+ inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
28
+ outputs = [torch.empty_like(u) for u in inputs]
29
+ dist.all_to_all(outputs, inputs, group=group, **kwargs)
30
+ x = torch.cat(outputs, dim=gather_dim).contiguous()
31
+ return x
32
+
33
+
34
+ def all_gather(tensor):
35
+ world_size = dist.get_world_size()
36
+ if world_size == 1:
37
+ return [tensor]
38
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
39
+ torch.distributed.all_gather(tensor_list, tensor)
40
+ return tensor_list
41
+
42
+
43
+ def gather_forward(input, dim):
44
+ # skip if world_size == 1
45
+ world_size = dist.get_world_size()
46
+ if world_size == 1:
47
+ return input
48
+
49
+ # gather sequence
50
+ output = all_gather(input)
51
+ return torch.cat(output, dim=dim).contiguous()
wan/image2video.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
+ from .distributed.util import get_world_size
22
+ from .modules.model import WanModel
23
+ from .modules.t5 import T5EncoderModel
24
+ from .modules.vae2_1 import Wan2_1_VAE
25
+ from .utils.fm_solvers import (
26
+ FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas,
28
+ retrieve_timesteps,
29
+ )
30
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+
32
+
33
+ class WanI2V:
34
+
35
+ def __init__(
36
+ self,
37
+ config,
38
+ checkpoint_dir,
39
+ device_id=0,
40
+ rank=0,
41
+ t5_fsdp=False,
42
+ dit_fsdp=False,
43
+ use_sp=False,
44
+ t5_cpu=False,
45
+ init_on_cpu=True,
46
+ convert_model_dtype=False,
47
+ ):
48
+ r"""
49
+ Initializes the image-to-video generation model components.
50
+
51
+ Args:
52
+ config (EasyDict):
53
+ Object containing model parameters initialized from config.py
54
+ checkpoint_dir (`str`):
55
+ Path to directory containing model checkpoints
56
+ device_id (`int`, *optional*, defaults to 0):
57
+ Id of target GPU device
58
+ rank (`int`, *optional*, defaults to 0):
59
+ Process rank for distributed training
60
+ t5_fsdp (`bool`, *optional*, defaults to False):
61
+ Enable FSDP sharding for T5 model
62
+ dit_fsdp (`bool`, *optional*, defaults to False):
63
+ Enable FSDP sharding for DiT model
64
+ use_sp (`bool`, *optional*, defaults to False):
65
+ Enable distribution strategy of sequence parallel.
66
+ t5_cpu (`bool`, *optional*, defaults to False):
67
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
68
+ init_on_cpu (`bool`, *optional*, defaults to True):
69
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
70
+ convert_model_dtype (`bool`, *optional*, defaults to False):
71
+ Convert DiT model parameters dtype to 'config.param_dtype'.
72
+ Only works without FSDP.
73
+ """
74
+ self.device = torch.device(f"cuda:{device_id}")
75
+ self.config = config
76
+ self.rank = rank
77
+ self.t5_cpu = t5_cpu
78
+ self.init_on_cpu = init_on_cpu
79
+
80
+ self.num_train_timesteps = config.num_train_timesteps
81
+ self.boundary = config.boundary
82
+ self.param_dtype = config.param_dtype
83
+
84
+ if t5_fsdp or dit_fsdp or use_sp:
85
+ self.init_on_cpu = False
86
+
87
+ shard_fn = partial(shard_model, device_id=device_id)
88
+ self.text_encoder = T5EncoderModel(
89
+ text_len=config.text_len,
90
+ dtype=config.t5_dtype,
91
+ device=torch.device('cpu'),
92
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
+ shard_fn=shard_fn if t5_fsdp else None,
95
+ )
96
+
97
+ self.vae_stride = config.vae_stride
98
+ self.patch_size = config.patch_size
99
+ self.vae = Wan2_1_VAE(
100
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
101
+ device=self.device)
102
+
103
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
104
+ self.low_noise_model = WanModel.from_pretrained(
105
+ checkpoint_dir, subfolder=config.low_noise_checkpoint)
106
+ self.low_noise_model = self._configure_model(
107
+ model=self.low_noise_model,
108
+ use_sp=use_sp,
109
+ dit_fsdp=dit_fsdp,
110
+ shard_fn=shard_fn,
111
+ convert_model_dtype=convert_model_dtype)
112
+
113
+ self.high_noise_model = WanModel.from_pretrained(
114
+ checkpoint_dir, subfolder=config.high_noise_checkpoint)
115
+ self.high_noise_model = self._configure_model(
116
+ model=self.high_noise_model,
117
+ use_sp=use_sp,
118
+ dit_fsdp=dit_fsdp,
119
+ shard_fn=shard_fn,
120
+ convert_model_dtype=convert_model_dtype)
121
+ if use_sp:
122
+ self.sp_size = get_world_size()
123
+ else:
124
+ self.sp_size = 1
125
+
126
+ self.sample_neg_prompt = config.sample_neg_prompt
127
+
128
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
129
+ convert_model_dtype):
130
+ """
131
+ Configures a model object. This includes setting evaluation modes,
132
+ applying distributed parallel strategy, and handling device placement.
133
+
134
+ Args:
135
+ model (torch.nn.Module):
136
+ The model instance to configure.
137
+ use_sp (`bool`):
138
+ Enable distribution strategy of sequence parallel.
139
+ dit_fsdp (`bool`):
140
+ Enable FSDP sharding for DiT model.
141
+ shard_fn (callable):
142
+ The function to apply FSDP sharding.
143
+ convert_model_dtype (`bool`):
144
+ Convert DiT model parameters dtype to 'config.param_dtype'.
145
+ Only works without FSDP.
146
+
147
+ Returns:
148
+ torch.nn.Module:
149
+ The configured model.
150
+ """
151
+ model.eval().requires_grad_(False)
152
+
153
+ if use_sp:
154
+ for block in model.blocks:
155
+ block.self_attn.forward = types.MethodType(
156
+ sp_attn_forward, block.self_attn)
157
+ model.forward = types.MethodType(sp_dit_forward, model)
158
+
159
+ if dist.is_initialized():
160
+ dist.barrier()
161
+
162
+ if dit_fsdp:
163
+ model = shard_fn(model)
164
+ else:
165
+ if convert_model_dtype:
166
+ model.to(self.param_dtype)
167
+ if not self.init_on_cpu:
168
+ model.to(self.device)
169
+
170
+ return model
171
+
172
+ def _prepare_model_for_timestep(self, t, boundary, offload_model):
173
+ r"""
174
+ Prepares and returns the required model for the current timestep.
175
+
176
+ Args:
177
+ t (torch.Tensor):
178
+ current timestep.
179
+ boundary (`int`):
180
+ The timestep threshold. If `t` is at or above this value,
181
+ the `high_noise_model` is considered as the required model.
182
+ offload_model (`bool`):
183
+ A flag intended to control the offloading behavior.
184
+
185
+ Returns:
186
+ torch.nn.Module:
187
+ The active model on the target device for the current timestep.
188
+ """
189
+ if t.item() >= boundary:
190
+ required_model_name = 'high_noise_model'
191
+ offload_model_name = 'low_noise_model'
192
+ else:
193
+ required_model_name = 'low_noise_model'
194
+ offload_model_name = 'high_noise_model'
195
+ if offload_model or self.init_on_cpu:
196
+ if next(getattr(
197
+ self,
198
+ offload_model_name).parameters()).device.type == 'cuda':
199
+ getattr(self, offload_model_name).to('cpu')
200
+ if next(getattr(
201
+ self,
202
+ required_model_name).parameters()).device.type == 'cpu':
203
+ getattr(self, required_model_name).to(self.device)
204
+ return getattr(self, required_model_name)
205
+
206
+ def generate(self,
207
+ input_prompt,
208
+ img,
209
+ max_area=720 * 1280,
210
+ frame_num=81,
211
+ shift=5.0,
212
+ sample_solver='unipc',
213
+ sampling_steps=40,
214
+ guide_scale=5.0,
215
+ n_prompt="",
216
+ seed=-1,
217
+ offload_model=True):
218
+ r"""
219
+ Generates video frames from input image and text prompt using diffusion process.
220
+
221
+ Args:
222
+ input_prompt (`str`):
223
+ Text prompt for content generation.
224
+ img (PIL.Image.Image):
225
+ Input image tensor. Shape: [3, H, W]
226
+ max_area (`int`, *optional*, defaults to 720*1280):
227
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
228
+ frame_num (`int`, *optional*, defaults to 81):
229
+ How many frames to sample from a video. The number should be 4n+1
230
+ shift (`float`, *optional*, defaults to 5.0):
231
+ Noise schedule shift parameter. Affects temporal dynamics
232
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
233
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
234
+ Solver used to sample the video.
235
+ sampling_steps (`int`, *optional*, defaults to 40):
236
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
237
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
238
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
239
+ If tuple, the first guide_scale will be used for low noise model and
240
+ the second guide_scale will be used for high noise model.
241
+ n_prompt (`str`, *optional*, defaults to ""):
242
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
243
+ seed (`int`, *optional*, defaults to -1):
244
+ Random seed for noise generation. If -1, use random seed
245
+ offload_model (`bool`, *optional*, defaults to True):
246
+ If True, offloads models to CPU during generation to save VRAM
247
+
248
+ Returns:
249
+ torch.Tensor:
250
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
251
+ - C: Color channels (3 for RGB)
252
+ - N: Number of frames (81)
253
+ - H: Frame height (from max_area)
254
+ - W: Frame width from max_area)
255
+ """
256
+ # preprocess
257
+ guide_scale = (guide_scale, guide_scale) if isinstance(
258
+ guide_scale, float) else guide_scale
259
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
260
+
261
+ F = frame_num
262
+ h, w = img.shape[1:]
263
+ aspect_ratio = h / w
264
+ lat_h = round(
265
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
266
+ self.patch_size[1] * self.patch_size[1])
267
+ lat_w = round(
268
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
269
+ self.patch_size[2] * self.patch_size[2])
270
+ h = lat_h * self.vae_stride[1]
271
+ w = lat_w * self.vae_stride[2]
272
+
273
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
274
+ self.patch_size[1] * self.patch_size[2])
275
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
276
+
277
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
278
+ seed_g = torch.Generator(device=self.device)
279
+ seed_g.manual_seed(seed)
280
+ noise = torch.randn(
281
+ 16,
282
+ (F - 1) // self.vae_stride[0] + 1,
283
+ lat_h,
284
+ lat_w,
285
+ dtype=torch.float32,
286
+ generator=seed_g,
287
+ device=self.device)
288
+
289
+ msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
290
+ msk[:, 1:] = 0
291
+ msk = torch.concat([
292
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
293
+ ],
294
+ dim=1)
295
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
296
+ msk = msk.transpose(1, 2)[0]
297
+
298
+ if n_prompt == "":
299
+ n_prompt = self.sample_neg_prompt
300
+
301
+ # preprocess
302
+ if not self.t5_cpu:
303
+ self.text_encoder.model.to(self.device)
304
+ context = self.text_encoder([input_prompt], self.device)
305
+ context_null = self.text_encoder([n_prompt], self.device)
306
+ if offload_model:
307
+ self.text_encoder.model.cpu()
308
+ else:
309
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
310
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
311
+ context = [t.to(self.device) for t in context]
312
+ context_null = [t.to(self.device) for t in context_null]
313
+
314
+ y = self.vae.encode([
315
+ torch.concat([
316
+ torch.nn.functional.interpolate(
317
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
318
+ 0, 1),
319
+ torch.zeros(3, F - 1, h, w)
320
+ ],
321
+ dim=1).to(self.device)
322
+ ])[0]
323
+ y = torch.concat([msk, y])
324
+
325
+ @contextmanager
326
+ def noop_no_sync():
327
+ yield
328
+
329
+ no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
330
+ noop_no_sync)
331
+ no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
332
+ noop_no_sync)
333
+
334
+ # evaluation mode
335
+ with (
336
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
337
+ torch.no_grad(),
338
+ no_sync_low_noise(),
339
+ no_sync_high_noise(),
340
+ ):
341
+ boundary = self.boundary * self.num_train_timesteps
342
+
343
+ if sample_solver == 'unipc':
344
+ sample_scheduler = FlowUniPCMultistepScheduler(
345
+ num_train_timesteps=self.num_train_timesteps,
346
+ shift=1,
347
+ use_dynamic_shifting=False)
348
+ sample_scheduler.set_timesteps(
349
+ sampling_steps, device=self.device, shift=shift)
350
+ timesteps = sample_scheduler.timesteps
351
+ elif sample_solver == 'dpm++':
352
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
353
+ num_train_timesteps=self.num_train_timesteps,
354
+ shift=1,
355
+ use_dynamic_shifting=False)
356
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
357
+ timesteps, _ = retrieve_timesteps(
358
+ sample_scheduler,
359
+ device=self.device,
360
+ sigmas=sampling_sigmas)
361
+ else:
362
+ raise NotImplementedError("Unsupported solver.")
363
+
364
+ # sample videos
365
+ latent = noise
366
+
367
+ arg_c = {
368
+ 'context': [context[0]],
369
+ 'seq_len': max_seq_len,
370
+ 'y': [y],
371
+ }
372
+
373
+ arg_null = {
374
+ 'context': context_null,
375
+ 'seq_len': max_seq_len,
376
+ 'y': [y],
377
+ }
378
+
379
+ if offload_model:
380
+ torch.cuda.empty_cache()
381
+
382
+ for _, t in enumerate(tqdm(timesteps)):
383
+ latent_model_input = [latent.to(self.device)]
384
+ timestep = [t]
385
+
386
+ timestep = torch.stack(timestep).to(self.device)
387
+
388
+ model = self._prepare_model_for_timestep(
389
+ t, boundary, offload_model)
390
+ sample_guide_scale = guide_scale[1] if t.item(
391
+ ) >= boundary else guide_scale[0]
392
+
393
+ noise_pred_cond = model(
394
+ latent_model_input, t=timestep, **arg_c)[0]
395
+ if offload_model:
396
+ torch.cuda.empty_cache()
397
+ noise_pred_uncond = model(
398
+ latent_model_input, t=timestep, **arg_null)[0]
399
+ if offload_model:
400
+ torch.cuda.empty_cache()
401
+ noise_pred = noise_pred_uncond + sample_guide_scale * (
402
+ noise_pred_cond - noise_pred_uncond)
403
+
404
+ temp_x0 = sample_scheduler.step(
405
+ noise_pred.unsqueeze(0),
406
+ t,
407
+ latent.unsqueeze(0),
408
+ return_dict=False,
409
+ generator=seed_g)[0]
410
+ latent = temp_x0.squeeze(0)
411
+
412
+ x0 = [latent]
413
+ del latent_model_input, timestep
414
+
415
+ if offload_model:
416
+ self.low_noise_model.cpu()
417
+ self.high_noise_model.cpu()
418
+ torch.cuda.empty_cache()
419
+
420
+ if self.rank == 0:
421
+ videos = self.vae.decode(x0)
422
+
423
+ del noise, latent, x0
424
+ del sample_scheduler
425
+ if offload_model:
426
+ gc.collect()
427
+ torch.cuda.synchronize()
428
+ if dist.is_initialized():
429
+ dist.barrier()
430
+
431
+ return videos[0] if self.rank == 0 else None
wan/modules/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .attention import flash_attention
3
+ from .model import WanModel
4
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
5
+ from .tokenizers import HuggingfaceTokenizer
6
+ from .vae2_1 import Wan2_1_VAE
7
+ from .vae2_2 import Wan2_2_VAE
8
+
9
+ __all__ = [
10
+ 'Wan2_1_VAE',
11
+ 'Wan2_2_VAE',
12
+ 'WanModel',
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ 'HuggingfaceTokenizer',
18
+ 'flash_attention',
19
+ ]
wan/modules/animate/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .model_animate import WanAnimateModel
3
+ from .clip import CLIPModel
4
+ __all__ = ['WanAnimateModel', 'CLIPModel']
wan/modules/animate/animate_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import numbers
4
+ from peft import LoraConfig
5
+
6
+
7
+ def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
8
+ target_modules = []
9
+ for name, module in transformer.named_modules():
10
+ if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
11
+ target_modules.append(name)
12
+
13
+ transformer_lora_config = LoraConfig(
14
+ r=rank,
15
+ lora_alpha=alpha,
16
+ init_lora_weights=init_lora_weights,
17
+ target_modules=target_modules,
18
+ )
19
+ return transformer_lora_config
20
+
21
+
22
+
23
+ class TensorList(object):
24
+
25
+ def __init__(self, tensors):
26
+ """
27
+ tensors: a list of torch.Tensor objects. No need to have uniform shape.
28
+ """
29
+ assert isinstance(tensors, (list, tuple))
30
+ assert all(isinstance(u, torch.Tensor) for u in tensors)
31
+ assert len(set([u.ndim for u in tensors])) == 1
32
+ assert len(set([u.dtype for u in tensors])) == 1
33
+ assert len(set([u.device for u in tensors])) == 1
34
+ self.tensors = tensors
35
+
36
+ def to(self, *args, **kwargs):
37
+ return TensorList([u.to(*args, **kwargs) for u in self.tensors])
38
+
39
+ def size(self, dim):
40
+ assert dim == 0, 'only support get the 0th size'
41
+ return len(self.tensors)
42
+
43
+ def pow(self, *args, **kwargs):
44
+ return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
45
+
46
+ def squeeze(self, dim):
47
+ assert dim != 0
48
+ if dim > 0:
49
+ dim -= 1
50
+ return TensorList([u.squeeze(dim) for u in self.tensors])
51
+
52
+ def type(self, *args, **kwargs):
53
+ return TensorList([u.type(*args, **kwargs) for u in self.tensors])
54
+
55
+ def type_as(self, other):
56
+ assert isinstance(other, (torch.Tensor, TensorList))
57
+ if isinstance(other, torch.Tensor):
58
+ return TensorList([u.type_as(other) for u in self.tensors])
59
+ else:
60
+ return TensorList([u.type(other.dtype) for u in self.tensors])
61
+
62
+ @property
63
+ def dtype(self):
64
+ return self.tensors[0].dtype
65
+
66
+ @property
67
+ def device(self):
68
+ return self.tensors[0].device
69
+
70
+ @property
71
+ def ndim(self):
72
+ return 1 + self.tensors[0].ndim
73
+
74
+ def __getitem__(self, index):
75
+ return self.tensors[index]
76
+
77
+ def __len__(self):
78
+ return len(self.tensors)
79
+
80
+ def __add__(self, other):
81
+ return self._apply(other, lambda u, v: u + v)
82
+
83
+ def __radd__(self, other):
84
+ return self._apply(other, lambda u, v: v + u)
85
+
86
+ def __sub__(self, other):
87
+ return self._apply(other, lambda u, v: u - v)
88
+
89
+ def __rsub__(self, other):
90
+ return self._apply(other, lambda u, v: v - u)
91
+
92
+ def __mul__(self, other):
93
+ return self._apply(other, lambda u, v: u * v)
94
+
95
+ def __rmul__(self, other):
96
+ return self._apply(other, lambda u, v: v * u)
97
+
98
+ def __floordiv__(self, other):
99
+ return self._apply(other, lambda u, v: u // v)
100
+
101
+ def __truediv__(self, other):
102
+ return self._apply(other, lambda u, v: u / v)
103
+
104
+ def __rfloordiv__(self, other):
105
+ return self._apply(other, lambda u, v: v // u)
106
+
107
+ def __rtruediv__(self, other):
108
+ return self._apply(other, lambda u, v: v / u)
109
+
110
+ def __pow__(self, other):
111
+ return self._apply(other, lambda u, v: u ** v)
112
+
113
+ def __rpow__(self, other):
114
+ return self._apply(other, lambda u, v: v ** u)
115
+
116
+ def __neg__(self):
117
+ return TensorList([-u for u in self.tensors])
118
+
119
+ def __iter__(self):
120
+ for tensor in self.tensors:
121
+ yield tensor
122
+
123
+ def __repr__(self):
124
+ return 'TensorList: \n' + repr(self.tensors)
125
+
126
+ def _apply(self, other, op):
127
+ if isinstance(other, (list, tuple, TensorList)) or (
128
+ isinstance(other, torch.Tensor) and (
129
+ other.numel() > 1 or other.ndim > 1
130
+ )
131
+ ):
132
+ assert len(other) == len(self.tensors)
133
+ return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
134
+ elif isinstance(other, numbers.Number) or (
135
+ isinstance(other, torch.Tensor) and (
136
+ other.numel() == 1 and other.ndim <= 1
137
+ )
138
+ ):
139
+ return TensorList([op(u, other) for u in self.tensors])
140
+ else:
141
+ raise TypeError(
142
+ f'unsupported operand for *: "TensorList" and "{type(other)}"'
143
+ )
wan/modules/animate/clip.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+
11
+ from ..attention import flash_attention
12
+ from ..tokenizers import HuggingfaceTokenizer
13
+ from .xlm_roberta import XLMRoberta
14
+
15
+ __all__ = [
16
+ 'XLMRobertaCLIP',
17
+ 'clip_xlm_roberta_vit_h_14',
18
+ 'CLIPModel',
19
+ ]
20
+
21
+
22
+ def pos_interpolate(pos, seq_len):
23
+ if pos.size(1) == seq_len:
24
+ return pos
25
+ else:
26
+ src_grid = int(math.sqrt(pos.size(1)))
27
+ tar_grid = int(math.sqrt(seq_len))
28
+ n = pos.size(1) - src_grid * src_grid
29
+ return torch.cat([
30
+ pos[:, :n],
31
+ F.interpolate(
32
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
33
+ 0, 3, 1, 2),
34
+ size=(tar_grid, tar_grid),
35
+ mode='bicubic',
36
+ align_corners=False).flatten(2).transpose(1, 2)
37
+ ],
38
+ dim=1)
39
+
40
+
41
+ class QuickGELU(nn.Module):
42
+
43
+ def forward(self, x):
44
+ return x * torch.sigmoid(1.702 * x)
45
+
46
+
47
+ class LayerNorm(nn.LayerNorm):
48
+
49
+ def forward(self, x):
50
+ return super().forward(x.float()).type_as(x)
51
+
52
+
53
+ class SelfAttention(nn.Module):
54
+
55
+ def __init__(self,
56
+ dim,
57
+ num_heads,
58
+ causal=False,
59
+ attn_dropout=0.0,
60
+ proj_dropout=0.0):
61
+ assert dim % num_heads == 0
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = dim // num_heads
66
+ self.causal = causal
67
+ self.attn_dropout = attn_dropout
68
+ self.proj_dropout = proj_dropout
69
+
70
+ # layers
71
+ self.to_qkv = nn.Linear(dim, dim * 3)
72
+ self.proj = nn.Linear(dim, dim)
73
+
74
+ def forward(self, x):
75
+ """
76
+ x: [B, L, C].
77
+ """
78
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
79
+
80
+ # compute query, key, value
81
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
82
+
83
+ # compute attention
84
+ p = self.attn_dropout if self.training else 0.0
85
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86
+ x = x.reshape(b, s, c)
87
+
88
+ # output
89
+ x = self.proj(x)
90
+ x = F.dropout(x, self.proj_dropout, self.training)
91
+ return x
92
+
93
+
94
+ class SwiGLU(nn.Module):
95
+
96
+ def __init__(self, dim, mid_dim):
97
+ super().__init__()
98
+ self.dim = dim
99
+ self.mid_dim = mid_dim
100
+
101
+ # layers
102
+ self.fc1 = nn.Linear(dim, mid_dim)
103
+ self.fc2 = nn.Linear(dim, mid_dim)
104
+ self.fc3 = nn.Linear(mid_dim, dim)
105
+
106
+ def forward(self, x):
107
+ x = F.silu(self.fc1(x)) * self.fc2(x)
108
+ x = self.fc3(x)
109
+ return x
110
+
111
+
112
+ class AttentionBlock(nn.Module):
113
+
114
+ def __init__(self,
115
+ dim,
116
+ mlp_ratio,
117
+ num_heads,
118
+ post_norm=False,
119
+ causal=False,
120
+ activation='quick_gelu',
121
+ attn_dropout=0.0,
122
+ proj_dropout=0.0,
123
+ norm_eps=1e-5):
124
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
125
+ super().__init__()
126
+ self.dim = dim
127
+ self.mlp_ratio = mlp_ratio
128
+ self.num_heads = num_heads
129
+ self.post_norm = post_norm
130
+ self.causal = causal
131
+ self.norm_eps = norm_eps
132
+
133
+ # layers
134
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
135
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
136
+ proj_dropout)
137
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
138
+ if activation == 'swi_glu':
139
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
140
+ else:
141
+ self.mlp = nn.Sequential(
142
+ nn.Linear(dim, int(dim * mlp_ratio)),
143
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
144
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
145
+
146
+ def forward(self, x):
147
+ if self.post_norm:
148
+ x = x + self.norm1(self.attn(x))
149
+ x = x + self.norm2(self.mlp(x))
150
+ else:
151
+ x = x + self.attn(self.norm1(x))
152
+ x = x + self.mlp(self.norm2(x))
153
+ return x
154
+
155
+
156
+ class AttentionPool(nn.Module):
157
+
158
+ def __init__(self,
159
+ dim,
160
+ mlp_ratio,
161
+ num_heads,
162
+ activation='gelu',
163
+ proj_dropout=0.0,
164
+ norm_eps=1e-5):
165
+ assert dim % num_heads == 0
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.mlp_ratio = mlp_ratio
169
+ self.num_heads = num_heads
170
+ self.head_dim = dim // num_heads
171
+ self.proj_dropout = proj_dropout
172
+ self.norm_eps = norm_eps
173
+
174
+ # layers
175
+ gain = 1.0 / math.sqrt(dim)
176
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
177
+ self.to_q = nn.Linear(dim, dim)
178
+ self.to_kv = nn.Linear(dim, dim * 2)
179
+ self.proj = nn.Linear(dim, dim)
180
+ self.norm = LayerNorm(dim, eps=norm_eps)
181
+ self.mlp = nn.Sequential(
182
+ nn.Linear(dim, int(dim * mlp_ratio)),
183
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
184
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
185
+
186
+ def forward(self, x):
187
+ """
188
+ x: [B, L, C].
189
+ """
190
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
191
+
192
+ # compute query, key, value
193
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
194
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195
+
196
+ # compute attention
197
+ x = flash_attention(q, k, v, version=2)
198
+ x = x.reshape(b, 1, c)
199
+
200
+ # output
201
+ x = self.proj(x)
202
+ x = F.dropout(x, self.proj_dropout, self.training)
203
+
204
+ # mlp
205
+ x = x + self.mlp(self.norm(x))
206
+ return x[:, 0]
207
+
208
+
209
+ class VisionTransformer(nn.Module):
210
+
211
+ def __init__(self,
212
+ image_size=224,
213
+ patch_size=16,
214
+ dim=768,
215
+ mlp_ratio=4,
216
+ out_dim=512,
217
+ num_heads=12,
218
+ num_layers=12,
219
+ pool_type='token',
220
+ pre_norm=True,
221
+ post_norm=False,
222
+ activation='quick_gelu',
223
+ attn_dropout=0.0,
224
+ proj_dropout=0.0,
225
+ embedding_dropout=0.0,
226
+ norm_eps=1e-5):
227
+ if image_size % patch_size != 0:
228
+ print(
229
+ '[WARNING] image_size is not divisible by patch_size',
230
+ flush=True)
231
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
232
+ out_dim = out_dim or dim
233
+ super().__init__()
234
+ self.image_size = image_size
235
+ self.patch_size = patch_size
236
+ self.num_patches = (image_size // patch_size)**2
237
+ self.dim = dim
238
+ self.mlp_ratio = mlp_ratio
239
+ self.out_dim = out_dim
240
+ self.num_heads = num_heads
241
+ self.num_layers = num_layers
242
+ self.pool_type = pool_type
243
+ self.post_norm = post_norm
244
+ self.norm_eps = norm_eps
245
+
246
+ # embeddings
247
+ gain = 1.0 / math.sqrt(dim)
248
+ self.patch_embedding = nn.Conv2d(
249
+ 3,
250
+ dim,
251
+ kernel_size=patch_size,
252
+ stride=patch_size,
253
+ bias=not pre_norm)
254
+ if pool_type in ('token', 'token_fc'):
255
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
256
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
257
+ 1, self.num_patches +
258
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
259
+ self.dropout = nn.Dropout(embedding_dropout)
260
+
261
+ # transformer
262
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
263
+ self.transformer = nn.Sequential(*[
264
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
265
+ activation, attn_dropout, proj_dropout, norm_eps)
266
+ for _ in range(num_layers)
267
+ ])
268
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
269
+
270
+ # head
271
+ if pool_type == 'token':
272
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
273
+ elif pool_type == 'token_fc':
274
+ self.head = nn.Linear(dim, out_dim)
275
+ elif pool_type == 'attn_pool':
276
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
277
+ proj_dropout, norm_eps)
278
+
279
+ def forward(self, x, interpolation=False, use_31_block=False):
280
+ b = x.size(0)
281
+
282
+ # embeddings
283
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
284
+ if self.pool_type in ('token', 'token_fc'):
285
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
286
+ if interpolation:
287
+ e = pos_interpolate(self.pos_embedding, x.size(1))
288
+ else:
289
+ e = self.pos_embedding
290
+ x = self.dropout(x + e)
291
+ if self.pre_norm is not None:
292
+ x = self.pre_norm(x)
293
+
294
+ # transformer
295
+ if use_31_block:
296
+ x = self.transformer[:-1](x)
297
+ return x
298
+ else:
299
+ x = self.transformer(x)
300
+ return x
301
+
302
+
303
+ class XLMRobertaWithHead(XLMRoberta):
304
+
305
+ def __init__(self, **kwargs):
306
+ self.out_dim = kwargs.pop('out_dim')
307
+ super().__init__(**kwargs)
308
+
309
+ # head
310
+ mid_dim = (self.dim + self.out_dim) // 2
311
+ self.head = nn.Sequential(
312
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
313
+ nn.Linear(mid_dim, self.out_dim, bias=False))
314
+
315
+ def forward(self, ids):
316
+ # xlm-roberta
317
+ x = super().forward(ids)
318
+
319
+ # average pooling
320
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
321
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
322
+
323
+ # head
324
+ x = self.head(x)
325
+ return x
326
+
327
+
328
+ class XLMRobertaCLIP(nn.Module):
329
+
330
+ def __init__(self,
331
+ embed_dim=1024,
332
+ image_size=224,
333
+ patch_size=14,
334
+ vision_dim=1280,
335
+ vision_mlp_ratio=4,
336
+ vision_heads=16,
337
+ vision_layers=32,
338
+ vision_pool='token',
339
+ vision_pre_norm=True,
340
+ vision_post_norm=False,
341
+ activation='gelu',
342
+ vocab_size=250002,
343
+ max_text_len=514,
344
+ type_size=1,
345
+ pad_id=1,
346
+ text_dim=1024,
347
+ text_heads=16,
348
+ text_layers=24,
349
+ text_post_norm=True,
350
+ text_dropout=0.1,
351
+ attn_dropout=0.0,
352
+ proj_dropout=0.0,
353
+ embedding_dropout=0.0,
354
+ norm_eps=1e-5):
355
+ super().__init__()
356
+ self.embed_dim = embed_dim
357
+ self.image_size = image_size
358
+ self.patch_size = patch_size
359
+ self.vision_dim = vision_dim
360
+ self.vision_mlp_ratio = vision_mlp_ratio
361
+ self.vision_heads = vision_heads
362
+ self.vision_layers = vision_layers
363
+ self.vision_pre_norm = vision_pre_norm
364
+ self.vision_post_norm = vision_post_norm
365
+ self.activation = activation
366
+ self.vocab_size = vocab_size
367
+ self.max_text_len = max_text_len
368
+ self.type_size = type_size
369
+ self.pad_id = pad_id
370
+ self.text_dim = text_dim
371
+ self.text_heads = text_heads
372
+ self.text_layers = text_layers
373
+ self.text_post_norm = text_post_norm
374
+ self.norm_eps = norm_eps
375
+
376
+ # models
377
+ self.visual = VisionTransformer(
378
+ image_size=image_size,
379
+ patch_size=patch_size,
380
+ dim=vision_dim,
381
+ mlp_ratio=vision_mlp_ratio,
382
+ out_dim=embed_dim,
383
+ num_heads=vision_heads,
384
+ num_layers=vision_layers,
385
+ pool_type=vision_pool,
386
+ pre_norm=vision_pre_norm,
387
+ post_norm=vision_post_norm,
388
+ activation=activation,
389
+ attn_dropout=attn_dropout,
390
+ proj_dropout=proj_dropout,
391
+ embedding_dropout=embedding_dropout,
392
+ norm_eps=norm_eps)
393
+ self.textual = XLMRobertaWithHead(
394
+ vocab_size=vocab_size,
395
+ max_seq_len=max_text_len,
396
+ type_size=type_size,
397
+ pad_id=pad_id,
398
+ dim=text_dim,
399
+ out_dim=embed_dim,
400
+ num_heads=text_heads,
401
+ num_layers=text_layers,
402
+ post_norm=text_post_norm,
403
+ dropout=text_dropout)
404
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
405
+
406
+ def forward(self, imgs, txt_ids):
407
+ """
408
+ imgs: [B, 3, H, W] of torch.float32.
409
+ - mean: [0.48145466, 0.4578275, 0.40821073]
410
+ - std: [0.26862954, 0.26130258, 0.27577711]
411
+ txt_ids: [B, L] of torch.long.
412
+ Encoded by data.CLIPTokenizer.
413
+ """
414
+ xi = self.visual(imgs)
415
+ xt = self.textual(txt_ids)
416
+ return xi, xt
417
+
418
+ def param_groups(self):
419
+ groups = [{
420
+ 'params': [
421
+ p for n, p in self.named_parameters()
422
+ if 'norm' in n or n.endswith('bias')
423
+ ],
424
+ 'weight_decay': 0.0
425
+ }, {
426
+ 'params': [
427
+ p for n, p in self.named_parameters()
428
+ if not ('norm' in n or n.endswith('bias'))
429
+ ]
430
+ }]
431
+ return groups
432
+
433
+
434
+ def _clip(pretrained=False,
435
+ pretrained_name=None,
436
+ model_cls=XLMRobertaCLIP,
437
+ return_transforms=False,
438
+ return_tokenizer=False,
439
+ tokenizer_padding='eos',
440
+ dtype=torch.float32,
441
+ device='cpu',
442
+ **kwargs):
443
+ # init a model on device
444
+ with torch.device(device):
445
+ model = model_cls(**kwargs)
446
+
447
+ # set device
448
+ model = model.to(dtype=dtype, device=device)
449
+ output = (model,)
450
+
451
+ # init transforms
452
+ if return_transforms:
453
+ # mean and std
454
+ if 'siglip' in pretrained_name.lower():
455
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
456
+ else:
457
+ mean = [0.48145466, 0.4578275, 0.40821073]
458
+ std = [0.26862954, 0.26130258, 0.27577711]
459
+
460
+ # transforms
461
+ transforms = T.Compose([
462
+ T.Resize((model.image_size, model.image_size),
463
+ interpolation=T.InterpolationMode.BICUBIC),
464
+ T.ToTensor(),
465
+ T.Normalize(mean=mean, std=std)
466
+ ])
467
+ output += (transforms,)
468
+ return output[0] if len(output) == 1 else output
469
+
470
+
471
+ def clip_xlm_roberta_vit_h_14(
472
+ pretrained=False,
473
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
474
+ **kwargs):
475
+ cfg = dict(
476
+ embed_dim=1024,
477
+ image_size=224,
478
+ patch_size=14,
479
+ vision_dim=1280,
480
+ vision_mlp_ratio=4,
481
+ vision_heads=16,
482
+ vision_layers=32,
483
+ vision_pool='token',
484
+ activation='gelu',
485
+ vocab_size=250002,
486
+ max_text_len=514,
487
+ type_size=1,
488
+ pad_id=1,
489
+ text_dim=1024,
490
+ text_heads=16,
491
+ text_layers=24,
492
+ text_post_norm=True,
493
+ text_dropout=0.1,
494
+ attn_dropout=0.0,
495
+ proj_dropout=0.0,
496
+ embedding_dropout=0.0)
497
+ cfg.update(**kwargs)
498
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
499
+
500
+
501
+ class CLIPModel:
502
+
503
+ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
504
+ self.dtype = dtype
505
+ self.device = device
506
+ self.checkpoint_path = checkpoint_path
507
+ self.tokenizer_path = tokenizer_path
508
+
509
+ # init model
510
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
511
+ pretrained=False,
512
+ return_transforms=True,
513
+ return_tokenizer=False,
514
+ dtype=dtype,
515
+ device=device)
516
+ self.model = self.model.eval().requires_grad_(False)
517
+ logging.info(f'loading {checkpoint_path}')
518
+ self.model.load_state_dict(
519
+ torch.load(checkpoint_path, map_location='cpu'))
520
+
521
+ # init tokenizer
522
+ self.tokenizer = HuggingfaceTokenizer(
523
+ name=tokenizer_path,
524
+ seq_len=self.model.max_text_len - 2,
525
+ clean='whitespace')
526
+
527
+ def visual(self, videos):
528
+ # preprocess
529
+ size = (self.model.image_size,) * 2
530
+ videos = torch.cat([
531
+ F.interpolate(
532
+ u.transpose(0, 1),
533
+ size=size,
534
+ mode='bicubic',
535
+ align_corners=False) for u in videos
536
+ ])
537
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
538
+
539
+ # forward
540
+ with torch.cuda.amp.autocast(dtype=self.dtype):
541
+ out = self.model.visual(videos, use_31_block=True)
542
+ return out
wan/modules/animate/face_blocks.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from torch import nn
3
+ import torch
4
+ from typing import Tuple, Optional
5
+ from einops import rearrange
6
+ import torch.nn.functional as F
7
+ import math
8
+ from ...distributed.util import gather_forward, get_rank, get_world_size
9
+
10
+
11
+ try:
12
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
13
+ except ImportError:
14
+ flash_attn_func = None
15
+
16
+ MEMORY_LAYOUT = {
17
+ "flash": (
18
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
19
+ lambda x: x,
20
+ ),
21
+ "torch": (
22
+ lambda x: x.transpose(1, 2),
23
+ lambda x: x.transpose(1, 2),
24
+ ),
25
+ "vanilla": (
26
+ lambda x: x.transpose(1, 2),
27
+ lambda x: x.transpose(1, 2),
28
+ ),
29
+ }
30
+
31
+
32
+ def attention(
33
+ q,
34
+ k,
35
+ v,
36
+ mode="flash",
37
+ drop_rate=0,
38
+ attn_mask=None,
39
+ causal=False,
40
+ max_seqlen_q=None,
41
+ batch_size=1,
42
+ ):
43
+ """
44
+ Perform QKV self attention.
45
+
46
+ Args:
47
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
48
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
49
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
50
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
51
+ drop_rate (float): Dropout rate in attention map. (default: 0)
52
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
53
+ (default: None)
54
+ causal (bool): Whether to use causal attention. (default: False)
55
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
56
+ used to index into q.
57
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
58
+ used to index into kv.
59
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
60
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
61
+
62
+ Returns:
63
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
64
+ """
65
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
66
+
67
+ if mode == "torch":
68
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
69
+ attn_mask = attn_mask.to(q.dtype)
70
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
71
+
72
+ elif mode == "flash":
73
+ x = flash_attn_func(
74
+ q,
75
+ k,
76
+ v,
77
+ )
78
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
79
+ elif mode == "vanilla":
80
+ scale_factor = 1 / math.sqrt(q.size(-1))
81
+
82
+ b, a, s, _ = q.shape
83
+ s1 = k.size(2)
84
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
85
+ if causal:
86
+ # Only applied to self attention
87
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
88
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
89
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
90
+ attn_bias.to(q.dtype)
91
+
92
+ if attn_mask is not None:
93
+ if attn_mask.dtype == torch.bool:
94
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
95
+ else:
96
+ attn_bias += attn_mask
97
+
98
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
99
+ attn += attn_bias
100
+ attn = attn.softmax(dim=-1)
101
+ attn = torch.dropout(attn, p=drop_rate, train=True)
102
+ x = attn @ v
103
+ else:
104
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
105
+
106
+ x = post_attn_layout(x)
107
+ b, s, a, d = x.shape
108
+ out = x.reshape(b, s, -1)
109
+ return out
110
+
111
+
112
+ class CausalConv1d(nn.Module):
113
+
114
+ def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
115
+ super().__init__()
116
+
117
+ self.pad_mode = pad_mode
118
+ padding = (kernel_size - 1, 0) # T
119
+ self.time_causal_padding = padding
120
+
121
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
122
+
123
+ def forward(self, x):
124
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
125
+ return self.conv(x)
126
+
127
+
128
+
129
+ class FaceEncoder(nn.Module):
130
+ def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
131
+ factory_kwargs = {"dtype": dtype, "device": device}
132
+ super().__init__()
133
+
134
+ self.num_heads = num_heads
135
+ self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
136
+ self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
137
+ self.act = nn.SiLU()
138
+ self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
139
+ self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
140
+
141
+ self.out_proj = nn.Linear(1024, hidden_dim)
142
+ self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
143
+
144
+ self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
145
+
146
+ self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
147
+
148
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
149
+
150
+ def forward(self, x):
151
+
152
+ x = rearrange(x, "b t c -> b c t")
153
+ b, c, t = x.shape
154
+
155
+ x = self.conv1_local(x)
156
+ x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
157
+
158
+ x = self.norm1(x)
159
+ x = self.act(x)
160
+ x = rearrange(x, "b t c -> b c t")
161
+ x = self.conv2(x)
162
+ x = rearrange(x, "b c t -> b t c")
163
+ x = self.norm2(x)
164
+ x = self.act(x)
165
+ x = rearrange(x, "b t c -> b c t")
166
+ x = self.conv3(x)
167
+ x = rearrange(x, "b c t -> b t c")
168
+ x = self.norm3(x)
169
+ x = self.act(x)
170
+ x = self.out_proj(x)
171
+ x = rearrange(x, "(b n) t c -> b t n c", b=b)
172
+ padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
173
+ x = torch.cat([x, padding], dim=-2)
174
+ x_local = x.clone()
175
+
176
+ return x_local
177
+
178
+
179
+
180
+ class RMSNorm(nn.Module):
181
+ def __init__(
182
+ self,
183
+ dim: int,
184
+ elementwise_affine=True,
185
+ eps: float = 1e-6,
186
+ device=None,
187
+ dtype=None,
188
+ ):
189
+ """
190
+ Initialize the RMSNorm normalization layer.
191
+
192
+ Args:
193
+ dim (int): The dimension of the input tensor.
194
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
195
+
196
+ Attributes:
197
+ eps (float): A small value added to the denominator for numerical stability.
198
+ weight (nn.Parameter): Learnable scaling parameter.
199
+
200
+ """
201
+ factory_kwargs = {"device": device, "dtype": dtype}
202
+ super().__init__()
203
+ self.eps = eps
204
+ if elementwise_affine:
205
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
206
+
207
+ def _norm(self, x):
208
+ """
209
+ Apply the RMSNorm normalization to the input tensor.
210
+
211
+ Args:
212
+ x (torch.Tensor): The input tensor.
213
+
214
+ Returns:
215
+ torch.Tensor: The normalized tensor.
216
+
217
+ """
218
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
219
+
220
+ def forward(self, x):
221
+ """
222
+ Forward pass through the RMSNorm layer.
223
+
224
+ Args:
225
+ x (torch.Tensor): The input tensor.
226
+
227
+ Returns:
228
+ torch.Tensor: The output tensor after applying RMSNorm.
229
+
230
+ """
231
+ output = self._norm(x.float()).type_as(x)
232
+ if hasattr(self, "weight"):
233
+ output = output * self.weight
234
+ return output
235
+
236
+
237
+ def get_norm_layer(norm_layer):
238
+ """
239
+ Get the normalization layer.
240
+
241
+ Args:
242
+ norm_layer (str): The type of normalization layer.
243
+
244
+ Returns:
245
+ norm_layer (nn.Module): The normalization layer.
246
+ """
247
+ if norm_layer == "layer":
248
+ return nn.LayerNorm
249
+ elif norm_layer == "rms":
250
+ return RMSNorm
251
+ else:
252
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
253
+
254
+
255
+ class FaceAdapter(nn.Module):
256
+ def __init__(
257
+ self,
258
+ hidden_dim: int,
259
+ heads_num: int,
260
+ qk_norm: bool = True,
261
+ qk_norm_type: str = "rms",
262
+ num_adapter_layers: int = 1,
263
+ dtype=None,
264
+ device=None,
265
+ ):
266
+
267
+ factory_kwargs = {"dtype": dtype, "device": device}
268
+ super().__init__()
269
+ self.hidden_size = hidden_dim
270
+ self.heads_num = heads_num
271
+ self.fuser_blocks = nn.ModuleList(
272
+ [
273
+ FaceBlock(
274
+ self.hidden_size,
275
+ self.heads_num,
276
+ qk_norm=qk_norm,
277
+ qk_norm_type=qk_norm_type,
278
+ **factory_kwargs,
279
+ )
280
+ for _ in range(num_adapter_layers)
281
+ ]
282
+ )
283
+
284
+ def forward(
285
+ self,
286
+ x: torch.Tensor,
287
+ motion_embed: torch.Tensor,
288
+ idx: int,
289
+ freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
290
+ freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
291
+ ) -> torch.Tensor:
292
+
293
+ return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
294
+
295
+
296
+
297
+ class FaceBlock(nn.Module):
298
+ def __init__(
299
+ self,
300
+ hidden_size: int,
301
+ heads_num: int,
302
+ qk_norm: bool = True,
303
+ qk_norm_type: str = "rms",
304
+ qk_scale: float = None,
305
+ dtype: Optional[torch.dtype] = None,
306
+ device: Optional[torch.device] = None,
307
+ ):
308
+ factory_kwargs = {"device": device, "dtype": dtype}
309
+ super().__init__()
310
+
311
+ self.deterministic = False
312
+ self.hidden_size = hidden_size
313
+ self.heads_num = heads_num
314
+ head_dim = hidden_size // heads_num
315
+ self.scale = qk_scale or head_dim**-0.5
316
+
317
+ self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
318
+ self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
319
+
320
+ self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
321
+
322
+ qk_norm_layer = get_norm_layer(qk_norm_type)
323
+ self.q_norm = (
324
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
325
+ )
326
+ self.k_norm = (
327
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
328
+ )
329
+
330
+ self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
331
+
332
+ self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
333
+
334
+ def forward(
335
+ self,
336
+ x: torch.Tensor,
337
+ motion_vec: torch.Tensor,
338
+ motion_mask: Optional[torch.Tensor] = None,
339
+ use_context_parallel=False,
340
+ ) -> torch.Tensor:
341
+
342
+ B, T, N, C = motion_vec.shape
343
+ T_comp = T
344
+
345
+ x_motion = self.pre_norm_motion(motion_vec)
346
+ x_feat = self.pre_norm_feat(x)
347
+
348
+ kv = self.linear1_kv(x_motion)
349
+ q = self.linear1_q(x_feat)
350
+
351
+ k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
352
+ q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
353
+
354
+ # Apply QK-Norm if needed.
355
+ q = self.q_norm(q).to(v)
356
+ k = self.k_norm(k).to(v)
357
+
358
+ k = rearrange(k, "B L N H D -> (B L) N H D")
359
+ v = rearrange(v, "B L N H D -> (B L) N H D")
360
+
361
+ if use_context_parallel:
362
+ q = gather_forward(q, dim=1)
363
+
364
+ q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
365
+ # Compute attention.
366
+ attn = attention(
367
+ q,
368
+ k,
369
+ v,
370
+ max_seqlen_q=q.shape[1],
371
+ batch_size=q.shape[0],
372
+ )
373
+
374
+ attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
375
+ if use_context_parallel:
376
+ attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
377
+
378
+ output = self.linear2(attn)
379
+
380
+ if motion_mask is not None:
381
+ output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
382
+
383
+ return output
wan/modules/animate/model_animate.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+ import types
4
+ from copy import deepcopy
5
+ from einops import rearrange
6
+ from typing import List
7
+ import numpy as np
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.loaders import PeftAdapterMixin
14
+
15
+ from ...distributed.sequence_parallel import (
16
+ distributed_attention,
17
+ gather_forward,
18
+ get_rank,
19
+ get_world_size,
20
+ )
21
+
22
+
23
+ from ..model import (
24
+ Head,
25
+ WanAttentionBlock,
26
+ WanLayerNorm,
27
+ WanRMSNorm,
28
+ WanModel,
29
+ WanSelfAttention,
30
+ flash_attention,
31
+ rope_params,
32
+ sinusoidal_embedding_1d,
33
+ rope_apply
34
+ )
35
+
36
+ from .face_blocks import FaceEncoder, FaceAdapter
37
+ from .motion_encoder import Generator
38
+
39
+ class HeadAnimate(Head):
40
+
41
+ def forward(self, x, e):
42
+ """
43
+ Args:
44
+ x(Tensor): Shape [B, L1, C]
45
+ e(Tensor): Shape [B, L1, C]
46
+ """
47
+ assert e.dtype == torch.float32
48
+ with amp.autocast(dtype=torch.float32):
49
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
50
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
51
+ return x
52
+
53
+
54
+ class WanAnimateSelfAttention(WanSelfAttention):
55
+
56
+ def forward(self, x, seq_lens, grid_sizes, freqs):
57
+ """
58
+ Args:
59
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
60
+ seq_lens(Tensor): Shape [B]
61
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
62
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
63
+ """
64
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
65
+
66
+ # query, key, value function
67
+ def qkv_fn(x):
68
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
69
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
70
+ v = self.v(x).view(b, s, n, d)
71
+ return q, k, v
72
+
73
+ q, k, v = qkv_fn(x)
74
+
75
+ x = flash_attention(
76
+ q=rope_apply(q, grid_sizes, freqs),
77
+ k=rope_apply(k, grid_sizes, freqs),
78
+ v=v,
79
+ k_lens=seq_lens,
80
+ window_size=self.window_size)
81
+
82
+ # output
83
+ x = x.flatten(2)
84
+ x = self.o(x)
85
+ return x
86
+
87
+
88
+ class WanAnimateCrossAttention(WanSelfAttention):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ num_heads,
93
+ window_size=(-1, -1),
94
+ qk_norm=True,
95
+ eps=1e-6,
96
+ use_img_emb=True
97
+ ):
98
+ super().__init__(
99
+ dim,
100
+ num_heads,
101
+ window_size,
102
+ qk_norm,
103
+ eps
104
+ )
105
+ self.use_img_emb = use_img_emb
106
+
107
+ if use_img_emb:
108
+ self.k_img = nn.Linear(dim, dim)
109
+ self.v_img = nn.Linear(dim, dim)
110
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
111
+
112
+ def forward(self, x, context, context_lens):
113
+ """
114
+ x: [B, L1, C].
115
+ context: [B, L2, C].
116
+ context_lens: [B].
117
+ """
118
+ if self.use_img_emb:
119
+ context_img = context[:, :257]
120
+ context = context[:, 257:]
121
+ else:
122
+ context = context
123
+
124
+ b, n, d = x.size(0), self.num_heads, self.head_dim
125
+
126
+ # compute query, key, value
127
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
128
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
129
+ v = self.v(context).view(b, -1, n, d)
130
+
131
+ if self.use_img_emb:
132
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
133
+ v_img = self.v_img(context_img).view(b, -1, n, d)
134
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
135
+ # compute attention
136
+ x = flash_attention(q, k, v, k_lens=context_lens)
137
+
138
+ # output
139
+ x = x.flatten(2)
140
+
141
+ if self.use_img_emb:
142
+ img_x = img_x.flatten(2)
143
+ x = x + img_x
144
+
145
+ x = self.o(x)
146
+ return x
147
+
148
+
149
+ class WanAnimateAttentionBlock(nn.Module):
150
+ def __init__(self,
151
+ dim,
152
+ ffn_dim,
153
+ num_heads,
154
+ window_size=(-1, -1),
155
+ qk_norm=True,
156
+ cross_attn_norm=True,
157
+ eps=1e-6,
158
+ use_img_emb=True):
159
+
160
+ super().__init__()
161
+ self.dim = dim
162
+ self.ffn_dim = ffn_dim
163
+ self.num_heads = num_heads
164
+ self.window_size = window_size
165
+ self.qk_norm = qk_norm
166
+ self.cross_attn_norm = cross_attn_norm
167
+ self.eps = eps
168
+
169
+ # layers
170
+ self.norm1 = WanLayerNorm(dim, eps)
171
+ self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)
172
+
173
+ self.norm3 = WanLayerNorm(
174
+ dim, eps, elementwise_affine=True
175
+ ) if cross_attn_norm else nn.Identity()
176
+
177
+ self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)
178
+ self.norm2 = WanLayerNorm(dim, eps)
179
+ self.ffn = nn.Sequential(
180
+ nn.Linear(dim, ffn_dim),
181
+ nn.GELU(approximate='tanh'),
182
+ nn.Linear(ffn_dim, dim)
183
+ )
184
+
185
+ # modulation
186
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
187
+
188
+ def forward(
189
+ self,
190
+ x,
191
+ e,
192
+ seq_lens,
193
+ grid_sizes,
194
+ freqs,
195
+ context,
196
+ context_lens,
197
+ ):
198
+ """
199
+ Args:
200
+ x(Tensor): Shape [B, L, C]
201
+ e(Tensor): Shape [B, L1, 6, C]
202
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
203
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
204
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
205
+ """
206
+ assert e.dtype == torch.float32
207
+ with amp.autocast(dtype=torch.float32):
208
+ e = (self.modulation + e).chunk(6, dim=1)
209
+ assert e[0].dtype == torch.float32
210
+
211
+ # self-attention
212
+ y = self.self_attn(
213
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
214
+ )
215
+ with amp.autocast(dtype=torch.float32):
216
+ x = x + y * e[2]
217
+
218
+ # cross-attention & ffn function
219
+ def cross_attn_ffn(x, context, context_lens, e):
220
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
221
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
222
+ with amp.autocast(dtype=torch.float32):
223
+ x = x + y * e[5]
224
+ return x
225
+
226
+ x = cross_attn_ffn(x, context, context_lens, e)
227
+ return x
228
+
229
+
230
+ class MLPProj(torch.nn.Module):
231
+ def __init__(self, in_dim, out_dim):
232
+ super().__init__()
233
+
234
+ self.proj = torch.nn.Sequential(
235
+ torch.nn.LayerNorm(in_dim),
236
+ torch.nn.Linear(in_dim, in_dim),
237
+ torch.nn.GELU(),
238
+ torch.nn.Linear(in_dim, out_dim),
239
+ torch.nn.LayerNorm(out_dim),
240
+ )
241
+
242
+ def forward(self, image_embeds):
243
+ clip_extra_context_tokens = self.proj(image_embeds)
244
+ return clip_extra_context_tokens
245
+
246
+ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
247
+ _no_split_modules = ['WanAttentionBlock']
248
+
249
+ @register_to_config
250
+ def __init__(self,
251
+ patch_size=(1, 2, 2),
252
+ text_len=512,
253
+ in_dim=36,
254
+ dim=5120,
255
+ ffn_dim=13824,
256
+ freq_dim=256,
257
+ text_dim=4096,
258
+ out_dim=16,
259
+ num_heads=40,
260
+ num_layers=40,
261
+ window_size=(-1, -1),
262
+ qk_norm=True,
263
+ cross_attn_norm=True,
264
+ eps=1e-6,
265
+ motion_encoder_dim=512,
266
+ use_context_parallel=False,
267
+ use_img_emb=True):
268
+
269
+ super().__init__()
270
+ self.patch_size = patch_size
271
+ self.text_len = text_len
272
+ self.in_dim = in_dim
273
+ self.dim = dim
274
+ self.ffn_dim = ffn_dim
275
+ self.freq_dim = freq_dim
276
+ self.text_dim = text_dim
277
+ self.out_dim = out_dim
278
+ self.num_heads = num_heads
279
+ self.num_layers = num_layers
280
+ self.window_size = window_size
281
+ self.qk_norm = qk_norm
282
+ self.cross_attn_norm = cross_attn_norm
283
+ self.eps = eps
284
+ self.motion_encoder_dim = motion_encoder_dim
285
+ self.use_context_parallel = use_context_parallel
286
+ self.use_img_emb = use_img_emb
287
+
288
+ # embeddings
289
+ self.patch_embedding = nn.Conv3d(
290
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
291
+
292
+ self.pose_patch_embedding = nn.Conv3d(
293
+ 16, dim, kernel_size=patch_size, stride=patch_size
294
+ )
295
+
296
+ self.text_embedding = nn.Sequential(
297
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
298
+ nn.Linear(dim, dim))
299
+
300
+ self.time_embedding = nn.Sequential(
301
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
302
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
303
+
304
+ # blocks
305
+ self.blocks = nn.ModuleList([
306
+ WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
307
+ cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)
308
+ ])
309
+
310
+ # head
311
+ self.head = HeadAnimate(dim, out_dim, patch_size, eps)
312
+
313
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
314
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
315
+ d = dim // num_heads
316
+ self.freqs = torch.cat([
317
+ rope_params(1024, d - 4 * (d // 6)),
318
+ rope_params(1024, 2 * (d // 6)),
319
+ rope_params(1024, 2 * (d // 6))
320
+ ], dim=1)
321
+
322
+ self.img_emb = MLPProj(1280, dim)
323
+
324
+ # initialize weights
325
+ self.init_weights()
326
+
327
+ self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
328
+ self.face_adapter = FaceAdapter(
329
+ heads_num=self.num_heads,
330
+ hidden_dim=self.dim,
331
+ num_adapter_layers=self.num_layers // 5,
332
+ )
333
+
334
+ self.face_encoder = FaceEncoder(
335
+ in_dim=motion_encoder_dim,
336
+ hidden_dim=self.dim,
337
+ num_heads=4,
338
+ )
339
+
340
+ def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
341
+ pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
342
+ for x_, pose_latents_ in zip(x, pose_latents):
343
+ x_[:, :, 1:] += pose_latents_
344
+
345
+ b,c,T,h,w = face_pixel_values.shape
346
+ face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
347
+
348
+ encode_bs = 8
349
+ face_pixel_values_tmp = []
350
+ for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
351
+ face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
352
+
353
+ motion_vec = torch.cat(face_pixel_values_tmp)
354
+
355
+ motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
356
+ motion_vec = self.face_encoder(motion_vec)
357
+
358
+ B, L, H, C = motion_vec.shape
359
+ pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
360
+ motion_vec = torch.cat([pad_face, motion_vec], dim=1)
361
+ return x, motion_vec
362
+
363
+
364
+ def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
365
+ if block_idx % 5 == 0:
366
+ adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]
367
+ residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
368
+ x = residual_out + x
369
+ return x
370
+
371
+
372
+ def forward(
373
+ self,
374
+ x,
375
+ t,
376
+ clip_fea,
377
+ context,
378
+ seq_len,
379
+ y=None,
380
+ pose_latents=None,
381
+ face_pixel_values=None
382
+ ):
383
+ # params
384
+ device = self.patch_embedding.weight.device
385
+ if self.freqs.device != device:
386
+ self.freqs = self.freqs.to(device)
387
+
388
+ if y is not None:
389
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
390
+
391
+ # embeddings
392
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
393
+ x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
394
+
395
+ grid_sizes = torch.stack(
396
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
397
+ x = [u.flatten(2).transpose(1, 2) for u in x]
398
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
399
+ assert seq_lens.max() <= seq_len
400
+ x = torch.cat([
401
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
402
+ dim=1) for u in x
403
+ ])
404
+
405
+ # time embeddings
406
+ with amp.autocast(dtype=torch.float32):
407
+ e = self.time_embedding(
408
+ sinusoidal_embedding_1d(self.freq_dim, t).float()
409
+ )
410
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
411
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
412
+
413
+ # context
414
+ context_lens = None
415
+ context = self.text_embedding(
416
+ torch.stack([
417
+ torch.cat(
418
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
419
+ for u in context
420
+ ]))
421
+
422
+ if self.use_img_emb:
423
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
424
+ context = torch.concat([context_clip, context], dim=1)
425
+
426
+ # arguments
427
+ kwargs = dict(
428
+ e=e0,
429
+ seq_lens=seq_lens,
430
+ grid_sizes=grid_sizes,
431
+ freqs=self.freqs,
432
+ context=context,
433
+ context_lens=context_lens)
434
+
435
+ if self.use_context_parallel:
436
+ x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
437
+
438
+ for idx, block in enumerate(self.blocks):
439
+ x = block(x, **kwargs)
440
+ x = self.after_transformer_block(idx, x, motion_vec)
441
+
442
+ # head
443
+ x = self.head(x, e)
444
+
445
+ if self.use_context_parallel:
446
+ x = gather_forward(x, dim=1)
447
+
448
+ # unpatchify
449
+ x = self.unpatchify(x, grid_sizes)
450
+ return [u.float() for u in x]
451
+
452
+
453
+ def unpatchify(self, x, grid_sizes):
454
+ r"""
455
+ Reconstruct video tensors from patch embeddings.
456
+
457
+ Args:
458
+ x (List[Tensor]):
459
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
460
+ grid_sizes (Tensor):
461
+ Original spatial-temporal grid dimensions before patching,
462
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
463
+
464
+ Returns:
465
+ List[Tensor]:
466
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
467
+ """
468
+
469
+ c = self.out_dim
470
+ out = []
471
+ for u, v in zip(x, grid_sizes.tolist()):
472
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
473
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
474
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
475
+ out.append(u)
476
+ return out
477
+
478
+ def init_weights(self):
479
+ r"""
480
+ Initialize model parameters using Xavier initialization.
481
+ """
482
+
483
+ # basic init
484
+ for m in self.modules():
485
+ if isinstance(m, nn.Linear):
486
+ nn.init.xavier_uniform_(m.weight)
487
+ if m.bias is not None:
488
+ nn.init.zeros_(m.bias)
489
+
490
+ # init embeddings
491
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
492
+ for m in self.text_embedding.modules():
493
+ if isinstance(m, nn.Linear):
494
+ nn.init.normal_(m.weight, std=.02)
495
+ for m in self.time_embedding.modules():
496
+ if isinstance(m, nn.Linear):
497
+ nn.init.normal_(m.weight, std=.02)
498
+
499
+ # init output layer
500
+ nn.init.zeros_(self.head.head.weight)
wan/modules/animate/motion_encoder.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/wyhsirius/LIA``
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import math
7
+
8
+ def custom_qr(input_tensor):
9
+ original_dtype = input_tensor.dtype
10
+ if original_dtype == torch.bfloat16:
11
+ q, r = torch.linalg.qr(input_tensor.to(torch.float32))
12
+ return q.to(original_dtype), r.to(original_dtype)
13
+ return torch.linalg.qr(input_tensor)
14
+
15
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
16
+ return F.leaky_relu(input + bias, negative_slope) * scale
17
+
18
+
19
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
20
+ _, minor, in_h, in_w = input.shape
21
+ kernel_h, kernel_w = kernel.shape
22
+
23
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
24
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
25
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
26
+
27
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
28
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
29
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
30
+
31
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
32
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
33
+ out = F.conv2d(out, w)
34
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
35
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
36
+ return out[:, :, ::down_y, ::down_x]
37
+
38
+
39
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
40
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
41
+
42
+
43
+ def make_kernel(k):
44
+ k = torch.tensor(k, dtype=torch.float32)
45
+ if k.ndim == 1:
46
+ k = k[None, :] * k[:, None]
47
+ k /= k.sum()
48
+ return k
49
+
50
+
51
+ class FusedLeakyReLU(nn.Module):
52
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
53
+ super().__init__()
54
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
55
+ self.negative_slope = negative_slope
56
+ self.scale = scale
57
+
58
+ def forward(self, input):
59
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
60
+ return out
61
+
62
+
63
+ class Blur(nn.Module):
64
+ def __init__(self, kernel, pad, upsample_factor=1):
65
+ super().__init__()
66
+
67
+ kernel = make_kernel(kernel)
68
+
69
+ if upsample_factor > 1:
70
+ kernel = kernel * (upsample_factor ** 2)
71
+
72
+ self.register_buffer('kernel', kernel)
73
+
74
+ self.pad = pad
75
+
76
+ def forward(self, input):
77
+ return upfirdn2d(input, self.kernel, pad=self.pad)
78
+
79
+
80
+ class ScaledLeakyReLU(nn.Module):
81
+ def __init__(self, negative_slope=0.2):
82
+ super().__init__()
83
+
84
+ self.negative_slope = negative_slope
85
+
86
+ def forward(self, input):
87
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
88
+
89
+
90
+ class EqualConv2d(nn.Module):
91
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
92
+ super().__init__()
93
+
94
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
95
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
96
+
97
+ self.stride = stride
98
+ self.padding = padding
99
+
100
+ if bias:
101
+ self.bias = nn.Parameter(torch.zeros(out_channel))
102
+ else:
103
+ self.bias = None
104
+
105
+ def forward(self, input):
106
+
107
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
108
+
109
+ def __repr__(self):
110
+ return (
111
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
112
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
113
+ )
114
+
115
+
116
+ class EqualLinear(nn.Module):
117
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
118
+ super().__init__()
119
+
120
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
121
+
122
+ if bias:
123
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
124
+ else:
125
+ self.bias = None
126
+
127
+ self.activation = activation
128
+
129
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
130
+ self.lr_mul = lr_mul
131
+
132
+ def forward(self, input):
133
+
134
+ if self.activation:
135
+ out = F.linear(input, self.weight * self.scale)
136
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
137
+ else:
138
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
139
+
140
+ return out
141
+
142
+ def __repr__(self):
143
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
144
+
145
+
146
+ class ConvLayer(nn.Sequential):
147
+ def __init__(
148
+ self,
149
+ in_channel,
150
+ out_channel,
151
+ kernel_size,
152
+ downsample=False,
153
+ blur_kernel=[1, 3, 3, 1],
154
+ bias=True,
155
+ activate=True,
156
+ ):
157
+ layers = []
158
+
159
+ if downsample:
160
+ factor = 2
161
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
162
+ pad0 = (p + 1) // 2
163
+ pad1 = p // 2
164
+
165
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
166
+
167
+ stride = 2
168
+ self.padding = 0
169
+
170
+ else:
171
+ stride = 1
172
+ self.padding = kernel_size // 2
173
+
174
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
175
+ bias=bias and not activate))
176
+
177
+ if activate:
178
+ if bias:
179
+ layers.append(FusedLeakyReLU(out_channel))
180
+ else:
181
+ layers.append(ScaledLeakyReLU(0.2))
182
+
183
+ super().__init__(*layers)
184
+
185
+
186
+ class ResBlock(nn.Module):
187
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
188
+ super().__init__()
189
+
190
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
191
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
192
+
193
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
194
+
195
+ def forward(self, input):
196
+ out = self.conv1(input)
197
+ out = self.conv2(out)
198
+
199
+ skip = self.skip(input)
200
+ out = (out + skip) / math.sqrt(2)
201
+
202
+ return out
203
+
204
+
205
+ class EncoderApp(nn.Module):
206
+ def __init__(self, size, w_dim=512):
207
+ super(EncoderApp, self).__init__()
208
+
209
+ channels = {
210
+ 4: 512,
211
+ 8: 512,
212
+ 16: 512,
213
+ 32: 512,
214
+ 64: 256,
215
+ 128: 128,
216
+ 256: 64,
217
+ 512: 32,
218
+ 1024: 16
219
+ }
220
+
221
+ self.w_dim = w_dim
222
+ log_size = int(math.log(size, 2))
223
+
224
+ self.convs = nn.ModuleList()
225
+ self.convs.append(ConvLayer(3, channels[size], 1))
226
+
227
+ in_channel = channels[size]
228
+ for i in range(log_size, 2, -1):
229
+ out_channel = channels[2 ** (i - 1)]
230
+ self.convs.append(ResBlock(in_channel, out_channel))
231
+ in_channel = out_channel
232
+
233
+ self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
234
+
235
+ def forward(self, x):
236
+
237
+ res = []
238
+ h = x
239
+ for conv in self.convs:
240
+ h = conv(h)
241
+ res.append(h)
242
+
243
+ return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
244
+
245
+
246
+ class Encoder(nn.Module):
247
+ def __init__(self, size, dim=512, dim_motion=20):
248
+ super(Encoder, self).__init__()
249
+
250
+ # appearance netmork
251
+ self.net_app = EncoderApp(size, dim)
252
+
253
+ # motion network
254
+ fc = [EqualLinear(dim, dim)]
255
+ for i in range(3):
256
+ fc.append(EqualLinear(dim, dim))
257
+
258
+ fc.append(EqualLinear(dim, dim_motion))
259
+ self.fc = nn.Sequential(*fc)
260
+
261
+ def enc_app(self, x):
262
+ h_source = self.net_app(x)
263
+ return h_source
264
+
265
+ def enc_motion(self, x):
266
+ h, _ = self.net_app(x)
267
+ h_motion = self.fc(h)
268
+ return h_motion
269
+
270
+
271
+ class Direction(nn.Module):
272
+ def __init__(self, motion_dim):
273
+ super(Direction, self).__init__()
274
+ self.weight = nn.Parameter(torch.randn(512, motion_dim))
275
+
276
+ def forward(self, input):
277
+
278
+ weight = self.weight + 1e-8
279
+ Q, R = custom_qr(weight)
280
+ if input is None:
281
+ return Q
282
+ else:
283
+ input_diag = torch.diag_embed(input) # alpha, diagonal matrix
284
+ out = torch.matmul(input_diag, Q.T)
285
+ out = torch.sum(out, dim=1)
286
+ return out
287
+
288
+
289
+ class Synthesis(nn.Module):
290
+ def __init__(self, motion_dim):
291
+ super(Synthesis, self).__init__()
292
+ self.direction = Direction(motion_dim)
293
+
294
+
295
+ class Generator(nn.Module):
296
+ def __init__(self, size, style_dim=512, motion_dim=20):
297
+ super().__init__()
298
+
299
+ self.enc = Encoder(size, style_dim, motion_dim)
300
+ self.dec = Synthesis(motion_dim)
301
+
302
+ def get_motion(self, img):
303
+ #motion_feat = self.enc.enc_motion(img)
304
+ motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
305
+ with torch.cuda.amp.autocast(dtype=torch.float32):
306
+ motion = self.dec.direction(motion_feat)
307
+ return motion
wan/modules/animate/preprocess/UserGuider.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wan-animate Preprocessing User Guider
2
+
3
+ ## 1. Introductions
4
+
5
+
6
+ Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline.
7
+
8
+ ### 1.1 Animation Mode
9
+
10
+ In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar.
11
+
12
+ - A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality.
13
+
14
+ - **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding.
15
+
16
+ - Community contributions to improve on this feature are welcome.
17
+
18
+ ### 1.2 Replacement Mode
19
+
20
+ - Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment.
21
+
22
+ - **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output.
23
+
24
+ - A simplified version for extracting the character's mask is also provided.
25
+ - **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool.
26
+
27
+ ---
28
+
29
+ ## 2. Preprocessing Instructions and Recommendations
30
+
31
+ ### 2.1 Basic Usage
32
+
33
+ - The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure:
34
+ ```
35
+ /path/to/your/ckpt_path/
36
+ ├── det/
37
+ │ └── yolov10m.onnx
38
+ ├── pose2d/
39
+ │ └── vitpose_h_wholebody.onnx
40
+ ├── sam2/
41
+ │ └── sam2_hiera_large.pt
42
+ └── FLUX.1-Kontext-dev/
43
+ ```
44
+ - `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results.
45
+
46
+ - When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated.
47
+
48
+ - The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area.
49
+
50
+ - The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness.
51
+
52
+ ---
53
+
54
+ ### 2.2 Animation Mode
55
+
56
+ - We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters.
57
+
58
+ - Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose.
59
+
60
+ - Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own.
61
+
62
+ ---
63
+
64
+ ### 2.3 Replacement Mode
65
+
66
+ - Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters.
67
+ - `iterations` and `k` can make the mask larger, covering more area.
68
+ - `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer.
69
+
70
+ - A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data.
wan/modules/animate/preprocess/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .process_pipepline import ProcessPipeline
3
+ from .video_predictor import SAM2VideoPredictor
wan/modules/animate/preprocess/human_visualization.py ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import cv2
4
+ import time
5
+ import math
6
+ import matplotlib
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from typing import Dict, List
10
+ import random
11
+ from pose2d_utils import AAPoseMeta
12
+
13
+
14
+ def draw_handpose(canvas, keypoints, hand_score_th=0.6):
15
+ """
16
+ Draw keypoints and connections representing hand pose on a given canvas.
17
+
18
+ Args:
19
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
20
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
21
+ or None if no keypoints are present.
22
+
23
+ Returns:
24
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
25
+
26
+ Note:
27
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
28
+ """
29
+ eps = 0.01
30
+
31
+ H, W, C = canvas.shape
32
+ stickwidth = max(int(min(H, W) / 200), 1)
33
+
34
+ edges = [
35
+ [0, 1],
36
+ [1, 2],
37
+ [2, 3],
38
+ [3, 4],
39
+ [0, 5],
40
+ [5, 6],
41
+ [6, 7],
42
+ [7, 8],
43
+ [0, 9],
44
+ [9, 10],
45
+ [10, 11],
46
+ [11, 12],
47
+ [0, 13],
48
+ [13, 14],
49
+ [14, 15],
50
+ [15, 16],
51
+ [0, 17],
52
+ [17, 18],
53
+ [18, 19],
54
+ [19, 20],
55
+ ]
56
+
57
+ for ie, (e1, e2) in enumerate(edges):
58
+ k1 = keypoints[e1]
59
+ k2 = keypoints[e2]
60
+ if k1 is None or k2 is None:
61
+ continue
62
+ if k1[2] < hand_score_th or k2[2] < hand_score_th:
63
+ continue
64
+
65
+ x1 = int(k1[0])
66
+ y1 = int(k1[1])
67
+ x2 = int(k2[0])
68
+ y2 = int(k2[1])
69
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
70
+ cv2.line(
71
+ canvas,
72
+ (x1, y1),
73
+ (x2, y2),
74
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
75
+ thickness=stickwidth,
76
+ )
77
+
78
+ for keypoint in keypoints:
79
+
80
+ if keypoint is None:
81
+ continue
82
+ if keypoint[2] < hand_score_th:
83
+ continue
84
+
85
+ x, y = keypoint[0], keypoint[1]
86
+ x = int(x)
87
+ y = int(y)
88
+ if x > eps and y > eps:
89
+ cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
90
+ return canvas
91
+
92
+
93
+ def draw_handpose_new(canvas, keypoints, stickwidth_type='v2', hand_score_th=0.6):
94
+ """
95
+ Draw keypoints and connections representing hand pose on a given canvas.
96
+
97
+ Args:
98
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
99
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
100
+ or None if no keypoints are present.
101
+
102
+ Returns:
103
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
104
+
105
+ Note:
106
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
107
+ """
108
+ eps = 0.01
109
+
110
+ H, W, C = canvas.shape
111
+ if stickwidth_type == 'v1':
112
+ stickwidth = max(int(min(H, W) / 200), 1)
113
+ elif stickwidth_type == 'v2':
114
+ stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1)
115
+
116
+ edges = [
117
+ [0, 1],
118
+ [1, 2],
119
+ [2, 3],
120
+ [3, 4],
121
+ [0, 5],
122
+ [5, 6],
123
+ [6, 7],
124
+ [7, 8],
125
+ [0, 9],
126
+ [9, 10],
127
+ [10, 11],
128
+ [11, 12],
129
+ [0, 13],
130
+ [13, 14],
131
+ [14, 15],
132
+ [15, 16],
133
+ [0, 17],
134
+ [17, 18],
135
+ [18, 19],
136
+ [19, 20],
137
+ ]
138
+
139
+ for ie, (e1, e2) in enumerate(edges):
140
+ k1 = keypoints[e1]
141
+ k2 = keypoints[e2]
142
+ if k1 is None or k2 is None:
143
+ continue
144
+ if k1[2] < hand_score_th or k2[2] < hand_score_th:
145
+ continue
146
+
147
+ x1 = int(k1[0])
148
+ y1 = int(k1[1])
149
+ x2 = int(k2[0])
150
+ y2 = int(k2[1])
151
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
152
+ cv2.line(
153
+ canvas,
154
+ (x1, y1),
155
+ (x2, y2),
156
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
157
+ thickness=stickwidth,
158
+ )
159
+
160
+ for keypoint in keypoints:
161
+
162
+ if keypoint is None:
163
+ continue
164
+ if keypoint[2] < hand_score_th:
165
+ continue
166
+
167
+ x, y = keypoint[0], keypoint[1]
168
+ x = int(x)
169
+ y = int(y)
170
+ if x > eps and y > eps:
171
+ cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
172
+ return canvas
173
+
174
+
175
+ def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6):
176
+ H, W, C = img.shape
177
+ stickwidth = max(int(min(H, W) / 200), 1)
178
+
179
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
180
+ return img
181
+
182
+ Y = np.array([keypoint1[0], keypoint2[0]])
183
+ X = np.array([keypoint1[1], keypoint2[1]])
184
+ mX = np.mean(X)
185
+ mY = np.mean(Y)
186
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
187
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
188
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
189
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
190
+ return img
191
+
192
+
193
+ def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]:
194
+ """Convert the 133 keypoints from pose2d to body and hands keypoints.
195
+
196
+ Args:
197
+ kp2ds (np.ndarray): [133, 2]
198
+
199
+ Returns:
200
+ List[np.ndarray]: _description_
201
+ """
202
+ kp2ds_body = (
203
+ kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]]
204
+ + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]
205
+ ) / 2
206
+ kp2ds_lhand = kp2ds[91:112]
207
+ kp2ds_rhand = kp2ds[112:133]
208
+ return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()
209
+
210
+
211
+ def draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):
212
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
213
+ kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
214
+ kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
215
+ pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
216
+ return pose_img
217
+
218
+ def draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickwidth_type='v2', draw_hand=True, draw_head=True):
219
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
220
+ kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
221
+ kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
222
+ pose_img = draw_aapose_new(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand,
223
+ stickwidth_type=stickwidth_type, draw_hand=draw_hand, draw_head=draw_head)
224
+ return pose_img
225
+
226
+ def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200):
227
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1)
228
+ kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
229
+ kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
230
+ pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=True, draw_head=False)
231
+ return pose_img
232
+
233
+
234
+ def draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=False, draw_head=True):
235
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
236
+ # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
237
+ # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
238
+ pose_img = draw_M(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
239
+ return pose_img
240
+
241
+
242
+ def draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False):
243
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
244
+ # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
245
+ # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
246
+ pose_img = draw_nose(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand)
247
+ return pose_img
248
+
249
+
250
+ def gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200):
251
+
252
+ return
253
+
254
+
255
+ def draw_M(
256
+ img,
257
+ kp2ds,
258
+ threshold=0.6,
259
+ data_to_json=None,
260
+ idx=-1,
261
+ kp2ds_lhand=None,
262
+ kp2ds_rhand=None,
263
+ draw_hand=False,
264
+ stick_width_norm=200,
265
+ draw_head=True
266
+ ):
267
+ """
268
+ Draw keypoints and connections representing hand pose on a given canvas.
269
+
270
+ Args:
271
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
272
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
273
+ or None if no keypoints are present.
274
+
275
+ Returns:
276
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
277
+
278
+ Note:
279
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
280
+ """
281
+
282
+ new_kep_list = [
283
+ "Nose",
284
+ "Neck",
285
+ "RShoulder",
286
+ "RElbow",
287
+ "RWrist", # No.4
288
+ "LShoulder",
289
+ "LElbow",
290
+ "LWrist", # No.7
291
+ "RHip",
292
+ "RKnee",
293
+ "RAnkle", # No.10
294
+ "LHip",
295
+ "LKnee",
296
+ "LAnkle", # No.13
297
+ "REye",
298
+ "LEye",
299
+ "REar",
300
+ "LEar",
301
+ "LToe",
302
+ "RToe",
303
+ ]
304
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
305
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
306
+ kp2ds = kp2ds.copy()
307
+ # import ipdb; ipdb.set_trace()
308
+ kp2ds[[1,2,3,4,5,6,7,8,9,10,11,12,13,18,19], 2] = 0
309
+ if not draw_head:
310
+ kp2ds[[0,14,15,16,17], 2] = 0
311
+ kp2ds_body = kp2ds
312
+ # kp2ds_body = kp2ds_body[:18]
313
+
314
+ # kp2ds_lhand = kp2ds.copy()[91:112]
315
+ # kp2ds_rhand = kp2ds.copy()[112:133]
316
+
317
+ limbSeq = [
318
+ # [2, 3],
319
+ # [2, 6], # shoulders
320
+ # [3, 4],
321
+ # [4, 5], # left arm
322
+ # [6, 7],
323
+ # [7, 8], # right arm
324
+ # [2, 9],
325
+ # [9, 10],
326
+ # [10, 11], # right leg
327
+ # [2, 12],
328
+ # [12, 13],
329
+ # [13, 14], # left leg
330
+ # [2, 1],
331
+ [1, 15],
332
+ [15, 17],
333
+ [1, 16],
334
+ [16, 18], # face (nose, eyes, ears)
335
+ # [14, 19],
336
+ # [11, 20], # foot
337
+ ]
338
+
339
+ colors = [
340
+ # [255, 0, 0],
341
+ # [255, 85, 0],
342
+ # [255, 170, 0],
343
+ # [255, 255, 0],
344
+ # [170, 255, 0],
345
+ # [85, 255, 0],
346
+ # [0, 255, 0],
347
+ # [0, 255, 85],
348
+ # [0, 255, 170],
349
+ # [0, 255, 255],
350
+ # [0, 170, 255],
351
+ # [0, 85, 255],
352
+ # [0, 0, 255],
353
+ # [85, 0, 255],
354
+ [170, 0, 255],
355
+ [255, 0, 255],
356
+ [255, 0, 170],
357
+ [255, 0, 85],
358
+ # foot
359
+ # [200, 200, 0],
360
+ # [100, 100, 0],
361
+ ]
362
+
363
+ H, W, C = img.shape
364
+ stickwidth = max(int(min(H, W) / stick_width_norm), 1)
365
+
366
+ for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
367
+ keypoint1 = kp2ds_body[k1_index - 1]
368
+ keypoint2 = kp2ds_body[k2_index - 1]
369
+
370
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
371
+ continue
372
+
373
+ Y = np.array([keypoint1[0], keypoint2[0]])
374
+ X = np.array([keypoint1[1], keypoint2[1]])
375
+ mX = np.mean(X)
376
+ mY = np.mean(Y)
377
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
378
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
379
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
380
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
381
+
382
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
383
+ if keypoint[-1] < threshold:
384
+ continue
385
+ x, y = keypoint[0], keypoint[1]
386
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
387
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
388
+
389
+ if draw_hand:
390
+ img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
391
+ img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
392
+
393
+ kp2ds_body[:, 0] /= W
394
+ kp2ds_body[:, 1] /= H
395
+
396
+ if data_to_json is not None:
397
+ if idx == -1:
398
+ data_to_json.append(
399
+ {
400
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
401
+ "height": H,
402
+ "width": W,
403
+ "category_id": 1,
404
+ "keypoints_body": kp2ds_body.tolist(),
405
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
406
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
407
+ }
408
+ )
409
+ else:
410
+ data_to_json[idx] = {
411
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
412
+ "height": H,
413
+ "width": W,
414
+ "category_id": 1,
415
+ "keypoints_body": kp2ds_body.tolist(),
416
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
417
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
418
+ }
419
+ return img
420
+
421
+
422
+ def draw_nose(
423
+ img,
424
+ kp2ds,
425
+ threshold=0.6,
426
+ data_to_json=None,
427
+ idx=-1,
428
+ kp2ds_lhand=None,
429
+ kp2ds_rhand=None,
430
+ draw_hand=False,
431
+ stick_width_norm=200,
432
+ ):
433
+ """
434
+ Draw keypoints and connections representing hand pose on a given canvas.
435
+
436
+ Args:
437
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
438
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
439
+ or None if no keypoints are present.
440
+
441
+ Returns:
442
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
443
+
444
+ Note:
445
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
446
+ """
447
+
448
+ new_kep_list = [
449
+ "Nose",
450
+ "Neck",
451
+ "RShoulder",
452
+ "RElbow",
453
+ "RWrist", # No.4
454
+ "LShoulder",
455
+ "LElbow",
456
+ "LWrist", # No.7
457
+ "RHip",
458
+ "RKnee",
459
+ "RAnkle", # No.10
460
+ "LHip",
461
+ "LKnee",
462
+ "LAnkle", # No.13
463
+ "REye",
464
+ "LEye",
465
+ "REar",
466
+ "LEar",
467
+ "LToe",
468
+ "RToe",
469
+ ]
470
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
471
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
472
+ kp2ds = kp2ds.copy()
473
+ kp2ds[1:, 2] = 0
474
+ # kp2ds[0, 2] = 1
475
+ kp2ds_body = kp2ds
476
+ # kp2ds_body = kp2ds_body[:18]
477
+
478
+ # kp2ds_lhand = kp2ds.copy()[91:112]
479
+ # kp2ds_rhand = kp2ds.copy()[112:133]
480
+
481
+ limbSeq = [
482
+ # [2, 3],
483
+ # [2, 6], # shoulders
484
+ # [3, 4],
485
+ # [4, 5], # left arm
486
+ # [6, 7],
487
+ # [7, 8], # right arm
488
+ # [2, 9],
489
+ # [9, 10],
490
+ # [10, 11], # right leg
491
+ # [2, 12],
492
+ # [12, 13],
493
+ # [13, 14], # left leg
494
+ # [2, 1],
495
+ [1, 15],
496
+ [15, 17],
497
+ [1, 16],
498
+ [16, 18], # face (nose, eyes, ears)
499
+ # [14, 19],
500
+ # [11, 20], # foot
501
+ ]
502
+
503
+ colors = [
504
+ # [255, 0, 0],
505
+ # [255, 85, 0],
506
+ # [255, 170, 0],
507
+ # [255, 255, 0],
508
+ # [170, 255, 0],
509
+ # [85, 255, 0],
510
+ # [0, 255, 0],
511
+ # [0, 255, 85],
512
+ # [0, 255, 170],
513
+ # [0, 255, 255],
514
+ # [0, 170, 255],
515
+ # [0, 85, 255],
516
+ # [0, 0, 255],
517
+ # [85, 0, 255],
518
+ [170, 0, 255],
519
+ # [255, 0, 255],
520
+ # [255, 0, 170],
521
+ # [255, 0, 85],
522
+ # foot
523
+ # [200, 200, 0],
524
+ # [100, 100, 0],
525
+ ]
526
+
527
+ H, W, C = img.shape
528
+ stickwidth = max(int(min(H, W) / stick_width_norm), 1)
529
+
530
+ # for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
531
+ # keypoint1 = kp2ds_body[k1_index - 1]
532
+ # keypoint2 = kp2ds_body[k2_index - 1]
533
+
534
+ # if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
535
+ # continue
536
+
537
+ # Y = np.array([keypoint1[0], keypoint2[0]])
538
+ # X = np.array([keypoint1[1], keypoint2[1]])
539
+ # mX = np.mean(X)
540
+ # mY = np.mean(Y)
541
+ # length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
542
+ # angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
543
+ # polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
544
+ # cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
545
+
546
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
547
+ if keypoint[-1] < threshold:
548
+ continue
549
+ x, y = keypoint[0], keypoint[1]
550
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
551
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
552
+
553
+ if draw_hand:
554
+ img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
555
+ img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
556
+
557
+ kp2ds_body[:, 0] /= W
558
+ kp2ds_body[:, 1] /= H
559
+
560
+ if data_to_json is not None:
561
+ if idx == -1:
562
+ data_to_json.append(
563
+ {
564
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
565
+ "height": H,
566
+ "width": W,
567
+ "category_id": 1,
568
+ "keypoints_body": kp2ds_body.tolist(),
569
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
570
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
571
+ }
572
+ )
573
+ else:
574
+ data_to_json[idx] = {
575
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
576
+ "height": H,
577
+ "width": W,
578
+ "category_id": 1,
579
+ "keypoints_body": kp2ds_body.tolist(),
580
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
581
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
582
+ }
583
+ return img
584
+
585
+
586
+ def draw_aapose(
587
+ img,
588
+ kp2ds,
589
+ threshold=0.6,
590
+ data_to_json=None,
591
+ idx=-1,
592
+ kp2ds_lhand=None,
593
+ kp2ds_rhand=None,
594
+ draw_hand=False,
595
+ stick_width_norm=200,
596
+ draw_head=True
597
+ ):
598
+ """
599
+ Draw keypoints and connections representing hand pose on a given canvas.
600
+
601
+ Args:
602
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
603
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
604
+ or None if no keypoints are present.
605
+
606
+ Returns:
607
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
608
+
609
+ Note:
610
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
611
+ """
612
+
613
+ new_kep_list = [
614
+ "Nose",
615
+ "Neck",
616
+ "RShoulder",
617
+ "RElbow",
618
+ "RWrist", # No.4
619
+ "LShoulder",
620
+ "LElbow",
621
+ "LWrist", # No.7
622
+ "RHip",
623
+ "RKnee",
624
+ "RAnkle", # No.10
625
+ "LHip",
626
+ "LKnee",
627
+ "LAnkle", # No.13
628
+ "REye",
629
+ "LEye",
630
+ "REar",
631
+ "LEar",
632
+ "LToe",
633
+ "RToe",
634
+ ]
635
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
636
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
637
+ kp2ds = kp2ds.copy()
638
+ if not draw_head:
639
+ kp2ds[[0,14,15,16,17], 2] = 0
640
+ kp2ds_body = kp2ds
641
+
642
+ # kp2ds_lhand = kp2ds.copy()[91:112]
643
+ # kp2ds_rhand = kp2ds.copy()[112:133]
644
+
645
+ limbSeq = [
646
+ [2, 3],
647
+ [2, 6], # shoulders
648
+ [3, 4],
649
+ [4, 5], # left arm
650
+ [6, 7],
651
+ [7, 8], # right arm
652
+ [2, 9],
653
+ [9, 10],
654
+ [10, 11], # right leg
655
+ [2, 12],
656
+ [12, 13],
657
+ [13, 14], # left leg
658
+ [2, 1],
659
+ [1, 15],
660
+ [15, 17],
661
+ [1, 16],
662
+ [16, 18], # face (nose, eyes, ears)
663
+ [14, 19],
664
+ [11, 20], # foot
665
+ ]
666
+
667
+ colors = [
668
+ [255, 0, 0],
669
+ [255, 85, 0],
670
+ [255, 170, 0],
671
+ [255, 255, 0],
672
+ [170, 255, 0],
673
+ [85, 255, 0],
674
+ [0, 255, 0],
675
+ [0, 255, 85],
676
+ [0, 255, 170],
677
+ [0, 255, 255],
678
+ [0, 170, 255],
679
+ [0, 85, 255],
680
+ [0, 0, 255],
681
+ [85, 0, 255],
682
+ [170, 0, 255],
683
+ [255, 0, 255],
684
+ [255, 0, 170],
685
+ [255, 0, 85],
686
+ # foot
687
+ [200, 200, 0],
688
+ [100, 100, 0],
689
+ ]
690
+
691
+ H, W, C = img.shape
692
+ stickwidth = max(int(min(H, W) / stick_width_norm), 1)
693
+
694
+ for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
695
+ keypoint1 = kp2ds_body[k1_index - 1]
696
+ keypoint2 = kp2ds_body[k2_index - 1]
697
+
698
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
699
+ continue
700
+
701
+ Y = np.array([keypoint1[0], keypoint2[0]])
702
+ X = np.array([keypoint1[1], keypoint2[1]])
703
+ mX = np.mean(X)
704
+ mY = np.mean(Y)
705
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
706
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
707
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
708
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
709
+
710
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
711
+ if keypoint[-1] < threshold:
712
+ continue
713
+ x, y = keypoint[0], keypoint[1]
714
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
715
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
716
+
717
+ if draw_hand:
718
+ img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
719
+ img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
720
+
721
+ kp2ds_body[:, 0] /= W
722
+ kp2ds_body[:, 1] /= H
723
+
724
+ if data_to_json is not None:
725
+ if idx == -1:
726
+ data_to_json.append(
727
+ {
728
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
729
+ "height": H,
730
+ "width": W,
731
+ "category_id": 1,
732
+ "keypoints_body": kp2ds_body.tolist(),
733
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
734
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
735
+ }
736
+ )
737
+ else:
738
+ data_to_json[idx] = {
739
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
740
+ "height": H,
741
+ "width": W,
742
+ "category_id": 1,
743
+ "keypoints_body": kp2ds_body.tolist(),
744
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
745
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
746
+ }
747
+ return img
748
+
749
+
750
+ def draw_aapose_new(
751
+ img,
752
+ kp2ds,
753
+ threshold=0.6,
754
+ data_to_json=None,
755
+ idx=-1,
756
+ kp2ds_lhand=None,
757
+ kp2ds_rhand=None,
758
+ draw_hand=False,
759
+ stickwidth_type='v2',
760
+ draw_head=True
761
+ ):
762
+ """
763
+ Draw keypoints and connections representing hand pose on a given canvas.
764
+
765
+ Args:
766
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
767
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
768
+ or None if no keypoints are present.
769
+
770
+ Returns:
771
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
772
+
773
+ Note:
774
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
775
+ """
776
+
777
+ new_kep_list = [
778
+ "Nose",
779
+ "Neck",
780
+ "RShoulder",
781
+ "RElbow",
782
+ "RWrist", # No.4
783
+ "LShoulder",
784
+ "LElbow",
785
+ "LWrist", # No.7
786
+ "RHip",
787
+ "RKnee",
788
+ "RAnkle", # No.10
789
+ "LHip",
790
+ "LKnee",
791
+ "LAnkle", # No.13
792
+ "REye",
793
+ "LEye",
794
+ "REar",
795
+ "LEar",
796
+ "LToe",
797
+ "RToe",
798
+ ]
799
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
800
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
801
+ kp2ds = kp2ds.copy()
802
+ if not draw_head:
803
+ kp2ds[[0,14,15,16,17], 2] = 0
804
+ kp2ds_body = kp2ds
805
+
806
+ # kp2ds_lhand = kp2ds.copy()[91:112]
807
+ # kp2ds_rhand = kp2ds.copy()[112:133]
808
+
809
+ limbSeq = [
810
+ [2, 3],
811
+ [2, 6], # shoulders
812
+ [3, 4],
813
+ [4, 5], # left arm
814
+ [6, 7],
815
+ [7, 8], # right arm
816
+ [2, 9],
817
+ [9, 10],
818
+ [10, 11], # right leg
819
+ [2, 12],
820
+ [12, 13],
821
+ [13, 14], # left leg
822
+ [2, 1],
823
+ [1, 15],
824
+ [15, 17],
825
+ [1, 16],
826
+ [16, 18], # face (nose, eyes, ears)
827
+ [14, 19],
828
+ [11, 20], # foot
829
+ ]
830
+
831
+ colors = [
832
+ [255, 0, 0],
833
+ [255, 85, 0],
834
+ [255, 170, 0],
835
+ [255, 255, 0],
836
+ [170, 255, 0],
837
+ [85, 255, 0],
838
+ [0, 255, 0],
839
+ [0, 255, 85],
840
+ [0, 255, 170],
841
+ [0, 255, 255],
842
+ [0, 170, 255],
843
+ [0, 85, 255],
844
+ [0, 0, 255],
845
+ [85, 0, 255],
846
+ [170, 0, 255],
847
+ [255, 0, 255],
848
+ [255, 0, 170],
849
+ [255, 0, 85],
850
+ # foot
851
+ [200, 200, 0],
852
+ [100, 100, 0],
853
+ ]
854
+
855
+ H, W, C = img.shape
856
+ H, W, C = img.shape
857
+
858
+ if stickwidth_type == 'v1':
859
+ stickwidth = max(int(min(H, W) / 200), 1)
860
+ elif stickwidth_type == 'v2':
861
+ stickwidth = max(int(min(H, W) / 200) - 1, 1)
862
+ else:
863
+ raise
864
+
865
+ for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
866
+ keypoint1 = kp2ds_body[k1_index - 1]
867
+ keypoint2 = kp2ds_body[k2_index - 1]
868
+
869
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
870
+ continue
871
+
872
+ Y = np.array([keypoint1[0], keypoint2[0]])
873
+ X = np.array([keypoint1[1], keypoint2[1]])
874
+ mX = np.mean(X)
875
+ mY = np.mean(Y)
876
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
877
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
878
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
879
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
880
+
881
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
882
+ if keypoint[-1] < threshold:
883
+ continue
884
+ x, y = keypoint[0], keypoint[1]
885
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
886
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
887
+
888
+ if draw_hand:
889
+ img = draw_handpose_new(img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
890
+ img = draw_handpose_new(img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
891
+
892
+ kp2ds_body[:, 0] /= W
893
+ kp2ds_body[:, 1] /= H
894
+
895
+ if data_to_json is not None:
896
+ if idx == -1:
897
+ data_to_json.append(
898
+ {
899
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
900
+ "height": H,
901
+ "width": W,
902
+ "category_id": 1,
903
+ "keypoints_body": kp2ds_body.tolist(),
904
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
905
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
906
+ }
907
+ )
908
+ else:
909
+ data_to_json[idx] = {
910
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
911
+ "height": H,
912
+ "width": W,
913
+ "category_id": 1,
914
+ "keypoints_body": kp2ds_body.tolist(),
915
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
916
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
917
+ }
918
+ return img
919
+
920
+
921
+ def draw_bbox(img, bbox, color=(255, 0, 0)):
922
+ img = load_image(img)
923
+ bbox = [int(bbox_tmp) for bbox_tmp in bbox]
924
+ cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
925
+ return img
926
+
927
+
928
+ def draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False):
929
+ img = load_image(img, reverse)
930
+
931
+ if skeleton is not None:
932
+ if skeleton == "coco17":
933
+ skeleton_list = [
934
+ [6, 8],
935
+ [8, 10],
936
+ [5, 7],
937
+ [7, 9],
938
+ [11, 13],
939
+ [13, 15],
940
+ [12, 14],
941
+ [14, 16],
942
+ [5, 6],
943
+ [6, 12],
944
+ [12, 11],
945
+ [11, 5],
946
+ ]
947
+ color_list = [
948
+ (255, 0, 0),
949
+ (0, 255, 0),
950
+ (0, 0, 255),
951
+ (255, 255, 0),
952
+ (255, 0, 255),
953
+ (0, 255, 255),
954
+ ]
955
+ elif skeleton == "cocowholebody":
956
+ skeleton_list = [
957
+ [6, 8],
958
+ [8, 10],
959
+ [5, 7],
960
+ [7, 9],
961
+ [11, 13],
962
+ [13, 15],
963
+ [12, 14],
964
+ [14, 16],
965
+ [5, 6],
966
+ [6, 12],
967
+ [12, 11],
968
+ [11, 5],
969
+ [15, 17],
970
+ [15, 18],
971
+ [15, 19],
972
+ [16, 20],
973
+ [16, 21],
974
+ [16, 22],
975
+ [91, 92, 93, 94, 95],
976
+ [91, 96, 97, 98, 99],
977
+ [91, 100, 101, 102, 103],
978
+ [91, 104, 105, 106, 107],
979
+ [91, 108, 109, 110, 111],
980
+ [112, 113, 114, 115, 116],
981
+ [112, 117, 118, 119, 120],
982
+ [112, 121, 122, 123, 124],
983
+ [112, 125, 126, 127, 128],
984
+ [112, 129, 130, 131, 132],
985
+ ]
986
+ color_list = [
987
+ (255, 0, 0),
988
+ (0, 255, 0),
989
+ (0, 0, 255),
990
+ (255, 255, 0),
991
+ (255, 0, 255),
992
+ (0, 255, 255),
993
+ ]
994
+ else:
995
+ color_list = [color]
996
+ for _idx, _skeleton in enumerate(skeleton_list):
997
+ for i in range(len(_skeleton) - 1):
998
+ cv2.line(
999
+ img,
1000
+ (int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])),
1001
+ (int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])),
1002
+ color_list[_idx % len(color_list)],
1003
+ 3,
1004
+ )
1005
+
1006
+ for _idx, kp2d in enumerate(kp2ds):
1007
+ if kp2d[2] > threshold:
1008
+ cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1)
1009
+ # cv2.putText(img,
1010
+ # str(_idx),
1011
+ # (int(kp2d[0, i, 0])*1,
1012
+ # int(kp2d[0, i, 1])*1),
1013
+ # cv2.FONT_HERSHEY_SIMPLEX,
1014
+ # 0.75,
1015
+ # color,
1016
+ # 2
1017
+ # )
1018
+
1019
+ return img
1020
+
1021
+
1022
+ def draw_mask(img, mask, background=0, return_rgba=False):
1023
+ img = load_image(img)
1024
+ h, w, _ = img.shape
1025
+ if type(background) == int:
1026
+ background = np.ones((h, w, 3)).astype(np.uint8) * 255 * background
1027
+ backgournd = cv2.resize(background, (w, h))
1028
+ img_rgba = np.concatenate([img, mask], -1)
1029
+ return alphaMerge(img_rgba, background, 0, 0, return_rgba=True)
1030
+
1031
+
1032
+ def draw_pcd(pcd_list, save_path=None):
1033
+ fig = plt.figure()
1034
+ ax = fig.add_subplot(111, projection="3d")
1035
+
1036
+ color_list = ["r", "g", "b", "y", "p"]
1037
+
1038
+ for _idx, _pcd in enumerate(pcd_list):
1039
+ ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker="o")
1040
+
1041
+ ax.set_xlabel("X")
1042
+ ax.set_ylabel("Y")
1043
+ ax.set_zlabel("Z")
1044
+
1045
+ if save_path is not None:
1046
+ plt.savefig(save_path)
1047
+ else:
1048
+ plt.savefig("tmp.png")
1049
+
1050
+
1051
+ def load_image(img, reverse=False):
1052
+ if type(img) == str:
1053
+ img = cv2.imread(img)
1054
+ if reverse:
1055
+ img = img.astype(np.float32)
1056
+ img = img[:, :, ::-1]
1057
+ img = img.astype(np.uint8)
1058
+ return img
1059
+
1060
+
1061
+ def draw_skeleten(meta):
1062
+ kps = []
1063
+ for i, kp in enumerate(meta["keypoints_body"]):
1064
+ if kp is None:
1065
+ # if kp is None:
1066
+ kps.append([0, 0, 0])
1067
+ else:
1068
+ kps.append([*kp, 1])
1069
+ kps = np.array(kps)
1070
+
1071
+ kps[:, 0] *= meta["width"]
1072
+ kps[:, 1] *= meta["height"]
1073
+ pose_img = np.zeros([meta["height"], meta["width"], 3], dtype=np.uint8)
1074
+
1075
+ pose_img = draw_aapose(
1076
+ pose_img,
1077
+ kps,
1078
+ draw_hand=True,
1079
+ kp2ds_lhand=meta["keypoints_left_hand"],
1080
+ kp2ds_rhand=meta["keypoints_right_hand"],
1081
+ )
1082
+ return pose_img
1083
+
1084
+
1085
+ def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray:
1086
+ """
1087
+ Args:
1088
+ pncc: [H,W,3]
1089
+ meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand
1090
+ Return:
1091
+ np.ndarray [H, W, 3]
1092
+ """
1093
+ # preprocess keypoints
1094
+ kps = []
1095
+ for i, kp in enumerate(meta["keypoints_body"]):
1096
+ if kp is None:
1097
+ # if kp is None:
1098
+ kps.append([0, 0, 0])
1099
+ elif i in [14, 15, 16, 17]:
1100
+ kps.append([0, 0, 0])
1101
+ else:
1102
+ kps.append([*kp])
1103
+ kps = np.stack(kps)
1104
+
1105
+ kps[:, 0] *= pncc.shape[1]
1106
+ kps[:, 1] *= pncc.shape[0]
1107
+
1108
+ # draw neck
1109
+ canvas = np.zeros_like(pncc)
1110
+ if kps[0][2] > 0.6 and kps[1][2] > 0.6:
1111
+ canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255])
1112
+
1113
+ # draw pncc
1114
+ mask = (pncc > 0).max(axis=2)
1115
+ canvas[mask] = pncc[mask]
1116
+ pncc = canvas
1117
+
1118
+ # draw other skeleten
1119
+ kps[0] = 0
1120
+
1121
+ meta["keypoints_left_hand"][:, 0] *= meta["width"]
1122
+ meta["keypoints_left_hand"][:, 1] *= meta["height"]
1123
+
1124
+ meta["keypoints_right_hand"][:, 0] *= meta["width"]
1125
+ meta["keypoints_right_hand"][:, 1] *= meta["height"]
1126
+ pose_img = draw_aapose(
1127
+ pncc,
1128
+ kps,
1129
+ draw_hand=True,
1130
+ kp2ds_lhand=meta["keypoints_left_hand"],
1131
+ kp2ds_rhand=meta["keypoints_right_hand"],
1132
+ )
1133
+ return pose_img
1134
+
1135
+
1136
+ FACE_CUSTOM_STYLE = {
1137
+ "eyeball": {"indexs": [68, 69], "color": [255, 255, 255], "connect": False},
1138
+ "left_eyebrow": {"indexs": [17, 18, 19, 20, 21], "color": [0, 255, 0]},
1139
+ "right_eyebrow": {"indexs": [22, 23, 24, 25, 26], "color": [0, 0, 255]},
1140
+ "left_eye": {"indexs": [36, 37, 38, 39, 40, 41], "color": [255, 255, 0], "close": True},
1141
+ "right_eye": {"indexs": [42, 43, 44, 45, 46, 47], "color": [255, 0, 255], "close": True},
1142
+ "mouth_outside": {"indexs": list(range(48, 60)), "color": [100, 255, 50], "close": True},
1143
+ "mouth_inside": {"indexs": [60, 61, 62, 63, 64, 65, 66, 67], "color": [255, 100, 50], "close": True},
1144
+ }
1145
+
1146
+
1147
+ def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE):
1148
+ """
1149
+ Args:
1150
+ img: [H, W, 3]
1151
+ kps: [70, 2]
1152
+ """
1153
+ img = img.copy()
1154
+ for key, item in style.items():
1155
+ pts = np.array(kps[item["indexs"]]).astype(np.int32)
1156
+ connect = item.get("connect", True)
1157
+ color = item["color"]
1158
+ close = item.get("close", False)
1159
+ if connect:
1160
+ cv2.polylines(img, [pts], close, color, thickness=thickness)
1161
+ else:
1162
+ for kp in pts:
1163
+ kp = np.array(kp).astype(np.int32)
1164
+ cv2.circle(img, kp, thickness * 2, color=color, thickness=-1)
1165
+ return img
1166
+
1167
+
1168
+ def draw_traj(metas: List[AAPoseMeta], threshold=0.6):
1169
+
1170
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
1171
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
1172
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [100, 255, 50], [255, 100, 50],
1173
+ # foot
1174
+ [200, 200, 0],
1175
+ [100, 100, 0]
1176
+ ]
1177
+ limbSeq = [
1178
+ [1, 2], [1, 5], # shoulders
1179
+ [2, 3], [3, 4], # left arm
1180
+ [5, 6], [6, 7], # right arm
1181
+ [1, 8], [8, 9], [9, 10], # right leg
1182
+ [1, 11], [11, 12], [12, 13], # left leg
1183
+ # face (nose, eyes, ears)
1184
+ [13, 18], [10, 19] # foot
1185
+ ]
1186
+
1187
+ face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]]
1188
+ kp_body = np.array([meta.kps_body for meta in metas])
1189
+ kp_body_p = np.array([meta.kps_body_p for meta in metas])
1190
+
1191
+
1192
+ face_seq = random.sample(face_seq, 2)
1193
+
1194
+ kp_lh = np.array([meta.kps_lhand for meta in metas])
1195
+ kp_rh = np.array([meta.kps_rhand for meta in metas])
1196
+
1197
+ kp_lh_p = np.array([meta.kps_lhand_p for meta in metas])
1198
+ kp_rh_p = np.array([meta.kps_rhand_p for meta in metas])
1199
+
1200
+ # kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1)
1201
+ # kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1)
1202
+
1203
+ new_limbSeq = []
1204
+ key_point_list = []
1205
+ for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
1206
+
1207
+ vis = (kp_body_p[:, k1_index] > threshold) * (kp_body_p[:, k2_index] > threshold) * 1
1208
+ if vis.sum() * 1.0 / vis.shape[0] > 0.4:
1209
+ new_limbSeq.append([k1_index, k2_index])
1210
+
1211
+ for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
1212
+
1213
+ keypoint1 = kp_body[:, k1_index - 1]
1214
+ keypoint2 = kp_body[:, k2_index - 1]
1215
+ interleave = random.randint(4, 7)
1216
+ randind = random.randint(0, interleave - 1)
1217
+ # randind = random.rand(range(interleave), sampling_num)
1218
+
1219
+ Y = np.array([keypoint1[:, 0], keypoint2[:, 0]])
1220
+ X = np.array([keypoint1[:, 1], keypoint2[:, 1]])
1221
+
1222
+ vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1
1223
+
1224
+ # for randidx in randind:
1225
+ t = randind / interleave
1226
+ x = (1-t)*Y[0, :] + t*Y[1, :]
1227
+ y = (1-t)*X[0, :] + t*X[1, :]
1228
+
1229
+ # np.array([1])
1230
+ x = x.astype(int)
1231
+ y = y.astype(int)
1232
+
1233
+ new_array = np.array([x, y, vis]).T
1234
+
1235
+ key_point_list.append(new_array)
1236
+
1237
+ indx_lh = random.randint(0, kp_lh.shape[1] - 1)
1238
+ lh = kp_lh[:, indx_lh, :]
1239
+ lh_p = kp_lh_p[:, indx_lh:indx_lh+1]
1240
+ lh = np.concatenate([lh, lh_p], axis=-1)
1241
+
1242
+ indx_rh = random.randint(0, kp_rh.shape[1] - 1)
1243
+ rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :]
1244
+ rh_p = kp_rh_p[:, indx_rh:indx_rh+1]
1245
+ rh = np.concatenate([rh, rh_p], axis=-1)
1246
+
1247
+
1248
+
1249
+ lh[-1, :] = (lh[-1, :] > threshold) * 1
1250
+ rh[-1, :] = (rh[-1, :] > threshold) * 1
1251
+
1252
+ # print(rh.shape, new_array.shape)
1253
+ # exit()
1254
+ key_point_list.append(lh.astype(int))
1255
+ key_point_list.append(rh.astype(int))
1256
+
1257
+
1258
+ key_points_list = np.stack(key_point_list)
1259
+ num_points = len(key_points_list)
1260
+ sample_colors = random.sample(colors, num_points)
1261
+
1262
+ stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2)
1263
+
1264
+ image_list_ori = []
1265
+ for i in range(key_points_list.shape[-2]):
1266
+ _image_vis = np.zeros((metas[0].width, metas[0].height, 3))
1267
+ points = key_points_list[:, i, :]
1268
+ for idx, point in enumerate(points):
1269
+ x, y, vis = point
1270
+ if vis == 1:
1271
+ cv2.circle(_image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1)
1272
+
1273
+ image_list_ori.append(_image_vis)
1274
+
1275
+ return image_list_ori
1276
+
1277
+ return [np.zeros([meta.width, meta.height, 3], dtype=np.uint8) for meta in metas]
1278
+
1279
+
1280
+ if __name__ == "__main__":
1281
+ meta = {
1282
+ "image_id": "00472.jpg",
1283
+ "height": 540,
1284
+ "width": 414,
1285
+ "category_id": 1,
1286
+ "keypoints_body": [
1287
+ [0.5084776947463768, 0.11350188078703703],
1288
+ [0.504467655495169, 0.20419560185185184],
1289
+ [0.3982016153381642, 0.198046875],
1290
+ [0.3841664779589372, 0.34869068287037036],
1291
+ [0.3901815368357488, 0.4670536747685185],
1292
+ [0.610733695652174, 0.2103443287037037],
1293
+ [0.6167487545289855, 0.3517650462962963],
1294
+ [0.6448190292874396, 0.4762767650462963],
1295
+ [0.4523371452294686, 0.47320240162037036],
1296
+ [0.4503321256038647, 0.6776475694444445],
1297
+ [0.47639738073671495, 0.8544234664351852],
1298
+ [0.5766483620169082, 0.47320240162037036],
1299
+ [0.5666232638888888, 0.6761103877314815],
1300
+ [0.534542949879227, 0.863646556712963],
1301
+ [0.4864224788647343, 0.09505570023148148],
1302
+ [0.5285278910024155, 0.09351851851851851],
1303
+ [0.46236224335748793, 0.10581597222222222],
1304
+ [0.5586031853864735, 0.10274160879629629],
1305
+ [0.4994551064311594, 0.9405056423611111],
1306
+ [0.4152442821557971, 0.9312825520833333],
1307
+ ],
1308
+ "keypoints_left_hand": [
1309
+ [267.78515625, 263.830078125, 1.2840936183929443],
1310
+ [265.294921875, 269.640625, 1.2546794414520264],
1311
+ [263.634765625, 277.111328125, 1.2863062620162964],
1312
+ [262.8046875, 285.412109375, 1.267038345336914],
1313
+ [261.14453125, 292.8828125, 1.280144453048706],
1314
+ [273.595703125, 281.26171875, 1.2592815160751343],
1315
+ [271.10546875, 291.22265625, 1.3256099224090576],
1316
+ [265.294921875, 294.54296875, 1.2368024587631226],
1317
+ [261.14453125, 294.54296875, 0.9771889448165894],
1318
+ [274.42578125, 282.091796875, 1.250044584274292],
1319
+ [269.4453125, 291.22265625, 1.2571144104003906],
1320
+ [264.46484375, 292.8828125, 1.177802324295044],
1321
+ [260.314453125, 292.052734375, 0.9283463358879089],
1322
+ [273.595703125, 282.091796875, 1.1834490299224854],
1323
+ [269.4453125, 290.392578125, 1.188171625137329],
1324
+ [265.294921875, 290.392578125, 1.192609429359436],
1325
+ [261.974609375, 289.5625, 0.9366656541824341],
1326
+ [271.935546875, 281.26171875, 1.0946396589279175],
1327
+ [268.615234375, 287.072265625, 0.9906131029129028],
1328
+ [265.294921875, 287.90234375, 1.0219476222991943],
1329
+ [262.8046875, 287.072265625, 0.9240120053291321],
1330
+ ],
1331
+ "keypoints_right_hand": [
1332
+ [161.53515625, 258.849609375, 1.2069408893585205],
1333
+ [168.17578125, 263.0, 1.1846840381622314],
1334
+ [173.986328125, 269.640625, 1.1435924768447876],
1335
+ [173.986328125, 277.94140625, 1.1802611351013184],
1336
+ [173.986328125, 286.2421875, 1.2599592208862305],
1337
+ [165.685546875, 275.451171875, 1.0633569955825806],
1338
+ [167.345703125, 286.2421875, 1.1693341732025146],
1339
+ [169.8359375, 291.22265625, 1.2698509693145752],
1340
+ [170.666015625, 294.54296875, 1.0619274377822876],
1341
+ [160.705078125, 276.28125, 1.0995020866394043],
1342
+ [163.1953125, 287.90234375, 1.2735884189605713],
1343
+ [166.515625, 291.22265625, 1.339503526687622],
1344
+ [169.005859375, 294.54296875, 1.0835273265838623],
1345
+ [157.384765625, 277.111328125, 1.0866981744766235],
1346
+ [161.53515625, 287.072265625, 1.2468621730804443],
1347
+ [164.025390625, 289.5625, 1.2817761898040771],
1348
+ [166.515625, 292.052734375, 1.099466323852539],
1349
+ [155.724609375, 277.111328125, 1.1065717935562134],
1350
+ [159.044921875, 285.412109375, 1.1924479007720947],
1351
+ [160.705078125, 287.072265625, 1.1304771900177002],
1352
+ [162.365234375, 287.90234375, 1.0040509700775146],
1353
+ ],
1354
+ }
1355
+ demo_meta = AAPoseMeta(meta)
1356
+ res = draw_traj([demo_meta]*5)
1357
+ cv2.imwrite("traj.png", res[0][..., ::-1])
wan/modules/animate/preprocess/pose2d.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import cv2
4
+ from typing import Union, List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import onnxruntime
9
+
10
+ from pose2d_utils import (
11
+ read_img,
12
+ box_convert_simple,
13
+ bbox_from_detector,
14
+ crop,
15
+ keypoints_from_heatmaps,
16
+ load_pose_metas_from_kp2ds_seq
17
+ )
18
+
19
+
20
+ class SimpleOnnxInference(object):
21
+ def __init__(self, checkpoint, device='cuda', reverse_input=False, **kwargs):
22
+ if isinstance(device, str):
23
+ device = torch.device(device)
24
+ if device.type == 'cuda':
25
+ device = '{}:{}'.format(device.type, device.index)
26
+ providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
27
+ else:
28
+ providers = ["CPUExecutionProvider"]
29
+ self.device = device
30
+ if not os.path.exists(checkpoint):
31
+ raise RuntimeError("{} is not existed!".format(checkpoint))
32
+
33
+ if os.path.isdir(checkpoint):
34
+ checkpoint = os.path.join(checkpoint, 'end2end.onnx')
35
+
36
+ self.session = onnxruntime.InferenceSession(checkpoint,
37
+ providers=providers
38
+ )
39
+ self.input_name = self.session.get_inputs()[0].name
40
+ self.output_name = self.session.get_outputs()[0].name
41
+ self.input_resolution = self.session.get_inputs()[0].shape[2:] if not reverse_input else self.session.get_inputs()[0].shape[2:][::-1]
42
+ self.input_resolution = np.array(self.input_resolution)
43
+
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ return self.forward(*args, **kwargs)
47
+
48
+
49
+ def get_output_names(self):
50
+ output_names = []
51
+ for node in self.session.get_outputs():
52
+ output_names.append(node.name)
53
+ return output_names
54
+
55
+
56
+ def set_device(self, device):
57
+ if isinstance(device, str):
58
+ device = torch.device(device)
59
+ if device.type == 'cuda':
60
+ device = '{}:{}'.format(device.type, device.index)
61
+ providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
62
+ else:
63
+ providers = ["CPUExecutionProvider"]
64
+ self.session.set_providers(["CUDAExecutionProvider"])
65
+ self.device = device
66
+
67
+
68
+ class Yolo(SimpleOnnxInference):
69
+ def __init__(self, checkpoint, device='cuda', threshold_conf=0.05, threshold_multi_persons=0.1, input_resolution=(640, 640), threshold_iou=0.5, threshold_bbox_shape_ratio=0.4, cat_id=[1], select_type='max', strict=True, sorted_func=None, **kwargs):
70
+ super(Yolo, self).__init__(checkpoint, device=device, **kwargs)
71
+ self.session.set_providers(["CUDAExecutionProvider"])
72
+ model_inputs = self.session.get_inputs()
73
+ input_shape = model_inputs[0].shape
74
+
75
+ self.input_width = 640
76
+ self.input_height = 640
77
+
78
+ self.threshold_multi_persons = threshold_multi_persons
79
+ self.threshold_conf = threshold_conf
80
+ self.threshold_iou = threshold_iou
81
+ self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio
82
+ self.input_resolution = input_resolution
83
+ self.cat_id = cat_id
84
+ self.select_type = select_type
85
+ self.strict = strict
86
+ self.sorted_func = sorted_func
87
+
88
+
89
+ def preprocess(self, input_image):
90
+ """
91
+ Preprocesses the input image before performing inference.
92
+
93
+ Returns:
94
+ image_data: Preprocessed image data ready for inference.
95
+ """
96
+ img = read_img(input_image)
97
+ # Get the height and width of the input image
98
+ img_height, img_width = img.shape[:2]
99
+ # Resize the image to match the input shape
100
+ img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0]))
101
+ # Normalize the image data by dividing it by 255.0
102
+ image_data = np.array(img) / 255.0
103
+ # Transpose the image to have the channel dimension as the first dimension
104
+ image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
105
+ # Expand the dimensions of the image data to match the expected input shape
106
+ # image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
107
+ image_data = image_data.astype(np.float32)
108
+ # Return the preprocessed image data
109
+ return image_data, np.array([img_height, img_width])
110
+
111
+
112
+ def postprocess(self, output, shape_raw, cat_id=[1]):
113
+ """
114
+ Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
115
+
116
+ Args:
117
+ input_image (numpy.ndarray): The input image.
118
+ output (numpy.ndarray): The output of the model.
119
+
120
+ Returns:
121
+ numpy.ndarray: The input image with detections drawn on it.
122
+ """
123
+ # Transpose and squeeze the output to match the expected shape
124
+
125
+ outputs = np.squeeze(output)
126
+ if len(outputs.shape) == 1:
127
+ outputs = outputs[None]
128
+ if output.shape[-1] != 6 and output.shape[1] == 84:
129
+ outputs = np.transpose(outputs)
130
+
131
+ # Get the number of rows in the outputs array
132
+ rows = outputs.shape[0]
133
+
134
+ # Calculate the scaling factors for the bounding box coordinates
135
+ x_factor = shape_raw[1] / self.input_width
136
+ y_factor = shape_raw[0] / self.input_height
137
+
138
+ # Lists to store the bounding boxes, scores, and class IDs of the detections
139
+ boxes = []
140
+ scores = []
141
+ class_ids = []
142
+
143
+ if outputs.shape[-1] == 6:
144
+ max_scores = outputs[:, 4]
145
+ classid = outputs[:, -1]
146
+
147
+ threshold_conf_masks = max_scores >= self.threshold_conf
148
+ classid_masks = classid[threshold_conf_masks] != 3.14159
149
+
150
+ max_scores = max_scores[threshold_conf_masks][classid_masks]
151
+ classid = classid[threshold_conf_masks][classid_masks]
152
+
153
+ boxes = outputs[:, :4][threshold_conf_masks][classid_masks]
154
+ boxes[:, [0, 2]] *= x_factor
155
+ boxes[:, [1, 3]] *= y_factor
156
+ boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
157
+ boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
158
+ boxes = boxes.astype(np.int32)
159
+
160
+ else:
161
+ classes_scores = outputs[:, 4:]
162
+ max_scores = np.amax(classes_scores, -1)
163
+ threshold_conf_masks = max_scores >= self.threshold_conf
164
+
165
+ classid = np.argmax(classes_scores[threshold_conf_masks], -1)
166
+
167
+ classid_masks = classid!=3.14159
168
+
169
+ classes_scores = classes_scores[threshold_conf_masks][classid_masks]
170
+ max_scores = max_scores[threshold_conf_masks][classid_masks]
171
+ classid = classid[classid_masks]
172
+
173
+ xywh = outputs[:, :4][threshold_conf_masks][classid_masks]
174
+
175
+ x = xywh[:, 0:1]
176
+ y = xywh[:, 1:2]
177
+ w = xywh[:, 2:3]
178
+ h = xywh[:, 3:4]
179
+
180
+ left = ((x - w / 2) * x_factor)
181
+ top = ((y - h / 2) * y_factor)
182
+ width = (w * x_factor)
183
+ height = (h * y_factor)
184
+ boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32)
185
+
186
+ boxes = boxes.tolist()
187
+ scores = max_scores.tolist()
188
+ class_ids = classid.tolist()
189
+
190
+ # Apply non-maximum suppression to filter out overlapping bounding boxes
191
+ indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou)
192
+ # Iterate over the selected indices after non-maximum suppression
193
+
194
+ results = []
195
+ for i in indices:
196
+ # Get the box, score, and class ID corresponding to the index
197
+ box = box_convert_simple(boxes[i], 'xywh2xyxy')
198
+ score = scores[i]
199
+ class_id = class_ids[i]
200
+ results.append(box + [score] + [class_id])
201
+ # # Draw the detection on the input image
202
+
203
+ # Return the modified input image
204
+ return np.array(results)
205
+
206
+
207
+ def process_results(self, results, shape_raw, cat_id=[1], single_person=True):
208
+ if isinstance(results, tuple):
209
+ det_results = results[0]
210
+ else:
211
+ det_results = results
212
+
213
+ person_results = []
214
+ person_count = 0
215
+ if len(results):
216
+ max_idx = -1
217
+ max_bbox_size = shape_raw[0] * shape_raw[1] * -10
218
+ max_bbox_shape = -1
219
+
220
+ bboxes = []
221
+ idx_list = []
222
+ for i in range(results.shape[0]):
223
+ bbox = results[i]
224
+ if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
225
+ idx_list.append(i)
226
+ bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
227
+ if bbox_shape > max_bbox_shape:
228
+ max_bbox_shape = bbox_shape
229
+
230
+ results = results[idx_list]
231
+
232
+ for i in range(results.shape[0]):
233
+ bbox = results[i]
234
+ bboxes.append(bbox)
235
+ if self.select_type == 'max':
236
+ bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
237
+ elif self.select_type == 'center':
238
+ bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
239
+ bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
240
+ if bbox_size > max_bbox_size:
241
+ if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio:
242
+ continue
243
+ max_bbox_size = bbox_size
244
+ max_bbox_shape = bbox_shape
245
+ max_idx = i
246
+
247
+ if self.sorted_func is not None and len(bboxes) > 0:
248
+ max_idx = self.sorted_func(bboxes, shape_raw)
249
+ bbox = bboxes[max_idx]
250
+ if self.select_type == 'max':
251
+ max_bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
252
+ elif self.select_type == 'center':
253
+ max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
254
+
255
+ if max_idx != -1:
256
+ person_count = 1
257
+
258
+ if max_idx != -1:
259
+ person = {}
260
+ person['bbox'] = results[max_idx, :5]
261
+ person['track_id'] = int(0)
262
+ person_results.append(person)
263
+
264
+ for i in range(results.shape[0]):
265
+ bbox = results[i]
266
+ if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
267
+ if self.select_type == 'max':
268
+ bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
269
+ elif self.select_type == 'center':
270
+ bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
271
+ if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size:
272
+ person_count += 1
273
+ if not single_person:
274
+ person = {}
275
+ person['bbox'] = results[i, :5]
276
+ person['track_id'] = int(person_count - 1)
277
+ person_results.append(person)
278
+ return person_results
279
+ else:
280
+ return None
281
+
282
+
283
+ def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs):
284
+ result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id)
285
+ result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person)
286
+ if result is not None and len(result) != 0:
287
+ person_results[i] = result
288
+
289
+
290
+ def forward(self, img, shape_raw, **kwargs):
291
+ """
292
+ Performs inference using an ONNX model and returns the output image with drawn detections.
293
+
294
+ Returns:
295
+ output_img: The output image with drawn detections.
296
+ """
297
+ if isinstance(img, torch.Tensor):
298
+ img = img.cpu().numpy()
299
+ shape_raw = shape_raw.cpu().numpy()
300
+
301
+ outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0]
302
+ person_results = [[{'bbox': np.array([0., 0., 1.*shape_raw[i][1], 1.*shape_raw[i][0], -1]), 'track_id': -1}] for i in range(len(outputs))]
303
+
304
+ for i in range(len(outputs)):
305
+ self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs)
306
+ return person_results
307
+
308
+
309
+ class ViTPose(SimpleOnnxInference):
310
+ def __init__(self, checkpoint, device='cuda', **kwargs):
311
+ super(ViTPose, self).__init__(checkpoint, device=device)
312
+ self.session.set_providers(["CUDAExecutionProvider"])
313
+
314
+ def forward(self, img, center, scale, **kwargs):
315
+ heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0]
316
+ points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
317
+ center=center,
318
+ scale=scale*200,
319
+ unbiased=True,
320
+ use_udp=False)
321
+ return np.concatenate([points, prob], axis=2)
322
+
323
+
324
+ @staticmethod
325
+ def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs):
326
+ if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:
327
+ bbox = np.array([0, 0, img.shape[1], img.shape[0]])
328
+
329
+ bbox_xywh = bbox
330
+ if mask is not None:
331
+ img = np.where(mask>128, img, mask)
332
+
333
+ if isinstance(input_resolution, int):
334
+ center, scale = bbox_from_detector(bbox_xywh, (input_resolution, input_resolution), rescale=rescale)
335
+ img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution, input_resolution))
336
+ else:
337
+ center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)
338
+ img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution[0], input_resolution[1]))
339
+
340
+ IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])
341
+ IMG_NORM_STD = np.array([0.229, 0.224, 0.225])
342
+ img_norm = (img / 255. - IMG_NORM_MEAN) / IMG_NORM_STD
343
+ img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
344
+ return img_norm, np.array(center), np.array(scale)
345
+
346
+
347
+ class Pose2d:
348
+ def __init__(self, checkpoint, detector_checkpoint=None, device='cuda', **kwargs):
349
+
350
+ if detector_checkpoint is not None:
351
+ self.detector = Yolo(detector_checkpoint, device)
352
+ else:
353
+ self.detector = None
354
+
355
+ self.model = ViTPose(checkpoint, device)
356
+ self.device = device
357
+
358
+ def load_images(self, inputs):
359
+ """
360
+ Load images from various input types.
361
+
362
+ Args:
363
+ inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
364
+ single image array, or list of image arrays
365
+
366
+ Returns:
367
+ List[np.ndarray]: List of RGB image arrays
368
+
369
+ Raises:
370
+ ValueError: If file format is unsupported or image cannot be read
371
+ """
372
+ if isinstance(inputs, str):
373
+ if inputs.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
374
+ cap = cv2.VideoCapture(inputs)
375
+ frames = []
376
+ while True:
377
+ ret, frame = cap.read()
378
+ if not ret:
379
+ break
380
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
381
+ cap.release()
382
+ images = frames
383
+ elif inputs.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
384
+ img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB)
385
+ if img is None:
386
+ raise ValueError(f"Cannot read image: {inputs}")
387
+ images = [img]
388
+ else:
389
+ raise ValueError(f"Unsupported file format: {inputs}")
390
+
391
+ elif isinstance(inputs, np.ndarray):
392
+ images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
393
+ elif isinstance(inputs, list):
394
+ images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
395
+ return images
396
+
397
+ def __call__(
398
+ self,
399
+ inputs: Union[str, np.ndarray, List[np.ndarray]],
400
+ return_image: bool = False,
401
+ **kwargs
402
+ ):
403
+ """
404
+ Process input and estimate 2D keypoints.
405
+
406
+ Args:
407
+ inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
408
+ single image array, or list of image arrays
409
+ **kwargs: Additional arguments for processing
410
+
411
+ Returns:
412
+ np.ndarray: Array of detected 2D keypoints for all input images
413
+ """
414
+ images = self.load_images(inputs)
415
+ H, W = images[0].shape[:2]
416
+ if self.detector is not None:
417
+ bboxes = []
418
+ for _image in images:
419
+ img, shape = self.detector.preprocess(_image)
420
+ bboxes.append(self.detector(img[None], shape[None])[0][0]["bbox"])
421
+ else:
422
+ bboxes = [None] * len(images)
423
+
424
+ kp2ds = []
425
+ for _image, _bbox in zip(images, bboxes):
426
+ img, center, scale = self.model.preprocess(_image, _bbox)
427
+ kp2ds.append(self.model(img[None], center[None], scale[None]))
428
+ kp2ds = np.concatenate(kp2ds, 0)
429
+ metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)
430
+ return metas
wan/modules/animate/preprocess/pose2d_utils.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import warnings
3
+ import cv2
4
+ import numpy as np
5
+ from typing import List
6
+ from PIL import Image
7
+
8
+
9
+ def box_convert_simple(box, convert_type='xyxy2xywh'):
10
+ if convert_type == 'xyxy2xywh':
11
+ return [box[0], box[1], box[2] - box[0], box[3] - box[1]]
12
+ elif convert_type == 'xywh2xyxy':
13
+ return [box[0], box[1], box[2] + box[0], box[3] + box[1]]
14
+ elif convert_type == 'xyxy2ctwh':
15
+ return [(box[0] + box[2]) / 2, (box[1] + box[3]) / 2, box[2] - box[0], box[3] - box[1]]
16
+ elif convert_type == 'ctwh2xyxy':
17
+ return [box[0] - box[2] // 2, box[1] - box[3] // 2, box[0] + (box[2] - box[2] // 2), box[1] + (box[3] - box[3] // 2)]
18
+
19
+ def read_img(image, convert='RGB', check_exist=False):
20
+ if isinstance(image, str):
21
+ if check_exist and not osp.exists(image):
22
+ return None
23
+ try:
24
+ img = Image.open(image)
25
+ if convert:
26
+ img = img.convert(convert)
27
+ except:
28
+ raise IOError('File error: ', image)
29
+ return np.asarray(img)
30
+ else:
31
+ if isinstance(image, np.ndarray):
32
+ if convert:
33
+ return image[..., ::-1]
34
+ else:
35
+ if convert:
36
+ img = img.convert(convert)
37
+ return np.asarray(img)
38
+
39
+ class AAPoseMeta:
40
+ def __init__(self, meta=None, kp2ds=None):
41
+ self.image_id = ""
42
+ self.height = 0
43
+ self.width = 0
44
+
45
+ self.kps_body: np.ndarray = None
46
+ self.kps_lhand: np.ndarray = None
47
+ self.kps_rhand: np.ndarray = None
48
+ self.kps_face: np.ndarray = None
49
+ self.kps_body_p: np.ndarray = None
50
+ self.kps_lhand_p: np.ndarray = None
51
+ self.kps_rhand_p: np.ndarray = None
52
+ self.kps_face_p: np.ndarray = None
53
+
54
+
55
+ if meta is not None:
56
+ self.load_from_meta(meta)
57
+ elif kp2ds is not None:
58
+ self.load_from_kp2ds(kp2ds)
59
+
60
+ def is_valid(self, kp, p, threshold):
61
+ x, y = kp
62
+ if x < 0 or y < 0 or x > self.width or y > self.height or p < threshold:
63
+ return False
64
+ else:
65
+ return True
66
+
67
+ def get_bbox(self, kp, kp_p, threshold=0.5):
68
+ kps = kp[kp_p > threshold]
69
+ if kps.size == 0:
70
+ return 0, 0, 0, 0
71
+ x0, y0 = kps.min(axis=0)
72
+ x1, y1 = kps.max(axis=0)
73
+ return x0, y0, x1, y1
74
+
75
+ def crop(self, x0, y0, x1, y1):
76
+ all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
77
+ for kps in all_kps:
78
+ if kps is not None:
79
+ kps[:, 0] -= x0
80
+ kps[:, 1] -= y0
81
+ self.width = x1 - x0
82
+ self.height = y1 - y0
83
+ return self
84
+
85
+ def resize(self, width, height):
86
+ scale_x = width / self.width
87
+ scale_y = height / self.height
88
+ all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
89
+ for kps in all_kps:
90
+ if kps is not None:
91
+ kps[:, 0] *= scale_x
92
+ kps[:, 1] *= scale_y
93
+ self.width = width
94
+ self.height = height
95
+ return self
96
+
97
+
98
+ def get_kps_body_with_p(self, normalize=False):
99
+ kps_body = self.kps_body.copy()
100
+ if normalize:
101
+ kps_body = kps_body / np.array([self.width, self.height])
102
+
103
+ return np.concatenate([kps_body, self.kps_body_p[:, None]])
104
+
105
+ @staticmethod
106
+ def from_kps_face(kps_face: np.ndarray, height: int, width: int):
107
+
108
+ pose_meta = AAPoseMeta()
109
+ pose_meta.kps_face = kps_face[:, :2]
110
+ if kps_face.shape[1] == 3:
111
+ pose_meta.kps_face_p = kps_face[:, 2]
112
+ else:
113
+ pose_meta.kps_face_p = kps_face[:, 0] * 0 + 1
114
+ pose_meta.height = height
115
+ pose_meta.width = width
116
+ return pose_meta
117
+
118
+ @staticmethod
119
+ def from_kps_body(kps_body: np.ndarray, height: int, width: int):
120
+
121
+ pose_meta = AAPoseMeta()
122
+ pose_meta.kps_body = kps_body[:, :2]
123
+ pose_meta.kps_body_p = kps_body[:, 2]
124
+ pose_meta.height = height
125
+ pose_meta.width = width
126
+ return pose_meta
127
+ @staticmethod
128
+ def from_humanapi_meta(meta):
129
+ pose_meta = AAPoseMeta()
130
+ width, height = meta["width"], meta["height"]
131
+ pose_meta.width = width
132
+ pose_meta.height = height
133
+ pose_meta.kps_body = meta["keypoints_body"][:, :2] * (width, height)
134
+ pose_meta.kps_body_p = meta["keypoints_body"][:, 2]
135
+ pose_meta.kps_lhand = meta["keypoints_left_hand"][:, :2] * (width, height)
136
+ pose_meta.kps_lhand_p = meta["keypoints_left_hand"][:, 2]
137
+ pose_meta.kps_rhand = meta["keypoints_right_hand"][:, :2] * (width, height)
138
+ pose_meta.kps_rhand_p = meta["keypoints_right_hand"][:, 2]
139
+ if 'keypoints_face' in meta:
140
+ pose_meta.kps_face = meta["keypoints_face"][:, :2] * (width, height)
141
+ pose_meta.kps_face_p = meta["keypoints_face"][:, 2]
142
+ return pose_meta
143
+
144
+ def load_from_meta(self, meta, norm_body=True, norm_hand=False):
145
+
146
+ self.image_id = meta.get("image_id", "00000.png")
147
+ self.height = meta["height"]
148
+ self.width = meta["width"]
149
+ kps_body_p = []
150
+ kps_body = []
151
+ for kp in meta["keypoints_body"]:
152
+ if kp is None:
153
+ kps_body.append([0, 0])
154
+ kps_body_p.append(0)
155
+ else:
156
+ kps_body.append(kp)
157
+ kps_body_p.append(1)
158
+
159
+ self.kps_body = np.array(kps_body)
160
+ self.kps_body[:, 0] *= self.width
161
+ self.kps_body[:, 1] *= self.height
162
+ self.kps_body_p = np.array(kps_body_p)
163
+
164
+ self.kps_lhand = np.array(meta["keypoints_left_hand"])[:, :2]
165
+ self.kps_lhand_p = np.array(meta["keypoints_left_hand"])[:, 2]
166
+ self.kps_rhand = np.array(meta["keypoints_right_hand"])[:, :2]
167
+ self.kps_rhand_p = np.array(meta["keypoints_right_hand"])[:, 2]
168
+
169
+ @staticmethod
170
+ def load_from_kp2ds(kp2ds: List[np.ndarray], width: int, height: int):
171
+ """input 133x3 numpy keypoints and output AAPoseMeta
172
+
173
+ Args:
174
+ kp2ds (List[np.ndarray]): _description_
175
+ width (int): _description_
176
+ height (int): _description_
177
+
178
+ Returns:
179
+ _type_: _description_
180
+ """
181
+ pose_meta = AAPoseMeta()
182
+ pose_meta.width = width
183
+ pose_meta.height = height
184
+ kps_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
185
+ kps_lhand = kp2ds[91:112]
186
+ kps_rhand = kp2ds[112:133]
187
+ kps_face = np.concatenate([kp2ds[23:23+68], kp2ds[1:3]], axis=0)
188
+ pose_meta.kps_body = kps_body[:, :2]
189
+ pose_meta.kps_body_p = kps_body[:, 2]
190
+ pose_meta.kps_lhand = kps_lhand[:, :2]
191
+ pose_meta.kps_lhand_p = kps_lhand[:, 2]
192
+ pose_meta.kps_rhand = kps_rhand[:, :2]
193
+ pose_meta.kps_rhand_p = kps_rhand[:, 2]
194
+ pose_meta.kps_face = kps_face[:, :2]
195
+ pose_meta.kps_face_p = kps_face[:, 2]
196
+ return pose_meta
197
+
198
+ @staticmethod
199
+ def from_dwpose(dwpose_det_res, height, width):
200
+ pose_meta = AAPoseMeta()
201
+ pose_meta.kps_body = dwpose_det_res["bodies"]["candidate"]
202
+ pose_meta.kps_body_p = dwpose_det_res["bodies"]["score"]
203
+ pose_meta.kps_body[:, 0] *= width
204
+ pose_meta.kps_body[:, 1] *= height
205
+
206
+ pose_meta.kps_lhand, pose_meta.kps_rhand = dwpose_det_res["hands"]
207
+ pose_meta.kps_lhand[:, 0] *= width
208
+ pose_meta.kps_lhand[:, 1] *= height
209
+ pose_meta.kps_rhand[:, 0] *= width
210
+ pose_meta.kps_rhand[:, 1] *= height
211
+ pose_meta.kps_lhand_p, pose_meta.kps_rhand_p = dwpose_det_res["hands_score"]
212
+
213
+ pose_meta.kps_face = dwpose_det_res["faces"][0]
214
+ pose_meta.kps_face[:, 0] *= width
215
+ pose_meta.kps_face[:, 1] *= height
216
+ pose_meta.kps_face_p = dwpose_det_res["faces_score"][0]
217
+ return pose_meta
218
+
219
+ def save_json(self):
220
+ pass
221
+
222
+ def draw_aapose(self, img, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):
223
+ from .human_visualization import draw_aapose_by_meta
224
+ return draw_aapose_by_meta(img, self, threshold, stick_width_norm, draw_hand, draw_head)
225
+
226
+
227
+ def translate(self, x0, y0):
228
+ all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
229
+ for kps in all_kps:
230
+ if kps is not None:
231
+ kps[:, 0] -= x0
232
+ kps[:, 1] -= y0
233
+
234
+ def scale(self, sx, sy):
235
+ all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
236
+ for kps in all_kps:
237
+ if kps is not None:
238
+ kps[:, 0] *= sx
239
+ kps[:, 1] *= sy
240
+
241
+ def padding_resize2(self, height=512, width=512):
242
+ """kps will be changed inplace
243
+
244
+ """
245
+
246
+ all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]
247
+
248
+ ori_height, ori_width = self.height, self.width
249
+
250
+ if (ori_height / ori_width) > (height / width):
251
+ new_width = int(height / ori_height * ori_width)
252
+ padding = int((width - new_width) / 2)
253
+ padding_width = padding
254
+ padding_height = 0
255
+ scale = height / ori_height
256
+
257
+ for kps in all_kps:
258
+ if kps is not None:
259
+ kps[:, 0] = kps[:, 0] * scale + padding
260
+ kps[:, 1] = kps[:, 1] * scale
261
+
262
+ else:
263
+ new_height = int(width / ori_width * ori_height)
264
+ padding = int((height - new_height) / 2)
265
+ padding_width = 0
266
+ padding_height = padding
267
+ scale = width / ori_width
268
+ for kps in all_kps:
269
+ if kps is not None:
270
+ kps[:, 1] = kps[:, 1] * scale + padding
271
+ kps[:, 0] = kps[:, 0] * scale
272
+
273
+
274
+ self.width = width
275
+ self.height = height
276
+ return self
277
+
278
+
279
+ def transform_preds(coords, center, scale, output_size, use_udp=False):
280
+ """Get final keypoint predictions from heatmaps and apply scaling and
281
+ translation to map them back to the image.
282
+
283
+ Note:
284
+ num_keypoints: K
285
+
286
+ Args:
287
+ coords (np.ndarray[K, ndims]):
288
+
289
+ * If ndims=2, corrds are predicted keypoint location.
290
+ * If ndims=4, corrds are composed of (x, y, scores, tags)
291
+ * If ndims=5, corrds are composed of (x, y, scores, tags,
292
+ flipped_tags)
293
+
294
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
295
+ scale (np.ndarray[2, ]): Scale of the bounding box
296
+ wrt [width, height].
297
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
298
+ destination heatmaps.
299
+ use_udp (bool): Use unbiased data processing
300
+
301
+ Returns:
302
+ np.ndarray: Predicted coordinates in the images.
303
+ """
304
+ assert coords.shape[1] in (2, 4, 5)
305
+ assert len(center) == 2
306
+ assert len(scale) == 2
307
+ assert len(output_size) == 2
308
+
309
+ # Recover the scale which is normalized by a factor of 200.
310
+ # scale = scale * 200.0
311
+
312
+ if use_udp:
313
+ scale_x = scale[0] / (output_size[0] - 1.0)
314
+ scale_y = scale[1] / (output_size[1] - 1.0)
315
+ else:
316
+ scale_x = scale[0] / output_size[0]
317
+ scale_y = scale[1] / output_size[1]
318
+
319
+ target_coords = np.ones_like(coords)
320
+ target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
321
+ target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
322
+
323
+ return target_coords
324
+
325
+
326
+ def _calc_distances(preds, targets, mask, normalize):
327
+ """Calculate the normalized distances between preds and target.
328
+
329
+ Note:
330
+ batch_size: N
331
+ num_keypoints: K
332
+ dimension of keypoints: D (normally, D=2 or D=3)
333
+
334
+ Args:
335
+ preds (np.ndarray[N, K, D]): Predicted keypoint location.
336
+ targets (np.ndarray[N, K, D]): Groundtruth keypoint location.
337
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
338
+ joints, and True for visible. Invisible joints will be ignored for
339
+ accuracy calculation.
340
+ normalize (np.ndarray[N, D]): Typical value is heatmap_size
341
+
342
+ Returns:
343
+ np.ndarray[K, N]: The normalized distances. \
344
+ If target keypoints are missing, the distance is -1.
345
+ """
346
+ N, K, _ = preds.shape
347
+ # set mask=0 when normalize==0
348
+ _mask = mask.copy()
349
+ _mask[np.where((normalize == 0).sum(1))[0], :] = False
350
+ distances = np.full((N, K), -1, dtype=np.float32)
351
+ # handle invalid values
352
+ normalize[np.where(normalize <= 0)] = 1e6
353
+ distances[_mask] = np.linalg.norm(
354
+ ((preds - targets) / normalize[:, None, :])[_mask], axis=-1)
355
+ return distances.T
356
+
357
+
358
+ def _distance_acc(distances, thr=0.5):
359
+ """Return the percentage below the distance threshold, while ignoring
360
+ distances values with -1.
361
+
362
+ Note:
363
+ batch_size: N
364
+ Args:
365
+ distances (np.ndarray[N, ]): The normalized distances.
366
+ thr (float): Threshold of the distances.
367
+
368
+ Returns:
369
+ float: Percentage of distances below the threshold. \
370
+ If all target keypoints are missing, return -1.
371
+ """
372
+ distance_valid = distances != -1
373
+ num_distance_valid = distance_valid.sum()
374
+ if num_distance_valid > 0:
375
+ return (distances[distance_valid] < thr).sum() / num_distance_valid
376
+ return -1
377
+
378
+
379
+ def _get_max_preds(heatmaps):
380
+ """Get keypoint predictions from score maps.
381
+
382
+ Note:
383
+ batch_size: N
384
+ num_keypoints: K
385
+ heatmap height: H
386
+ heatmap width: W
387
+
388
+ Args:
389
+ heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
390
+
391
+ Returns:
392
+ tuple: A tuple containing aggregated results.
393
+
394
+ - preds (np.ndarray[N, K, 2]): Predicted keypoint location.
395
+ - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
396
+ """
397
+ assert isinstance(heatmaps,
398
+ np.ndarray), ('heatmaps should be numpy.ndarray')
399
+ assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
400
+
401
+ N, K, _, W = heatmaps.shape
402
+ heatmaps_reshaped = heatmaps.reshape((N, K, -1))
403
+ idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))
404
+ maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))
405
+
406
+ preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
407
+ preds[:, :, 0] = preds[:, :, 0] % W
408
+ preds[:, :, 1] = preds[:, :, 1] // W
409
+
410
+ preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1)
411
+ return preds, maxvals
412
+
413
+
414
+ def _get_max_preds_3d(heatmaps):
415
+ """Get keypoint predictions from 3D score maps.
416
+
417
+ Note:
418
+ batch size: N
419
+ num keypoints: K
420
+ heatmap depth size: D
421
+ heatmap height: H
422
+ heatmap width: W
423
+
424
+ Args:
425
+ heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
426
+
427
+ Returns:
428
+ tuple: A tuple containing aggregated results.
429
+
430
+ - preds (np.ndarray[N, K, 3]): Predicted keypoint location.
431
+ - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
432
+ """
433
+ assert isinstance(heatmaps, np.ndarray), \
434
+ ('heatmaps should be numpy.ndarray')
435
+ assert heatmaps.ndim == 5, 'heatmaps should be 5-ndim'
436
+
437
+ N, K, D, H, W = heatmaps.shape
438
+ heatmaps_reshaped = heatmaps.reshape((N, K, -1))
439
+ idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))
440
+ maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))
441
+
442
+ preds = np.zeros((N, K, 3), dtype=np.float32)
443
+ _idx = idx[..., 0]
444
+ preds[..., 2] = _idx // (H * W)
445
+ preds[..., 1] = (_idx // W) % H
446
+ preds[..., 0] = _idx % W
447
+
448
+ preds = np.where(maxvals > 0.0, preds, -1)
449
+ return preds, maxvals
450
+
451
+
452
+ def pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None):
453
+ """Calculate the pose accuracy of PCK for each individual keypoint and the
454
+ averaged accuracy across all keypoints from heatmaps.
455
+
456
+ Note:
457
+ PCK metric measures accuracy of the localization of the body joints.
458
+ The distances between predicted positions and the ground-truth ones
459
+ are typically normalized by the bounding box size.
460
+ The threshold (thr) of the normalized distance is commonly set
461
+ as 0.05, 0.1 or 0.2 etc.
462
+
463
+ - batch_size: N
464
+ - num_keypoints: K
465
+ - heatmap height: H
466
+ - heatmap width: W
467
+
468
+ Args:
469
+ output (np.ndarray[N, K, H, W]): Model output heatmaps.
470
+ target (np.ndarray[N, K, H, W]): Groundtruth heatmaps.
471
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
472
+ joints, and True for visible. Invisible joints will be ignored for
473
+ accuracy calculation.
474
+ thr (float): Threshold of PCK calculation. Default 0.05.
475
+ normalize (np.ndarray[N, 2]): Normalization factor for H&W.
476
+
477
+ Returns:
478
+ tuple: A tuple containing keypoint accuracy.
479
+
480
+ - np.ndarray[K]: Accuracy of each keypoint.
481
+ - float: Averaged accuracy across all keypoints.
482
+ - int: Number of valid keypoints.
483
+ """
484
+ N, K, H, W = output.shape
485
+ if K == 0:
486
+ return None, 0, 0
487
+ if normalize is None:
488
+ normalize = np.tile(np.array([[H, W]]), (N, 1))
489
+
490
+ pred, _ = _get_max_preds(output)
491
+ gt, _ = _get_max_preds(target)
492
+ return keypoint_pck_accuracy(pred, gt, mask, thr, normalize)
493
+
494
+
495
+ def keypoint_pck_accuracy(pred, gt, mask, thr, normalize):
496
+ """Calculate the pose accuracy of PCK for each individual keypoint and the
497
+ averaged accuracy across all keypoints for coordinates.
498
+
499
+ Note:
500
+ PCK metric measures accuracy of the localization of the body joints.
501
+ The distances between predicted positions and the ground-truth ones
502
+ are typically normalized by the bounding box size.
503
+ The threshold (thr) of the normalized distance is commonly set
504
+ as 0.05, 0.1 or 0.2 etc.
505
+
506
+ - batch_size: N
507
+ - num_keypoints: K
508
+
509
+ Args:
510
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
511
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
512
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
513
+ joints, and True for visible. Invisible joints will be ignored for
514
+ accuracy calculation.
515
+ thr (float): Threshold of PCK calculation.
516
+ normalize (np.ndarray[N, 2]): Normalization factor for H&W.
517
+
518
+ Returns:
519
+ tuple: A tuple containing keypoint accuracy.
520
+
521
+ - acc (np.ndarray[K]): Accuracy of each keypoint.
522
+ - avg_acc (float): Averaged accuracy across all keypoints.
523
+ - cnt (int): Number of valid keypoints.
524
+ """
525
+ distances = _calc_distances(pred, gt, mask, normalize)
526
+
527
+ acc = np.array([_distance_acc(d, thr) for d in distances])
528
+ valid_acc = acc[acc >= 0]
529
+ cnt = len(valid_acc)
530
+ avg_acc = valid_acc.mean() if cnt > 0 else 0
531
+ return acc, avg_acc, cnt
532
+
533
+
534
+ def keypoint_auc(pred, gt, mask, normalize, num_step=20):
535
+ """Calculate the pose accuracy of PCK for each individual keypoint and the
536
+ averaged accuracy across all keypoints for coordinates.
537
+
538
+ Note:
539
+ - batch_size: N
540
+ - num_keypoints: K
541
+
542
+ Args:
543
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
544
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
545
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
546
+ joints, and True for visible. Invisible joints will be ignored for
547
+ accuracy calculation.
548
+ normalize (float): Normalization factor.
549
+
550
+ Returns:
551
+ float: Area under curve.
552
+ """
553
+ nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1))
554
+ x = [1.0 * i / num_step for i in range(num_step)]
555
+ y = []
556
+ for thr in x:
557
+ _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor)
558
+ y.append(avg_acc)
559
+
560
+ auc = 0
561
+ for i in range(num_step):
562
+ auc += 1.0 / num_step * y[i]
563
+ return auc
564
+
565
+
566
+ def keypoint_nme(pred, gt, mask, normalize_factor):
567
+ """Calculate the normalized mean error (NME).
568
+
569
+ Note:
570
+ - batch_size: N
571
+ - num_keypoints: K
572
+
573
+ Args:
574
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
575
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
576
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
577
+ joints, and True for visible. Invisible joints will be ignored for
578
+ accuracy calculation.
579
+ normalize_factor (np.ndarray[N, 2]): Normalization factor.
580
+
581
+ Returns:
582
+ float: normalized mean error
583
+ """
584
+ distances = _calc_distances(pred, gt, mask, normalize_factor)
585
+ distance_valid = distances[distances != -1]
586
+ return distance_valid.sum() / max(1, len(distance_valid))
587
+
588
+
589
+ def keypoint_epe(pred, gt, mask):
590
+ """Calculate the end-point error.
591
+
592
+ Note:
593
+ - batch_size: N
594
+ - num_keypoints: K
595
+
596
+ Args:
597
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
598
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
599
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
600
+ joints, and True for visible. Invisible joints will be ignored for
601
+ accuracy calculation.
602
+
603
+ Returns:
604
+ float: Average end-point error.
605
+ """
606
+
607
+ distances = _calc_distances(
608
+ pred, gt, mask,
609
+ np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))
610
+ distance_valid = distances[distances != -1]
611
+ return distance_valid.sum() / max(1, len(distance_valid))
612
+
613
+
614
+ def _taylor(heatmap, coord):
615
+ """Distribution aware coordinate decoding method.
616
+
617
+ Note:
618
+ - heatmap height: H
619
+ - heatmap width: W
620
+
621
+ Args:
622
+ heatmap (np.ndarray[H, W]): Heatmap of a particular joint type.
623
+ coord (np.ndarray[2,]): Coordinates of the predicted keypoints.
624
+
625
+ Returns:
626
+ np.ndarray[2,]: Updated coordinates.
627
+ """
628
+ H, W = heatmap.shape[:2]
629
+ px, py = int(coord[0]), int(coord[1])
630
+ if 1 < px < W - 2 and 1 < py < H - 2:
631
+ dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1])
632
+ dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px])
633
+ dxx = 0.25 * (
634
+ heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2])
635
+ dxy = 0.25 * (
636
+ heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] -
637
+ heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1])
638
+ dyy = 0.25 * (
639
+ heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] +
640
+ heatmap[py - 2 * 1][px])
641
+ derivative = np.array([[dx], [dy]])
642
+ hessian = np.array([[dxx, dxy], [dxy, dyy]])
643
+ if dxx * dyy - dxy**2 != 0:
644
+ hessianinv = np.linalg.inv(hessian)
645
+ offset = -hessianinv @ derivative
646
+ offset = np.squeeze(np.array(offset.T), axis=0)
647
+ coord += offset
648
+ return coord
649
+
650
+
651
+ def post_dark_udp(coords, batch_heatmaps, kernel=3):
652
+ """DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The
653
+ Devil is in the Details: Delving into Unbiased Data Processing for Human
654
+ Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate
655
+ Representation for Human Pose Estimation (CVPR 2020).
656
+
657
+ Note:
658
+ - batch size: B
659
+ - num keypoints: K
660
+ - num persons: N
661
+ - height of heatmaps: H
662
+ - width of heatmaps: W
663
+
664
+ B=1 for bottom_up paradigm where all persons share the same heatmap.
665
+ B=N for top_down paradigm where each person has its own heatmaps.
666
+
667
+ Args:
668
+ coords (np.ndarray[N, K, 2]): Initial coordinates of human pose.
669
+ batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps
670
+ kernel (int): Gaussian kernel size (K) for modulation.
671
+
672
+ Returns:
673
+ np.ndarray([N, K, 2]): Refined coordinates.
674
+ """
675
+ if not isinstance(batch_heatmaps, np.ndarray):
676
+ batch_heatmaps = batch_heatmaps.cpu().numpy()
677
+ B, K, H, W = batch_heatmaps.shape
678
+ N = coords.shape[0]
679
+ assert (B == 1 or B == N)
680
+ for heatmaps in batch_heatmaps:
681
+ for heatmap in heatmaps:
682
+ cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap)
683
+ np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps)
684
+ np.log(batch_heatmaps, batch_heatmaps)
685
+
686
+ batch_heatmaps_pad = np.pad(
687
+ batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)),
688
+ mode='edge').flatten()
689
+
690
+ index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2)
691
+ index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K)
692
+ index = index.astype(int).reshape(-1, 1)
693
+ i_ = batch_heatmaps_pad[index]
694
+ ix1 = batch_heatmaps_pad[index + 1]
695
+ iy1 = batch_heatmaps_pad[index + W + 2]
696
+ ix1y1 = batch_heatmaps_pad[index + W + 3]
697
+ ix1_y1_ = batch_heatmaps_pad[index - W - 3]
698
+ ix1_ = batch_heatmaps_pad[index - 1]
699
+ iy1_ = batch_heatmaps_pad[index - 2 - W]
700
+
701
+ dx = 0.5 * (ix1 - ix1_)
702
+ dy = 0.5 * (iy1 - iy1_)
703
+ derivative = np.concatenate([dx, dy], axis=1)
704
+ derivative = derivative.reshape(N, K, 2, 1)
705
+ dxx = ix1 - 2 * i_ + ix1_
706
+ dyy = iy1 - 2 * i_ + iy1_
707
+ dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
708
+ hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
709
+ hessian = hessian.reshape(N, K, 2, 2)
710
+ hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
711
+ coords -= np.einsum('ijmn,ijnk->ijmk', hessian, derivative).squeeze()
712
+ return coords
713
+
714
+
715
+ def _gaussian_blur(heatmaps, kernel=11):
716
+ """Modulate heatmap distribution with Gaussian.
717
+ sigma = 0.3*((kernel_size-1)*0.5-1)+0.8
718
+ sigma~=3 if k=17
719
+ sigma=2 if k=11;
720
+ sigma~=1.5 if k=7;
721
+ sigma~=1 if k=3;
722
+
723
+ Note:
724
+ - batch_size: N
725
+ - num_keypoints: K
726
+ - heatmap height: H
727
+ - heatmap width: W
728
+
729
+ Args:
730
+ heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
731
+ kernel (int): Gaussian kernel size (K) for modulation, which should
732
+ match the heatmap gaussian sigma when training.
733
+ K=17 for sigma=3 and k=11 for sigma=2.
734
+
735
+ Returns:
736
+ np.ndarray ([N, K, H, W]): Modulated heatmap distribution.
737
+ """
738
+ assert kernel % 2 == 1
739
+
740
+ border = (kernel - 1) // 2
741
+ batch_size = heatmaps.shape[0]
742
+ num_joints = heatmaps.shape[1]
743
+ height = heatmaps.shape[2]
744
+ width = heatmaps.shape[3]
745
+ for i in range(batch_size):
746
+ for j in range(num_joints):
747
+ origin_max = np.max(heatmaps[i, j])
748
+ dr = np.zeros((height + 2 * border, width + 2 * border),
749
+ dtype=np.float32)
750
+ dr[border:-border, border:-border] = heatmaps[i, j].copy()
751
+ dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
752
+ heatmaps[i, j] = dr[border:-border, border:-border].copy()
753
+ heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j])
754
+ return heatmaps
755
+
756
+
757
+ def keypoints_from_regression(regression_preds, center, scale, img_size):
758
+ """Get final keypoint predictions from regression vectors and transform
759
+ them back to the image.
760
+
761
+ Note:
762
+ - batch_size: N
763
+ - num_keypoints: K
764
+
765
+ Args:
766
+ regression_preds (np.ndarray[N, K, 2]): model prediction.
767
+ center (np.ndarray[N, 2]): Center of the bounding box (x, y).
768
+ scale (np.ndarray[N, 2]): Scale of the bounding box
769
+ wrt height/width.
770
+ img_size (list(img_width, img_height)): model input image size.
771
+
772
+ Returns:
773
+ tuple:
774
+
775
+ - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images.
776
+ - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
777
+ """
778
+ N, K, _ = regression_preds.shape
779
+ preds, maxvals = regression_preds, np.ones((N, K, 1), dtype=np.float32)
780
+
781
+ preds = preds * img_size
782
+
783
+ # Transform back to the image
784
+ for i in range(N):
785
+ preds[i] = transform_preds(preds[i], center[i], scale[i], img_size)
786
+
787
+ return preds, maxvals
788
+
789
+
790
+ def keypoints_from_heatmaps(heatmaps,
791
+ center,
792
+ scale,
793
+ unbiased=False,
794
+ post_process='default',
795
+ kernel=11,
796
+ valid_radius_factor=0.0546875,
797
+ use_udp=False,
798
+ target_type='GaussianHeatmap'):
799
+ """Get final keypoint predictions from heatmaps and transform them back to
800
+ the image.
801
+
802
+ Note:
803
+ - batch size: N
804
+ - num keypoints: K
805
+ - heatmap height: H
806
+ - heatmap width: W
807
+
808
+ Args:
809
+ heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.
810
+ center (np.ndarray[N, 2]): Center of the bounding box (x, y).
811
+ scale (np.ndarray[N, 2]): Scale of the bounding box
812
+ wrt height/width.
813
+ post_process (str/None): Choice of methods to post-process
814
+ heatmaps. Currently supported: None, 'default', 'unbiased',
815
+ 'megvii'.
816
+ unbiased (bool): Option to use unbiased decoding. Mutually
817
+ exclusive with megvii.
818
+ Note: this arg is deprecated and unbiased=True can be replaced
819
+ by post_process='unbiased'
820
+ Paper ref: Zhang et al. Distribution-Aware Coordinate
821
+ Representation for Human Pose Estimation (CVPR 2020).
822
+ kernel (int): Gaussian kernel size (K) for modulation, which should
823
+ match the heatmap gaussian sigma when training.
824
+ K=17 for sigma=3 and k=11 for sigma=2.
825
+ valid_radius_factor (float): The radius factor of the positive area
826
+ in classification heatmap for UDP.
827
+ use_udp (bool): Use unbiased data processing.
828
+ target_type (str): 'GaussianHeatmap' or 'CombinedTarget'.
829
+ GaussianHeatmap: Classification target with gaussian distribution.
830
+ CombinedTarget: The combination of classification target
831
+ (response map) and regression target (offset map).
832
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
833
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
834
+
835
+ Returns:
836
+ tuple: A tuple containing keypoint predictions and scores.
837
+
838
+ - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images.
839
+ - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
840
+ """
841
+ # Avoid being affected
842
+ heatmaps = heatmaps.copy()
843
+
844
+ # detect conflicts
845
+ if unbiased:
846
+ assert post_process not in [False, None, 'megvii']
847
+ if post_process in ['megvii', 'unbiased']:
848
+ assert kernel > 0
849
+ if use_udp:
850
+ assert not post_process == 'megvii'
851
+
852
+ # normalize configs
853
+ if post_process is False:
854
+ warnings.warn(
855
+ 'post_process=False is deprecated, '
856
+ 'please use post_process=None instead', DeprecationWarning)
857
+ post_process = None
858
+ elif post_process is True:
859
+ if unbiased is True:
860
+ warnings.warn(
861
+ 'post_process=True, unbiased=True is deprecated,'
862
+ " please use post_process='unbiased' instead",
863
+ DeprecationWarning)
864
+ post_process = 'unbiased'
865
+ else:
866
+ warnings.warn(
867
+ 'post_process=True, unbiased=False is deprecated, '
868
+ "please use post_process='default' instead",
869
+ DeprecationWarning)
870
+ post_process = 'default'
871
+ elif post_process == 'default':
872
+ if unbiased is True:
873
+ warnings.warn(
874
+ 'unbiased=True is deprecated, please use '
875
+ "post_process='unbiased' instead", DeprecationWarning)
876
+ post_process = 'unbiased'
877
+
878
+ # start processing
879
+ if post_process == 'megvii':
880
+ heatmaps = _gaussian_blur(heatmaps, kernel=kernel)
881
+
882
+ N, K, H, W = heatmaps.shape
883
+ if use_udp:
884
+ if target_type.lower() == 'GaussianHeatMap'.lower():
885
+ preds, maxvals = _get_max_preds(heatmaps)
886
+ preds = post_dark_udp(preds, heatmaps, kernel=kernel)
887
+ elif target_type.lower() == 'CombinedTarget'.lower():
888
+ for person_heatmaps in heatmaps:
889
+ for i, heatmap in enumerate(person_heatmaps):
890
+ kt = 2 * kernel + 1 if i % 3 == 0 else kernel
891
+ cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap)
892
+ # valid radius is in direct proportion to the height of heatmap.
893
+ valid_radius = valid_radius_factor * H
894
+ offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius
895
+ offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius
896
+ heatmaps = heatmaps[:, ::3, :]
897
+ preds, maxvals = _get_max_preds(heatmaps)
898
+ index = preds[..., 0] + preds[..., 1] * W
899
+ index += W * H * np.arange(0, N * K / 3)
900
+ index = index.astype(int).reshape(N, K // 3, 1)
901
+ preds += np.concatenate((offset_x[index], offset_y[index]), axis=2)
902
+ else:
903
+ raise ValueError('target_type should be either '
904
+ "'GaussianHeatmap' or 'CombinedTarget'")
905
+ else:
906
+ preds, maxvals = _get_max_preds(heatmaps)
907
+ if post_process == 'unbiased': # alleviate biased coordinate
908
+ # apply Gaussian distribution modulation.
909
+ heatmaps = np.log(
910
+ np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10))
911
+ for n in range(N):
912
+ for k in range(K):
913
+ preds[n][k] = _taylor(heatmaps[n][k], preds[n][k])
914
+ elif post_process is not None:
915
+ # add +/-0.25 shift to the predicted locations for higher acc.
916
+ for n in range(N):
917
+ for k in range(K):
918
+ heatmap = heatmaps[n][k]
919
+ px = int(preds[n][k][0])
920
+ py = int(preds[n][k][1])
921
+ if 1 < px < W - 1 and 1 < py < H - 1:
922
+ diff = np.array([
923
+ heatmap[py][px + 1] - heatmap[py][px - 1],
924
+ heatmap[py + 1][px] - heatmap[py - 1][px]
925
+ ])
926
+ preds[n][k] += np.sign(diff) * .25
927
+ if post_process == 'megvii':
928
+ preds[n][k] += 0.5
929
+
930
+ # Transform back to the image
931
+ for i in range(N):
932
+ preds[i] = transform_preds(
933
+ preds[i], center[i], scale[i], [W, H], use_udp=use_udp)
934
+
935
+ if post_process == 'megvii':
936
+ maxvals = maxvals / 255.0 + 0.5
937
+
938
+ return preds, maxvals
939
+
940
+
941
+ def keypoints_from_heatmaps3d(heatmaps, center, scale):
942
+ """Get final keypoint predictions from 3d heatmaps and transform them back
943
+ to the image.
944
+
945
+ Note:
946
+ - batch size: N
947
+ - num keypoints: K
948
+ - heatmap depth size: D
949
+ - heatmap height: H
950
+ - heatmap width: W
951
+
952
+ Args:
953
+ heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
954
+ center (np.ndarray[N, 2]): Center of the bounding box (x, y).
955
+ scale (np.ndarray[N, 2]): Scale of the bounding box
956
+ wrt height/width.
957
+
958
+ Returns:
959
+ tuple: A tuple containing keypoint predictions and scores.
960
+
961
+ - preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \
962
+ in images.
963
+ - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
964
+ """
965
+ N, K, D, H, W = heatmaps.shape
966
+ preds, maxvals = _get_max_preds_3d(heatmaps)
967
+ # Transform back to the image
968
+ for i in range(N):
969
+ preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i],
970
+ [W, H])
971
+ return preds, maxvals
972
+
973
+
974
+ def multilabel_classification_accuracy(pred, gt, mask, thr=0.5):
975
+ """Get multi-label classification accuracy.
976
+
977
+ Note:
978
+ - batch size: N
979
+ - label number: L
980
+
981
+ Args:
982
+ pred (np.ndarray[N, L, 2]): model predicted labels.
983
+ gt (np.ndarray[N, L, 2]): ground-truth labels.
984
+ mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of
985
+ ground-truth labels.
986
+
987
+ Returns:
988
+ float: multi-label classification accuracy.
989
+ """
990
+ # we only compute accuracy on the samples with ground-truth of all labels.
991
+ valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0)
992
+ pred, gt = pred[valid], gt[valid]
993
+
994
+ if pred.shape[0] == 0:
995
+ acc = 0.0 # when no sample is with gt labels, set acc to 0.
996
+ else:
997
+ # The classification of a sample is regarded as correct
998
+ # only if it's correct for all labels.
999
+ acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean()
1000
+ return acc
1001
+
1002
+
1003
+
1004
+ def get_transform(center, scale, res, rot=0):
1005
+ """Generate transformation matrix."""
1006
+ # res: (height, width), (rows, cols)
1007
+ crop_aspect_ratio = res[0] / float(res[1])
1008
+ h = 200 * scale
1009
+ w = h / crop_aspect_ratio
1010
+ t = np.zeros((3, 3))
1011
+ t[0, 0] = float(res[1]) / w
1012
+ t[1, 1] = float(res[0]) / h
1013
+ t[0, 2] = res[1] * (-float(center[0]) / w + .5)
1014
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
1015
+ t[2, 2] = 1
1016
+ if not rot == 0:
1017
+ rot = -rot # To match direction of rotation from cropping
1018
+ rot_mat = np.zeros((3, 3))
1019
+ rot_rad = rot * np.pi / 180
1020
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
1021
+ rot_mat[0, :2] = [cs, -sn]
1022
+ rot_mat[1, :2] = [sn, cs]
1023
+ rot_mat[2, 2] = 1
1024
+ # Need to rotate around center
1025
+ t_mat = np.eye(3)
1026
+ t_mat[0, 2] = -res[1] / 2
1027
+ t_mat[1, 2] = -res[0] / 2
1028
+ t_inv = t_mat.copy()
1029
+ t_inv[:2, 2] *= -1
1030
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
1031
+ return t
1032
+
1033
+
1034
+ def transform(pt, center, scale, res, invert=0, rot=0):
1035
+ """Transform pixel location to different reference."""
1036
+ t = get_transform(center, scale, res, rot=rot)
1037
+ if invert:
1038
+ t = np.linalg.inv(t)
1039
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
1040
+ new_pt = np.dot(t, new_pt)
1041
+ return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1
1042
+
1043
+
1044
+ def bbox_from_detector(bbox, input_resolution=(224, 224), rescale=1.25):
1045
+ """
1046
+ Get center and scale of bounding box from bounding box.
1047
+ The expected format is [min_x, min_y, max_x, max_y].
1048
+ """
1049
+ CROP_IMG_HEIGHT, CROP_IMG_WIDTH = input_resolution
1050
+ CROP_ASPECT_RATIO = CROP_IMG_HEIGHT / float(CROP_IMG_WIDTH)
1051
+
1052
+ # center
1053
+ center_x = (bbox[0] + bbox[2]) / 2.0
1054
+ center_y = (bbox[1] + bbox[3]) / 2.0
1055
+ center = np.array([center_x, center_y])
1056
+
1057
+ # scale
1058
+ bbox_w = bbox[2] - bbox[0]
1059
+ bbox_h = bbox[3] - bbox[1]
1060
+ bbox_size = max(bbox_w * CROP_ASPECT_RATIO, bbox_h)
1061
+
1062
+ scale = np.array([bbox_size / CROP_ASPECT_RATIO, bbox_size]) / 200.0
1063
+ # scale = bbox_size / 200.0
1064
+ # adjust bounding box tightness
1065
+ scale *= rescale
1066
+ return center, scale
1067
+
1068
+
1069
+ def crop(img, center, scale, res):
1070
+ """
1071
+ Crop image according to the supplied bounding box.
1072
+ res: [rows, cols]
1073
+ """
1074
+ # Upper left point
1075
+ ul = np.array(transform([1, 1], center, max(scale), res, invert=1)) - 1
1076
+ # Bottom right point
1077
+ br = np.array(transform([res[1] + 1, res[0] + 1], center, max(scale), res, invert=1)) - 1
1078
+
1079
+ # Padding so that when rotated proper amount of context is included
1080
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
1081
+
1082
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
1083
+ if len(img.shape) > 2:
1084
+ new_shape += [img.shape[2]]
1085
+ new_img = np.zeros(new_shape, dtype=np.float32)
1086
+
1087
+ # Range to fill new array
1088
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
1089
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
1090
+ # Range to sample from original image
1091
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
1092
+ old_y = max(0, ul[1]), min(len(img), br[1])
1093
+ try:
1094
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
1095
+ except Exception as e:
1096
+ print(e)
1097
+
1098
+ new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows)
1099
+ return new_img, new_shape, (old_x, old_y), (new_x, new_y) # , ul, br
1100
+
1101
+
1102
+ def split_kp2ds_for_aa(kp2ds, ret_face=False):
1103
+ kp2ds_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
1104
+ kp2ds_lhand = kp2ds[91:112]
1105
+ kp2ds_rhand = kp2ds[112:133]
1106
+ kp2ds_face = kp2ds[22:91]
1107
+ if ret_face:
1108
+ return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy(), kp2ds_face.copy()
1109
+ return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()
1110
+
1111
+ def load_pose_metas_from_kp2ds_seq_list(kp2ds_seq, width, height):
1112
+ metas = []
1113
+ for kps in kp2ds_seq:
1114
+ if len(kps) != 1:
1115
+ return None
1116
+ kps = kps[0].copy()
1117
+ kps[:, 0] /= width
1118
+ kps[:, 1] /= height
1119
+ kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True)
1120
+
1121
+ if kp2ds_body[:, :2].min(axis=1).max() < 0:
1122
+ kp2ds_body = last_kp2ds_body
1123
+ last_kp2ds_body = kp2ds_body
1124
+
1125
+ meta = {
1126
+ "width": width,
1127
+ "height": height,
1128
+ "keypoints_body": kp2ds_body.tolist(),
1129
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
1130
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
1131
+ "keypoints_face": kp2ds_face.tolist(),
1132
+ }
1133
+ metas.append(meta)
1134
+ return metas
1135
+
1136
+
1137
+ def load_pose_metas_from_kp2ds_seq(kp2ds_seq, width, height):
1138
+ metas = []
1139
+ for kps in kp2ds_seq:
1140
+ kps = kps.copy()
1141
+ kps[:, 0] /= width
1142
+ kps[:, 1] /= height
1143
+ kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True)
1144
+
1145
+ # 排除全部小于0的情况
1146
+ if kp2ds_body[:, :2].min(axis=1).max() < 0:
1147
+ kp2ds_body = last_kp2ds_body
1148
+ last_kp2ds_body = kp2ds_body
1149
+
1150
+ meta = {
1151
+ "width": width,
1152
+ "height": height,
1153
+ "keypoints_body": kp2ds_body,
1154
+ "keypoints_left_hand": kp2ds_lhand,
1155
+ "keypoints_right_hand": kp2ds_rhand,
1156
+ "keypoints_face": kp2ds_face,
1157
+ }
1158
+ metas.append(meta)
1159
+ return metas
wan/modules/animate/preprocess/preprocess_data.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import argparse
4
+ from process_pipepline import ProcessPipeline
5
+
6
+
7
+ def _parse_args():
8
+ parser = argparse.ArgumentParser(
9
+ description="The preprocessing pipeline for Wan-animate."
10
+ )
11
+
12
+ parser.add_argument(
13
+ "--ckpt_path",
14
+ type=str,
15
+ default=None,
16
+ help="The path to the preprocessing model's checkpoint directory. ")
17
+
18
+ parser.add_argument(
19
+ "--video_path",
20
+ type=str,
21
+ default=None,
22
+ help="The path to the driving video.")
23
+ parser.add_argument(
24
+ "--refer_path",
25
+ type=str,
26
+ default=None,
27
+ help="The path to the refererence image.")
28
+ parser.add_argument(
29
+ "--save_path",
30
+ type=str,
31
+ default=None,
32
+ help="The path to save the processed results.")
33
+
34
+ parser.add_argument(
35
+ "--resolution_area",
36
+ type=int,
37
+ nargs=2,
38
+ default=[1280, 720],
39
+ help="The target resolution for processing, specified as [width, height]. To handle different aspect ratios, the video is resized to have a total area equivalent to width * height, while preserving the original aspect ratio."
40
+ )
41
+ parser.add_argument(
42
+ "--fps",
43
+ type=int,
44
+ default=30,
45
+ help="The target FPS for processing the driving video. Set to -1 to use the video's original FPS."
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--replace_flag",
50
+ action="store_true",
51
+ default=False,
52
+ help="Whether to use replacement mode.")
53
+ parser.add_argument(
54
+ "--retarget_flag",
55
+ action="store_true",
56
+ default=False,
57
+ help="Whether to use pose retargeting. Currently only supported in animation mode")
58
+ parser.add_argument(
59
+ "--use_flux",
60
+ action="store_true",
61
+ default=False,
62
+ help="Whether to use image editing in pose retargeting. Recommended if the character in the reference image or the first frame of the driving video is not in a standard, front-facing pose")
63
+
64
+ # Parameters for the mask strategy in replacement mode. These control the mask's size and shape. Refer to https://arxiv.org/pdf/2502.06145
65
+ parser.add_argument(
66
+ "--iterations",
67
+ type=int,
68
+ default=3,
69
+ help="Number of iterations for mask dilation."
70
+ )
71
+ parser.add_argument(
72
+ "--k",
73
+ type=int,
74
+ default=7,
75
+ help="Number of kernel size for mask dilation."
76
+ )
77
+ parser.add_argument(
78
+ "--w_len",
79
+ type=int,
80
+ default=1,
81
+ help="The number of subdivisions for the grid along the 'w' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed."
82
+ )
83
+ parser.add_argument(
84
+ "--h_len",
85
+ type=int,
86
+ default=1,
87
+ help="The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed."
88
+ )
89
+ args = parser.parse_args()
90
+
91
+ return args
92
+
93
+
94
+ if __name__ == '__main__':
95
+ args = _parse_args()
96
+ args_dict = vars(args)
97
+ print(args_dict)
98
+
99
+ assert len(args.resolution_area) == 2, "resolution_area should be a list of two integers [width, height]"
100
+ assert not args.use_flux or args.retarget_flag, "Image editing with FLUX can only be used when pose retargeting is enabled."
101
+
102
+ pose2d_checkpoint_path = os.path.join(args.ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')
103
+ det_checkpoint_path = os.path.join(args.ckpt_path, 'det/yolov10m.onnx')
104
+
105
+ sam2_checkpoint_path = os.path.join(args.ckpt_path, 'sam2/sam2_hiera_large.pt') if args.replace_flag else None
106
+ flux_kontext_path = os.path.join(args.ckpt_path, 'FLUX.1-Kontext-dev') if args.use_flux else None
107
+ process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)
108
+ os.makedirs(args.save_path, exist_ok=True)
109
+ process_pipeline(video_path=args.video_path,
110
+ refer_image_path=args.refer_path,
111
+ output_path=args.save_path,
112
+ resolution_area=args.resolution_area,
113
+ fps=args.fps,
114
+ iterations=args.iterations,
115
+ k=args.k,
116
+ w_len=args.w_len,
117
+ h_len=args.h_len,
118
+ retarget_flag=args.retarget_flag,
119
+ use_flux=args.use_flux,
120
+ replace_flag=args.replace_flag)
121
+
wan/modules/animate/preprocess/process_pipepline.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import numpy as np
4
+ import shutil
5
+ import torch
6
+ from diffusers import FluxKontextPipeline
7
+ import cv2
8
+ from loguru import logger
9
+ from PIL import Image
10
+ try:
11
+ import moviepy.editor as mpy
12
+ except:
13
+ import moviepy as mpy
14
+
15
+ from decord import VideoReader
16
+ from pose2d import Pose2d
17
+ from pose2d_utils import AAPoseMeta
18
+ from utils import resize_by_area, get_frame_indices, padding_resize, get_face_bboxes, get_aug_mask, get_mask_body_img
19
+ from human_visualization import draw_aapose_by_meta_new
20
+ from retarget_pose import get_retarget_pose
21
+ import sam2.modeling.sam.transformer as transformer
22
+ transformer.USE_FLASH_ATTN = False
23
+ transformer.MATH_KERNEL_ON = True
24
+ transformer.OLD_GPU = True
25
+ from sam_utils import build_sam2_video_predictor
26
+
27
+
28
+ class ProcessPipeline():
29
+ def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):
30
+ self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)
31
+
32
+ model_cfg = "sam2_hiera_l.yaml"
33
+ if sam_checkpoint_path is not None:
34
+ self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path)
35
+ if flux_kontext_path is not None:
36
+ self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to("cuda")
37
+
38
+ def __call__(self, video_path, refer_image_path, output_path, resolution_area=[1280, 720], fps=30, iterations=3, k=7, w_len=1, h_len=1, retarget_flag=False, use_flux=False, replace_flag=False):
39
+ if replace_flag:
40
+
41
+ video_reader = VideoReader(video_path)
42
+ frame_num = len(video_reader)
43
+ print('frame_num: {}'.format(frame_num))
44
+
45
+ video_fps = video_reader.get_avg_fps()
46
+ print('video_fps: {}'.format(video_fps))
47
+ print('fps: {}'.format(fps))
48
+
49
+ # TODO: Maybe we can switch to PyAV later, which can get accurate frame num
50
+ duration = video_reader.get_frame_timestamp(-1)[-1]
51
+ expected_frame_num = int(duration * video_fps + 0.5)
52
+ ratio = abs((frame_num - expected_frame_num)/frame_num)
53
+ if ratio > 0.1:
54
+ print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
55
+ frame_num = expected_frame_num
56
+
57
+ if fps == -1:
58
+ fps = video_fps
59
+
60
+ target_num = int(frame_num / video_fps * fps)
61
+ print('target_num: {}'.format(target_num))
62
+ idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
63
+ frames = video_reader.get_batch(idxs).asnumpy()
64
+
65
+ frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames]
66
+ height, width = frames[0].shape[:2]
67
+ logger.info(f"Processing pose meta")
68
+
69
+
70
+ tpl_pose_metas = self.pose2d(frames)
71
+
72
+ face_images = []
73
+ for idx, meta in enumerate(tpl_pose_metas):
74
+ face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
75
+ image_shape=(frames[0].shape[0], frames[0].shape[1]))
76
+
77
+ x1, x2, y1, y2 = face_bbox_for_image
78
+ face_image = frames[idx][y1:y2, x1:x2]
79
+ face_image = cv2.resize(face_image, (512, 512))
80
+ face_images.append(face_image)
81
+
82
+ logger.info(f"Processing reference image: {refer_image_path}")
83
+ refer_img = cv2.imread(refer_image_path)
84
+ src_ref_path = os.path.join(output_path, 'src_ref.png')
85
+ shutil.copy(refer_image_path, src_ref_path)
86
+ refer_img = refer_img[..., ::-1]
87
+
88
+ refer_img = padding_resize(refer_img, height, width)
89
+ logger.info(f"Processing template video: {video_path}")
90
+ tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
91
+ cond_images = []
92
+
93
+ for idx, meta in enumerate(tpl_retarget_pose_metas):
94
+ canvas = np.zeros_like(refer_img)
95
+ conditioning_image = draw_aapose_by_meta_new(canvas, meta)
96
+ cond_images.append(conditioning_image)
97
+ masks = self.get_mask(frames, 400, tpl_pose_metas)
98
+
99
+ bg_images = []
100
+ aug_masks = []
101
+
102
+ for frame, mask in zip(frames, masks):
103
+ if iterations > 0:
104
+ _, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k)
105
+ each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len)
106
+ else:
107
+ each_aug_mask = mask
108
+
109
+ each_bg_image = frame * (1 - each_aug_mask[:, :, None])
110
+ bg_images.append(each_bg_image)
111
+ aug_masks.append(each_aug_mask)
112
+
113
+ src_face_path = os.path.join(output_path, 'src_face.mp4')
114
+ mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
115
+
116
+ src_pose_path = os.path.join(output_path, 'src_pose.mp4')
117
+ mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
118
+
119
+ src_bg_path = os.path.join(output_path, 'src_bg.mp4')
120
+ mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path)
121
+
122
+ aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks]
123
+ src_mask_path = os.path.join(output_path, 'src_mask.mp4')
124
+ mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path)
125
+ return True
126
+ else:
127
+ logger.info(f"Processing reference image: {refer_image_path}")
128
+ refer_img = cv2.imread(refer_image_path)
129
+ src_ref_path = os.path.join(output_path, 'src_ref.png')
130
+ shutil.copy(refer_image_path, src_ref_path)
131
+ refer_img = refer_img[..., ::-1]
132
+
133
+ refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16)
134
+
135
+ refer_pose_meta = self.pose2d([refer_img])[0]
136
+
137
+
138
+ logger.info(f"Processing template video: {video_path}")
139
+ video_reader = VideoReader(video_path)
140
+ frame_num = len(video_reader)
141
+ print('frame_num: {}'.format(frame_num))
142
+
143
+ video_fps = video_reader.get_avg_fps()
144
+ print('video_fps: {}'.format(video_fps))
145
+ print('fps: {}'.format(fps))
146
+
147
+ # TODO: Maybe we can switch to PyAV later, which can get accurate frame num
148
+ duration = video_reader.get_frame_timestamp(-1)[-1]
149
+ expected_frame_num = int(duration * video_fps + 0.5)
150
+ ratio = abs((frame_num - expected_frame_num)/frame_num)
151
+ if ratio > 0.1:
152
+ print("Warning: The difference between the actual number of frames and the expected number of frames is two large")
153
+ frame_num = expected_frame_num
154
+
155
+ if fps == -1:
156
+ fps = video_fps
157
+
158
+ target_num = int(frame_num / video_fps * fps)
159
+ print('target_num: {}'.format(target_num))
160
+ idxs = get_frame_indices(frame_num, video_fps, target_num, fps)
161
+ frames = video_reader.get_batch(idxs).asnumpy()
162
+
163
+ logger.info(f"Processing pose meta")
164
+
165
+ tpl_pose_meta0 = self.pose2d(frames[:1])[0]
166
+ tpl_pose_metas = self.pose2d(frames)
167
+
168
+ face_images = []
169
+ for idx, meta in enumerate(tpl_pose_metas):
170
+ face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,
171
+ image_shape=(frames[0].shape[0], frames[0].shape[1]))
172
+
173
+ x1, x2, y1, y2 = face_bbox_for_image
174
+ face_image = frames[idx][y1:y2, x1:x2]
175
+ face_image = cv2.resize(face_image, (512, 512))
176
+ face_images.append(face_image)
177
+
178
+ if retarget_flag:
179
+ if use_flux:
180
+ tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta)
181
+ refer_input = Image.fromarray(refer_img)
182
+ refer_edit = self.flux_kontext(
183
+ image=refer_input,
184
+ height=refer_img.shape[0],
185
+ width=refer_img.shape[1],
186
+ prompt=refer_prompt,
187
+ guidance_scale=2.5,
188
+ num_inference_steps=28,
189
+ ).images[0]
190
+
191
+ refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1]))
192
+ refer_edit_path = os.path.join(output_path, 'refer_edit.png')
193
+ refer_edit.save(refer_edit_path)
194
+ refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0]
195
+
196
+ tpl_img = frames[1]
197
+ tpl_input = Image.fromarray(tpl_img)
198
+
199
+ tpl_edit = self.flux_kontext(
200
+ image=tpl_input,
201
+ height=tpl_img.shape[0],
202
+ width=tpl_img.shape[1],
203
+ prompt=tpl_prompt,
204
+ guidance_scale=2.5,
205
+ num_inference_steps=28,
206
+ ).images[0]
207
+
208
+ tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1]))
209
+ tpl_edit_path = os.path.join(output_path, 'tpl_edit.png')
210
+ tpl_edit.save(tpl_edit_path)
211
+ tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0]
212
+ tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta)
213
+ else:
214
+ tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None)
215
+ else:
216
+ tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]
217
+
218
+ cond_images = []
219
+ for idx, meta in enumerate(tpl_retarget_pose_metas):
220
+ if retarget_flag:
221
+ canvas = np.zeros_like(refer_img)
222
+ conditioning_image = draw_aapose_by_meta_new(canvas, meta)
223
+ else:
224
+ canvas = np.zeros_like(frames[0])
225
+ conditioning_image = draw_aapose_by_meta_new(canvas, meta)
226
+ conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1])
227
+
228
+ cond_images.append(conditioning_image)
229
+
230
+ src_face_path = os.path.join(output_path, 'src_face.mp4')
231
+ mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)
232
+
233
+ src_pose_path = os.path.join(output_path, 'src_pose.mp4')
234
+ mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)
235
+ return True
236
+
237
+ def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta):
238
+ arm_visible = False
239
+ leg_visible = False
240
+ for tpl_pose_meta in tpl_pose_metas:
241
+ tpl_keypoints = tpl_pose_meta['keypoints_body']
242
+ if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0:
243
+ if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \
244
+ (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75):
245
+ arm_visible = True
246
+ if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0:
247
+ if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \
248
+ (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75):
249
+ leg_visible = True
250
+ if arm_visible and leg_visible:
251
+ break
252
+
253
+ if leg_visible:
254
+ if tpl_pose_meta['width'] > tpl_pose_meta['height']:
255
+ tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
256
+ else:
257
+ tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
258
+
259
+ if refer_pose_meta['width'] > refer_pose_meta['height']:
260
+ refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image."
261
+ else:
262
+ refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image."
263
+ elif arm_visible:
264
+ if tpl_pose_meta['width'] > tpl_pose_meta['height']:
265
+ tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
266
+ else:
267
+ tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
268
+
269
+ if refer_pose_meta['width'] > refer_pose_meta['height']:
270
+ refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image."
271
+ else:
272
+ refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image."
273
+ else:
274
+ tpl_prompt = "Change the person to face forward."
275
+ refer_prompt = "Change the person to face forward."
276
+
277
+ return tpl_prompt, refer_prompt
278
+
279
+
280
+ def get_mask(self, frames, th_step, kp2ds_all):
281
+ frame_num = len(frames)
282
+ if frame_num < th_step:
283
+ num_step = 1
284
+ else:
285
+ num_step = (frame_num + th_step) // th_step
286
+
287
+ all_mask = []
288
+ for index in range(num_step):
289
+ each_frames = frames[index * th_step:(index + 1) * th_step]
290
+
291
+ kp2ds = kp2ds_all[index * th_step:(index + 1) * th_step]
292
+ if len(each_frames) > 4:
293
+ key_frame_num = 4
294
+ elif 4 >= len(each_frames) > 0:
295
+ key_frame_num = 1
296
+ else:
297
+ continue
298
+
299
+ key_frame_step = len(kp2ds) // key_frame_num
300
+ key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))
301
+
302
+ key_points_index = [0, 1, 2, 5, 8, 11, 10, 13]
303
+ key_frame_body_points_list = []
304
+ for key_frame_index in key_frame_index_list:
305
+ keypoints_body_list = []
306
+ body_key_points = kp2ds[key_frame_index]['keypoints_body']
307
+ for each_index in key_points_index:
308
+ each_keypoint = body_key_points[each_index]
309
+ if None is each_keypoint:
310
+ continue
311
+ keypoints_body_list.append(each_keypoint)
312
+
313
+ keypoints_body = np.array(keypoints_body_list)[:, :2]
314
+ wh = np.array([[kp2ds[0]['width'], kp2ds[0]['height']]])
315
+ points = (keypoints_body * wh).astype(np.int32)
316
+ key_frame_body_points_list.append(points)
317
+
318
+ inference_state = self.predictor.init_state_v2(frames=each_frames)
319
+ self.predictor.reset_state(inference_state)
320
+ ann_obj_id = 1
321
+ for ann_frame_idx, points in zip(key_frame_index_list, key_frame_body_points_list):
322
+ labels = np.array([1] * points.shape[0], np.int32)
323
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
324
+ inference_state=inference_state,
325
+ frame_idx=ann_frame_idx,
326
+ obj_id=ann_obj_id,
327
+ points=points,
328
+ labels=labels,
329
+ )
330
+
331
+ video_segments = {}
332
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
333
+ video_segments[out_frame_idx] = {
334
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
335
+ for i, out_obj_id in enumerate(out_obj_ids)
336
+ }
337
+
338
+ for out_frame_idx in range(len(video_segments)):
339
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
340
+ out_mask = out_mask[0].astype(np.uint8)
341
+ all_mask.append(out_mask)
342
+
343
+ return all_mask
344
+
345
+ def convert_list_to_array(self, metas):
346
+ metas_list = []
347
+ for meta in metas:
348
+ for key, value in meta.items():
349
+ if type(value) is list:
350
+ value = np.array(value)
351
+ meta[key] = value
352
+ metas_list.append(meta)
353
+ return metas_list
354
+
wan/modules/animate/preprocess/retarget_pose.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import json
6
+ from tqdm import tqdm
7
+ import math
8
+ from typing import NamedTuple, List
9
+ import copy
10
+ from pose2d_utils import AAPoseMeta
11
+
12
+
13
+ # load skeleton name and bone lines
14
+ keypoint_list = [
15
+ "Nose",
16
+ "Neck",
17
+ "RShoulder",
18
+ "RElbow",
19
+ "RWrist", # No.4
20
+ "LShoulder",
21
+ "LElbow",
22
+ "LWrist", # No.7
23
+ "RHip",
24
+ "RKnee",
25
+ "RAnkle", # No.10
26
+ "LHip",
27
+ "LKnee",
28
+ "LAnkle", # No.13
29
+ "REye",
30
+ "LEye",
31
+ "REar",
32
+ "LEar",
33
+ "LToe",
34
+ "RToe",
35
+ ]
36
+
37
+
38
+ limbSeq = [
39
+ [2, 3], [2, 6], # shoulders
40
+ [3, 4], [4, 5], # left arm
41
+ [6, 7], [7, 8], # right arm
42
+ [2, 9], [9, 10], [10, 11], # right leg
43
+ [2, 12], [12, 13], [13, 14], # left leg
44
+ [2, 1], [1, 15], [15, 17], [1, 16], [16, 18], # face (nose, eyes, ears)
45
+ [14, 19], # left foot
46
+ [11, 20] # right foot
47
+ ]
48
+
49
+ eps = 0.01
50
+
51
+ class Keypoint(NamedTuple):
52
+ x: float
53
+ y: float
54
+ score: float = 1.0
55
+ id: int = -1
56
+
57
+
58
+ # for each limb, calculate src & dst bone's length
59
+ # and calculate their ratios
60
+ def get_length(skeleton, limb):
61
+
62
+ k1_index, k2_index = limb
63
+
64
+ H, W = skeleton['height'], skeleton['width']
65
+ keypoints = skeleton['keypoints_body']
66
+ keypoint1 = keypoints[k1_index - 1]
67
+ keypoint2 = keypoints[k2_index - 1]
68
+
69
+ if keypoint1 is None or keypoint2 is None:
70
+ return None, None, None
71
+
72
+ X = np.array([keypoint1[0], keypoint2[0]]) * float(W)
73
+ Y = np.array([keypoint1[1], keypoint2[1]]) * float(H)
74
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
75
+
76
+ return X, Y, length
77
+
78
+
79
+
80
+ def get_handpose_meta(keypoints, delta, src_H, src_W):
81
+
82
+ new_keypoints = []
83
+
84
+ for idx, keypoint in enumerate(keypoints):
85
+ if keypoint is None:
86
+ new_keypoints.append(None)
87
+ continue
88
+ if keypoint.score == 0:
89
+ new_keypoints.append(None)
90
+ continue
91
+
92
+ x, y = keypoint.x, keypoint.y
93
+ x = int(x * src_W + delta[0])
94
+ y = int(y * src_H + delta[1])
95
+
96
+ new_keypoints.append(
97
+ Keypoint(
98
+ x=x,
99
+ y=y,
100
+ score=keypoint.score,
101
+ ))
102
+
103
+ return new_keypoints
104
+
105
+
106
+ def deal_hand_keypoints(hand_res, r_ratio, l_ratio, hand_score_th = 0.5):
107
+
108
+ left_hand = []
109
+ right_hand = []
110
+
111
+ left_delta_x = hand_res['left'][0][0] * (l_ratio - 1)
112
+ left_delta_y = hand_res['left'][0][1] * (l_ratio - 1)
113
+
114
+ right_delta_x = hand_res['right'][0][0] * (r_ratio - 1)
115
+ right_delta_y = hand_res['right'][0][1] * (r_ratio - 1)
116
+
117
+ length = len(hand_res['left'])
118
+
119
+ for i in range(length):
120
+ # left hand
121
+ if hand_res['left'][i][2] < hand_score_th:
122
+ left_hand.append(
123
+ Keypoint(
124
+ x=-1,
125
+ y=-1,
126
+ score=0,
127
+ )
128
+ )
129
+ else:
130
+ left_hand.append(
131
+ Keypoint(
132
+ x=hand_res['left'][i][0] * l_ratio - left_delta_x,
133
+ y=hand_res['left'][i][1] * l_ratio - left_delta_y,
134
+ score = hand_res['left'][i][2]
135
+ )
136
+ )
137
+
138
+ # right hand
139
+ if hand_res['right'][i][2] < hand_score_th:
140
+ right_hand.append(
141
+ Keypoint(
142
+ x=-1,
143
+ y=-1,
144
+ score=0,
145
+ )
146
+ )
147
+ else:
148
+ right_hand.append(
149
+ Keypoint(
150
+ x=hand_res['right'][i][0] * r_ratio - right_delta_x,
151
+ y=hand_res['right'][i][1] * r_ratio - right_delta_y,
152
+ score = hand_res['right'][i][2]
153
+ )
154
+ )
155
+
156
+ return right_hand, left_hand
157
+
158
+
159
+ def get_scaled_pose(canvas, src_canvas, keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y,
160
+ rescaled_src_ground_x, body_flag, id, scale_min, threshold = 0.4):
161
+
162
+ H, W = canvas
163
+ src_H, src_W = src_canvas
164
+
165
+ new_length_list = [ ]
166
+ angle_list = [ ]
167
+
168
+ # keypoints from 0-1 to H/W range
169
+ for idx in range(len(keypoints)):
170
+ if keypoints[idx] is None or len(keypoints[idx]) == 0:
171
+ continue
172
+
173
+ keypoints[idx] = [keypoints[idx][0] * src_W, keypoints[idx][1] * src_H, keypoints[idx][2]]
174
+
175
+ # first traverse, get new_length_list and angle_list
176
+ for idx, (k1_index, k2_index) in enumerate(limbSeq):
177
+ keypoint1 = keypoints[k1_index - 1]
178
+ keypoint2 = keypoints[k2_index - 1]
179
+
180
+ if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0:
181
+ new_length_list.append(None)
182
+ angle_list.append(None)
183
+ continue
184
+
185
+ Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W)
186
+ X = np.array([keypoint1[1], keypoint2[1]]) #* float(H)
187
+
188
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
189
+
190
+ new_length = length * bone_ratio_list[idx]
191
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
192
+
193
+ new_length_list.append(new_length)
194
+ angle_list.append(angle)
195
+
196
+ # Keep foot length within 0.5x calf length
197
+ foot_lower_leg_ratio = 0.5
198
+ if new_length_list[8] != None and new_length_list[18] != None:
199
+ if new_length_list[18] > new_length_list[8] * foot_lower_leg_ratio:
200
+ new_length_list[18] = new_length_list[8] * foot_lower_leg_ratio
201
+
202
+ if new_length_list[11] != None and new_length_list[17] != None:
203
+ if new_length_list[17] > new_length_list[11] * foot_lower_leg_ratio:
204
+ new_length_list[17] = new_length_list[11] * foot_lower_leg_ratio
205
+
206
+ # second traverse, calculate new keypoints
207
+ rescale_keypoints = keypoints.copy()
208
+
209
+ for idx, (k1_index, k2_index) in enumerate(limbSeq):
210
+ # update dst_keypoints
211
+ start_keypoint = rescale_keypoints[k1_index - 1]
212
+ new_length = new_length_list[idx]
213
+ angle = angle_list[idx]
214
+
215
+ if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \
216
+ len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0:
217
+ continue
218
+
219
+ # calculate end_keypoint
220
+ delta_x = new_length * math.cos(math.radians(angle))
221
+ delta_y = new_length * math.sin(math.radians(angle))
222
+
223
+ end_keypoint_x = start_keypoint[0] - delta_x
224
+ end_keypoint_y = start_keypoint[1] - delta_y
225
+
226
+ # update keypoints
227
+ rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y, rescale_keypoints[k2_index - 1][2]]
228
+
229
+ if id == 0:
230
+ if body_flag == 'full_body' and rescale_keypoints[8] != None and rescale_keypoints[11] != None:
231
+ delta_ground_x_offset_first_frame = (rescale_keypoints[8][0] + rescale_keypoints[11][0]) / 2 - rescaled_src_ground_x
232
+ delta_ground_x += delta_ground_x_offset_first_frame
233
+ elif body_flag == 'half_body' and rescale_keypoints[1] != None:
234
+ delta_ground_x_offset_first_frame = rescale_keypoints[1][0] - rescaled_src_ground_x
235
+ delta_ground_x += delta_ground_x_offset_first_frame
236
+
237
+ # offset all keypoints
238
+ for idx in range(len(rescale_keypoints)):
239
+ if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0 :
240
+ continue
241
+ rescale_keypoints[idx][0] -= delta_ground_x
242
+ rescale_keypoints[idx][1] -= delta_ground_y
243
+
244
+ # rescale keypoints to original size
245
+ rescale_keypoints[idx][0] /= scale_min
246
+ rescale_keypoints[idx][1] /= scale_min
247
+
248
+ # Scale hand proportions based on body skeletal ratios
249
+ r_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min
250
+ l_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min
251
+ left_hand, right_hand = deal_hand_keypoints(keypoints_hand, r_ratio, l_ratio, hand_score_th = threshold)
252
+
253
+ left_hand_new = left_hand.copy()
254
+ right_hand_new = right_hand.copy()
255
+
256
+ if rescale_keypoints[4] == None and rescale_keypoints[7] == None:
257
+ pass
258
+
259
+ elif rescale_keypoints[4] == None and rescale_keypoints[7] != None:
260
+ right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2])
261
+ right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W)
262
+
263
+ elif rescale_keypoints[4] != None and rescale_keypoints[7] == None:
264
+ left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2])
265
+ left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W)
266
+
267
+ else:
268
+ # get left_hand and right_hand offset
269
+ left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2])
270
+ right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2])
271
+
272
+ if keypoints[4][0] != None and left_hand[0].x != -1:
273
+ left_hand_root_offset = np.array( ( keypoints[4][0] - left_hand[0].x * src_W, keypoints[4][1] - left_hand[0].y * src_H))
274
+ left_hand_delta += left_hand_root_offset
275
+
276
+ if keypoints[7][0] != None and right_hand[0].x != -1:
277
+ right_hand_root_offset = np.array( ( keypoints[7][0] - right_hand[0].x * src_W, keypoints[7][1] - right_hand[0].y * src_H))
278
+ right_hand_delta += right_hand_root_offset
279
+
280
+ dis_left_hand = ((keypoints[4][0] - left_hand[0].x * src_W) ** 2 + (keypoints[4][1] - left_hand[0].y * src_H) ** 2) ** 0.5
281
+ dis_right_hand = ((keypoints[7][0] - left_hand[0].x * src_W) ** 2 + (keypoints[7][1] - left_hand[0].y * src_H) ** 2) ** 0.5
282
+
283
+ if dis_left_hand > dis_right_hand:
284
+ right_hand_new = get_handpose_meta(left_hand, right_hand_delta, src_H, src_W)
285
+ left_hand_new = get_handpose_meta(right_hand, left_hand_delta, src_H, src_W)
286
+ else:
287
+ left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W)
288
+ right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W)
289
+
290
+ # get normalized keypoints_body
291
+ norm_body_keypoints = [ ]
292
+ for body_keypoint in rescale_keypoints:
293
+ if body_keypoint != None:
294
+ norm_body_keypoints.append([body_keypoint[0] / W , body_keypoint[1] / H, body_keypoint[2]])
295
+ else:
296
+ norm_body_keypoints.append(None)
297
+
298
+ frame_info = {
299
+ 'height': H,
300
+ 'width': W,
301
+ 'keypoints_body': norm_body_keypoints,
302
+ 'keypoints_left_hand' : left_hand_new,
303
+ 'keypoints_right_hand' : right_hand_new,
304
+ }
305
+
306
+ return frame_info
307
+
308
+
309
+ def rescale_skeleton(H, W, keypoints, bone_ratio_list):
310
+
311
+ rescale_keypoints = keypoints.copy()
312
+
313
+ new_length_list = [ ]
314
+ angle_list = [ ]
315
+
316
+ # keypoints from 0-1 to H/W range
317
+ for idx in range(len(rescale_keypoints)):
318
+ if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0:
319
+ continue
320
+
321
+ rescale_keypoints[idx] = [rescale_keypoints[idx][0] * W, rescale_keypoints[idx][1] * H]
322
+
323
+ # first traverse, get new_length_list and angle_list
324
+ for idx, (k1_index, k2_index) in enumerate(limbSeq):
325
+ keypoint1 = rescale_keypoints[k1_index - 1]
326
+ keypoint2 = rescale_keypoints[k2_index - 1]
327
+
328
+ if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0:
329
+ new_length_list.append(None)
330
+ angle_list.append(None)
331
+ continue
332
+
333
+ Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W)
334
+ X = np.array([keypoint1[1], keypoint2[1]]) #* float(H)
335
+
336
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
337
+
338
+
339
+ new_length = length * bone_ratio_list[idx]
340
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
341
+
342
+ new_length_list.append(new_length)
343
+ angle_list.append(angle)
344
+
345
+ # # second traverse, calculate new keypoints
346
+ for idx, (k1_index, k2_index) in enumerate(limbSeq):
347
+ # update dst_keypoints
348
+ start_keypoint = rescale_keypoints[k1_index - 1]
349
+ new_length = new_length_list[idx]
350
+ angle = angle_list[idx]
351
+
352
+ if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \
353
+ len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0:
354
+ continue
355
+
356
+ # calculate end_keypoint
357
+ delta_x = new_length * math.cos(math.radians(angle))
358
+ delta_y = new_length * math.sin(math.radians(angle))
359
+
360
+ end_keypoint_x = start_keypoint[0] - delta_x
361
+ end_keypoint_y = start_keypoint[1] - delta_y
362
+
363
+ # update keypoints
364
+ rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y]
365
+
366
+ return rescale_keypoints
367
+
368
+
369
+ def fix_lack_keypoints_use_sym(skeleton):
370
+
371
+ keypoints = skeleton['keypoints_body']
372
+ H, W = skeleton['height'], skeleton['width']
373
+
374
+ limb_points_list = [
375
+ [3, 4, 5],
376
+ [6, 7, 8],
377
+ [12, 13, 14, 19],
378
+ [9, 10, 11, 20],
379
+ ]
380
+
381
+ for limb_points in limb_points_list:
382
+ miss_flag = False
383
+ for point in limb_points:
384
+ if keypoints[point - 1] is None:
385
+ miss_flag = True
386
+ continue
387
+ if miss_flag:
388
+ skeleton['keypoints_body'][point - 1] = None
389
+
390
+ repair_limb_seq_left = [
391
+ [3, 4], [4, 5], # left arm
392
+ [12, 13], [13, 14], # left leg
393
+ [14, 19] # left foot
394
+ ]
395
+
396
+ repair_limb_seq_right = [
397
+ [6, 7], [7, 8], # right arm
398
+ [9, 10], [10, 11], # right leg
399
+ [11, 20] # right foot
400
+ ]
401
+
402
+ repair_limb_seq = [repair_limb_seq_left, repair_limb_seq_right]
403
+
404
+ for idx_part, part in enumerate(repair_limb_seq):
405
+ for idx, limb in enumerate(part):
406
+
407
+ k1_index, k2_index = limb
408
+ keypoint1 = keypoints[k1_index - 1]
409
+ keypoint2 = keypoints[k2_index - 1]
410
+
411
+ if keypoint1 != None and keypoint2 is None:
412
+ # reference to symmetric limb
413
+ sym_limb = repair_limb_seq[1-idx_part][idx]
414
+ k1_index_sym, k2_index_sym = sym_limb
415
+ keypoint1_sym = keypoints[k1_index_sym - 1]
416
+ keypoint2_sym = keypoints[k2_index_sym - 1]
417
+ ref_length = 0
418
+
419
+ if keypoint1_sym != None and keypoint2_sym != None:
420
+ X = np.array([keypoint1_sym[0], keypoint2_sym[0]]) * float(W)
421
+ Y = np.array([keypoint1_sym[1], keypoint2_sym[1]]) * float(H)
422
+ ref_length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
423
+ else:
424
+ ref_length_left, ref_length_right = 0, 0
425
+ if keypoints[1] != None and keypoints[8] != None:
426
+ X = np.array([keypoints[1][0], keypoints[8][0]]) * float(W)
427
+ Y = np.array([keypoints[1][1], keypoints[8][1]]) * float(H)
428
+ ref_length_left = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
429
+ if idx <= 1: # arms
430
+ ref_length_left /= 2
431
+
432
+ if keypoints[1] != None and keypoints[11] != None:
433
+ X = np.array([keypoints[1][0], keypoints[11][0]]) * float(W)
434
+ Y = np.array([keypoints[1][1], keypoints[11][1]]) * float(H)
435
+ ref_length_right = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
436
+ if idx <= 1: # arms
437
+ ref_length_right /= 2
438
+ elif idx == 4: # foot
439
+ ref_length_right /= 5
440
+
441
+ ref_length = max(ref_length_left, ref_length_right)
442
+
443
+ if ref_length != 0:
444
+ skeleton['keypoints_body'][k2_index - 1] = [0, 0] #init
445
+ skeleton['keypoints_body'][k2_index - 1][0] = skeleton['keypoints_body'][k1_index - 1][0]
446
+ skeleton['keypoints_body'][k2_index - 1][1] = skeleton['keypoints_body'][k1_index - 1][1] + ref_length / H
447
+ return skeleton
448
+
449
+
450
+ def rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list):
451
+
452
+ modify_bone_list = [
453
+ [0, 1],
454
+ [2, 4],
455
+ [3, 5],
456
+ [6, 9],
457
+ [7, 10],
458
+ [8, 11],
459
+ [17, 18]
460
+ ]
461
+
462
+ for modify_bone in modify_bone_list:
463
+ new_ratio = max(ratio_list[modify_bone[0]], ratio_list[modify_bone[1]])
464
+ ratio_list[modify_bone[0]] = new_ratio
465
+ ratio_list[modify_bone[1]] = new_ratio
466
+
467
+ if ratio_list[13]!= None and ratio_list[15]!= None:
468
+ ratio_eye_avg = (ratio_list[13] + ratio_list[15]) / 2
469
+ ratio_list[13] = ratio_eye_avg
470
+ ratio_list[15] = ratio_eye_avg
471
+
472
+ if ratio_list[14]!= None and ratio_list[16]!= None:
473
+ ratio_eye_avg = (ratio_list[14] + ratio_list[16]) / 2
474
+ ratio_list[14] = ratio_eye_avg
475
+ ratio_list[16] = ratio_eye_avg
476
+
477
+ return ratio_list, src_length_list, dst_length_list
478
+
479
+
480
+
481
+ def check_full_body(keypoints, threshold = 0.4):
482
+
483
+ body_flag = 'half_body'
484
+
485
+ # 1. If ankle points exist, confidence is greater than the threshold, and points do not exceed the frame, return full_body
486
+ if keypoints[10] != None and keypoints[13] != None and keypoints[8] != None and keypoints[11] != None:
487
+ if (keypoints[10][1] <= 1 and keypoints[13][1] <= 1) and (keypoints[10][2] >= threshold and keypoints[13][2] >= threshold) and \
488
+ (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold):
489
+ body_flag = 'full_body'
490
+ return body_flag
491
+
492
+ # 2. If hip points exist, return three_quarter_body
493
+ if (keypoints[8] != None and keypoints[11] != None):
494
+ if (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold):
495
+ body_flag = 'three_quarter_body'
496
+ return body_flag
497
+
498
+ return body_flag
499
+
500
+
501
+ def check_full_body_both(flag1, flag2):
502
+ body_flag_dict = {
503
+ 'full_body': 2,
504
+ 'three_quarter_body' : 1,
505
+ 'half_body': 0
506
+ }
507
+
508
+ body_flag_dict_reverse = {
509
+ 2: 'full_body',
510
+ 1: 'three_quarter_body',
511
+ 0: 'half_body'
512
+ }
513
+
514
+ flag1_num = body_flag_dict[flag1]
515
+ flag2_num = body_flag_dict[flag2]
516
+ flag_both_num = min(flag1_num, flag2_num)
517
+ return body_flag_dict_reverse[flag_both_num]
518
+
519
+
520
+ def write_to_poses(data_to_json, none_idx, dst_shape, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, scale_min):
521
+ outputs = []
522
+ length = len(data_to_json)
523
+ for id in tqdm(range(length)):
524
+
525
+ src_height, src_width = data_to_json[id]['height'], data_to_json[id]['width']
526
+ width, height = dst_shape
527
+ keypoints = data_to_json[id]['keypoints_body']
528
+ for idx in range(len(keypoints)):
529
+ if idx in none_idx:
530
+ keypoints[idx] = None
531
+ new_keypoints = keypoints.copy()
532
+
533
+ # get hand keypoints
534
+ keypoints_hand = {'left' : data_to_json[id]['keypoints_left_hand'], 'right' : data_to_json[id]['keypoints_right_hand']}
535
+ # Normalize hand coordinates to 0-1 range
536
+ for hand_idx in range(len(data_to_json[id]['keypoints_left_hand'])):
537
+ data_to_json[id]['keypoints_left_hand'][hand_idx][0] = data_to_json[id]['keypoints_left_hand'][hand_idx][0] / src_width
538
+ data_to_json[id]['keypoints_left_hand'][hand_idx][1] = data_to_json[id]['keypoints_left_hand'][hand_idx][1] / src_height
539
+
540
+ for hand_idx in range(len(data_to_json[id]['keypoints_right_hand'])):
541
+ data_to_json[id]['keypoints_right_hand'][hand_idx][0] = data_to_json[id]['keypoints_right_hand'][hand_idx][0] / src_width
542
+ data_to_json[id]['keypoints_right_hand'][hand_idx][1] = data_to_json[id]['keypoints_right_hand'][hand_idx][1] / src_height
543
+
544
+
545
+ frame_info = get_scaled_pose((height, width), (src_height, src_width), new_keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, id, scale_min)
546
+ outputs.append(frame_info)
547
+
548
+ return outputs
549
+
550
+
551
+ def calculate_scale_ratio(skeleton, skeleton_edit, scale_ratio_flag):
552
+ if scale_ratio_flag:
553
+
554
+ headw = max(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0]) - \
555
+ min(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0])
556
+ headw_edit = max(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0]) - \
557
+ min(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0])
558
+ headw_ratio = headw / headw_edit
559
+
560
+ _, _, shoulder = get_length(skeleton, [6,3])
561
+ _, _, shoulder_edit = get_length(skeleton_edit, [6,3])
562
+ shoulder_ratio = shoulder / shoulder_edit
563
+
564
+ return max(headw_ratio, shoulder_ratio)
565
+
566
+ else:
567
+ return 1
568
+
569
+
570
+
571
+ def retarget_pose(src_skeleton, dst_skeleton, all_src_skeleton, src_skeleton_edit, dst_skeleton_edit, threshold=0.4):
572
+
573
+ if src_skeleton_edit is not None and dst_skeleton_edit is not None:
574
+ use_edit_for_base = True
575
+ else:
576
+ use_edit_for_base = False
577
+
578
+ src_skeleton_ori = copy.deepcopy(src_skeleton)
579
+
580
+ dst_skeleton_ori_h, dst_skeleton_ori_w = dst_skeleton['height'], dst_skeleton['width']
581
+ if src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][10] != None and src_skeleton['keypoints_body'][13] != None and \
582
+ dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][10] != None and dst_skeleton['keypoints_body'][13] != None and \
583
+ src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][10][2] > 0.5 and src_skeleton['keypoints_body'][13][2] > 0.5 and \
584
+ dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][10][2] > 0.5 and dst_skeleton['keypoints_body'][13][2] > 0.5:
585
+
586
+ src_height = src_skeleton['height'] * abs(
587
+ (src_skeleton['keypoints_body'][10][1] + src_skeleton['keypoints_body'][13][1]) / 2 -
588
+ src_skeleton['keypoints_body'][0][1])
589
+ dst_height = dst_skeleton['height'] * abs(
590
+ (dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][13][1]) / 2 -
591
+ dst_skeleton['keypoints_body'][0][1])
592
+ scale_min = 1.0 * src_height / dst_height
593
+ elif src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][8] != None and src_skeleton['keypoints_body'][11] != None and \
594
+ dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][8] != None and dst_skeleton['keypoints_body'][11] != None and \
595
+ src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][8][2] > 0.5 and src_skeleton['keypoints_body'][11][2] > 0.5 and \
596
+ dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][8][2] > 0.5 and dst_skeleton['keypoints_body'][11][2] > 0.5:
597
+
598
+ src_height = src_skeleton['height'] * abs(
599
+ (src_skeleton['keypoints_body'][8][1] + src_skeleton['keypoints_body'][11][1]) / 2 -
600
+ src_skeleton['keypoints_body'][0][1])
601
+ dst_height = dst_skeleton['height'] * abs(
602
+ (dst_skeleton['keypoints_body'][8][1] + dst_skeleton['keypoints_body'][11][1]) / 2 -
603
+ dst_skeleton['keypoints_body'][0][1])
604
+ scale_min = 1.0 * src_height / dst_height
605
+ else:
606
+ scale_min = np.sqrt(src_skeleton['height'] * src_skeleton['width']) / np.sqrt(dst_skeleton['height'] * dst_skeleton['width'])
607
+
608
+ if use_edit_for_base:
609
+ scale_ratio_flag = False
610
+ if src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][10] != None and src_skeleton_edit['keypoints_body'][13] != None and \
611
+ dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][10] != None and dst_skeleton_edit['keypoints_body'][13] != None and \
612
+ src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][10][2] > 0.5 and src_skeleton_edit['keypoints_body'][13][2] > 0.5 and \
613
+ dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][10][2] > 0.5 and dst_skeleton_edit['keypoints_body'][13][2] > 0.5:
614
+
615
+ src_height_edit = src_skeleton_edit['height'] * abs(
616
+ (src_skeleton_edit['keypoints_body'][10][1] + src_skeleton_edit['keypoints_body'][13][1]) / 2 -
617
+ src_skeleton_edit['keypoints_body'][0][1])
618
+ dst_height_edit = dst_skeleton_edit['height'] * abs(
619
+ (dst_skeleton_edit['keypoints_body'][10][1] + dst_skeleton_edit['keypoints_body'][13][1]) / 2 -
620
+ dst_skeleton_edit['keypoints_body'][0][1])
621
+ scale_min_edit = 1.0 * src_height_edit / dst_height_edit
622
+ elif src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][8] != None and src_skeleton_edit['keypoints_body'][11] != None and \
623
+ dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][8] != None and dst_skeleton_edit['keypoints_body'][11] != None and \
624
+ src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][8][2] > 0.5 and src_skeleton_edit['keypoints_body'][11][2] > 0.5 and \
625
+ dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][8][2] > 0.5 and dst_skeleton_edit['keypoints_body'][11][2] > 0.5:
626
+
627
+ src_height_edit = src_skeleton_edit['height'] * abs(
628
+ (src_skeleton_edit['keypoints_body'][8][1] + src_skeleton_edit['keypoints_body'][11][1]) / 2 -
629
+ src_skeleton_edit['keypoints_body'][0][1])
630
+ dst_height_edit = dst_skeleton_edit['height'] * abs(
631
+ (dst_skeleton_edit['keypoints_body'][8][1] + dst_skeleton_edit['keypoints_body'][11][1]) / 2 -
632
+ dst_skeleton_edit['keypoints_body'][0][1])
633
+ scale_min_edit = 1.0 * src_height_edit / dst_height_edit
634
+ else:
635
+ scale_min_edit = np.sqrt(src_skeleton_edit['height'] * src_skeleton_edit['width']) / np.sqrt(dst_skeleton_edit['height'] * dst_skeleton_edit['width'])
636
+ scale_ratio_flag = True
637
+
638
+ # Flux may change the scale, compensate for it here
639
+ ratio_src = calculate_scale_ratio(src_skeleton, src_skeleton_edit, scale_ratio_flag)
640
+ ratio_dst = calculate_scale_ratio(dst_skeleton, dst_skeleton_edit, scale_ratio_flag)
641
+
642
+ dst_skeleton_edit['height'] = int(dst_skeleton_edit['height'] * scale_min_edit)
643
+ dst_skeleton_edit['width'] = int(dst_skeleton_edit['width'] * scale_min_edit)
644
+ for idx in range(len(dst_skeleton_edit['keypoints_left_hand'])):
645
+ dst_skeleton_edit['keypoints_left_hand'][idx][0] *= scale_min_edit
646
+ dst_skeleton_edit['keypoints_left_hand'][idx][1] *= scale_min_edit
647
+ for idx in range(len(dst_skeleton_edit['keypoints_right_hand'])):
648
+ dst_skeleton_edit['keypoints_right_hand'][idx][0] *= scale_min_edit
649
+ dst_skeleton_edit['keypoints_right_hand'][idx][1] *= scale_min_edit
650
+
651
+
652
+ dst_skeleton['height'] = int(dst_skeleton['height'] * scale_min)
653
+ dst_skeleton['width'] = int(dst_skeleton['width'] * scale_min)
654
+ for idx in range(len(dst_skeleton['keypoints_left_hand'])):
655
+ dst_skeleton['keypoints_left_hand'][idx][0] *= scale_min
656
+ dst_skeleton['keypoints_left_hand'][idx][1] *= scale_min
657
+ for idx in range(len(dst_skeleton['keypoints_right_hand'])):
658
+ dst_skeleton['keypoints_right_hand'][idx][0] *= scale_min
659
+ dst_skeleton['keypoints_right_hand'][idx][1] *= scale_min
660
+
661
+
662
+ dst_body_flag = check_full_body(dst_skeleton['keypoints_body'], threshold)
663
+ src_body_flag = check_full_body(src_skeleton_ori['keypoints_body'], threshold)
664
+ body_flag = check_full_body_both(dst_body_flag, src_body_flag)
665
+ #print('body_flag: ', body_flag)
666
+
667
+ if use_edit_for_base:
668
+ src_skeleton_edit = fix_lack_keypoints_use_sym(src_skeleton_edit)
669
+ dst_skeleton_edit = fix_lack_keypoints_use_sym(dst_skeleton_edit)
670
+ else:
671
+ src_skeleton = fix_lack_keypoints_use_sym(src_skeleton)
672
+ dst_skeleton = fix_lack_keypoints_use_sym(dst_skeleton)
673
+
674
+ none_idx = []
675
+ for idx in range(len(dst_skeleton['keypoints_body'])):
676
+ if dst_skeleton['keypoints_body'][idx] == None or src_skeleton['keypoints_body'][idx] == None:
677
+ src_skeleton['keypoints_body'][idx] = None
678
+ dst_skeleton['keypoints_body'][idx] = None
679
+ none_idx.append(idx)
680
+
681
+ # get bone ratio list
682
+ ratio_list, src_length_list, dst_length_list = [], [], []
683
+ for idx, limb in enumerate(limbSeq):
684
+ if use_edit_for_base:
685
+ src_X, src_Y, src_length = get_length(src_skeleton_edit, limb)
686
+ dst_X, dst_Y, dst_length = get_length(dst_skeleton_edit, limb)
687
+
688
+ if src_X is None or src_Y is None or dst_X is None or dst_Y is None:
689
+ ratio = -1
690
+ else:
691
+ ratio = 1.0 * dst_length * ratio_dst / src_length / ratio_src
692
+
693
+ else:
694
+ src_X, src_Y, src_length = get_length(src_skeleton, limb)
695
+ dst_X, dst_Y, dst_length = get_length(dst_skeleton, limb)
696
+
697
+ if src_X is None or src_Y is None or dst_X is None or dst_Y is None:
698
+ ratio = -1
699
+ else:
700
+ ratio = 1.0 * dst_length / src_length
701
+
702
+ ratio_list.append(ratio)
703
+ src_length_list.append(src_length)
704
+ dst_length_list.append(dst_length)
705
+
706
+ for idx, ratio in enumerate(ratio_list):
707
+ if ratio == -1:
708
+ if ratio_list[0] != -1 and ratio_list[1] != -1:
709
+ ratio_list[idx] = (ratio_list[0] + ratio_list[1]) / 2
710
+
711
+ # Consider adding constraints when Flux fails to correct head pose, causing neck issues.
712
+ # if ratio_list[12] > (ratio_list[0]+ratio_list[1])/2*1.25:
713
+ # ratio_list[12] = (ratio_list[0]+ratio_list[1])/2*1.25
714
+
715
+ ratio_list, src_length_list, dst_length_list = rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list)
716
+
717
+ rescaled_src_skeleton_ori = rescale_skeleton(src_skeleton_ori['height'], src_skeleton_ori['width'],
718
+ src_skeleton_ori['keypoints_body'], ratio_list)
719
+
720
+ # get global translation offset_x and offset_y
721
+ if body_flag == 'full_body':
722
+ #print('use foot mark.')
723
+ dst_ground_y = max(dst_skeleton['keypoints_body'][10][1], dst_skeleton['keypoints_body'][13][1]) * dst_skeleton[
724
+ 'height']
725
+ # The midpoint between toe and ankle
726
+ if dst_skeleton['keypoints_body'][18] != None and dst_skeleton['keypoints_body'][19] != None:
727
+ right_foot_mid = (dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][19][1]) / 2
728
+ left_foot_mid = (dst_skeleton['keypoints_body'][13][1] + dst_skeleton['keypoints_body'][18][1]) / 2
729
+ dst_ground_y = max(left_foot_mid, right_foot_mid) * dst_skeleton['height']
730
+
731
+ rescaled_src_ground_y = max(rescaled_src_skeleton_ori[10][1], rescaled_src_skeleton_ori[13][1])
732
+ delta_ground_y = rescaled_src_ground_y - dst_ground_y
733
+
734
+ dst_ground_x = (dst_skeleton['keypoints_body'][8][0] + dst_skeleton['keypoints_body'][11][0]) * dst_skeleton[
735
+ 'width'] / 2
736
+ rescaled_src_ground_x = (rescaled_src_skeleton_ori[8][0] + rescaled_src_skeleton_ori[11][0]) / 2
737
+ delta_ground_x = rescaled_src_ground_x - dst_ground_x
738
+ delta_x, delta_y = delta_ground_x, delta_ground_y
739
+
740
+ else:
741
+ #print('use neck mark.')
742
+ # use neck keypoint as mark
743
+ src_neck_y = rescaled_src_skeleton_ori[1][1]
744
+ dst_neck_y = dst_skeleton['keypoints_body'][1][1]
745
+ delta_neck_y = src_neck_y - dst_neck_y * dst_skeleton['height']
746
+
747
+ src_neck_x = rescaled_src_skeleton_ori[1][0]
748
+ dst_neck_x = dst_skeleton['keypoints_body'][1][0]
749
+ delta_neck_x = src_neck_x - dst_neck_x * dst_skeleton['width']
750
+ delta_x, delta_y = delta_neck_x, delta_neck_y
751
+ rescaled_src_ground_x = src_neck_x
752
+
753
+
754
+ dst_shape = (dst_skeleton_ori_w, dst_skeleton_ori_h)
755
+ output = write_to_poses(all_src_skeleton, none_idx, dst_shape, ratio_list, delta_x, delta_y,
756
+ rescaled_src_ground_x, body_flag, scale_min)
757
+ return output
758
+
759
+
760
+ def get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tql_edit_pose_meta0, refer_edit_pose_meta):
761
+
762
+ for key, value in tpl_pose_meta0.items():
763
+ if type(value) is np.ndarray:
764
+ if key in ['keypoints_left_hand', 'keypoints_right_hand']:
765
+ value = value * np.array([[tpl_pose_meta0["width"], tpl_pose_meta0["height"], 1.0]])
766
+ if not isinstance(value, list):
767
+ value = value.tolist()
768
+ tpl_pose_meta0[key] = value
769
+
770
+ for key, value in refer_pose_meta.items():
771
+ if type(value) is np.ndarray:
772
+ if key in ['keypoints_left_hand', 'keypoints_right_hand']:
773
+ value = value * np.array([[refer_pose_meta["width"], refer_pose_meta["height"], 1.0]])
774
+ if not isinstance(value, list):
775
+ value = value.tolist()
776
+ refer_pose_meta[key] = value
777
+
778
+ tpl_pose_metas_new = []
779
+ for meta in tpl_pose_metas:
780
+ for key, value in meta.items():
781
+ if type(value) is np.ndarray:
782
+ if key in ['keypoints_left_hand', 'keypoints_right_hand']:
783
+ value = value * np.array([[meta["width"], meta["height"], 1.0]])
784
+ if not isinstance(value, list):
785
+ value = value.tolist()
786
+ meta[key] = value
787
+ tpl_pose_metas_new.append(meta)
788
+
789
+ if tql_edit_pose_meta0 is not None:
790
+ for key, value in tql_edit_pose_meta0.items():
791
+ if type(value) is np.ndarray:
792
+ if key in ['keypoints_left_hand', 'keypoints_right_hand']:
793
+ value = value * np.array([[tql_edit_pose_meta0["width"], tql_edit_pose_meta0["height"], 1.0]])
794
+ if not isinstance(value, list):
795
+ value = value.tolist()
796
+ tql_edit_pose_meta0[key] = value
797
+
798
+ if refer_edit_pose_meta is not None:
799
+ for key, value in refer_edit_pose_meta.items():
800
+ if type(value) is np.ndarray:
801
+ if key in ['keypoints_left_hand', 'keypoints_right_hand']:
802
+ value = value * np.array([[refer_edit_pose_meta["width"], refer_edit_pose_meta["height"], 1.0]])
803
+ if not isinstance(value, list):
804
+ value = value.tolist()
805
+ refer_edit_pose_meta[key] = value
806
+
807
+ retarget_tpl_pose_metas = retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas_new, tql_edit_pose_meta0, refer_edit_pose_meta)
808
+
809
+ pose_metas = []
810
+ for meta in retarget_tpl_pose_metas:
811
+ pose_meta = AAPoseMeta()
812
+ width, height = meta["width"], meta["height"]
813
+ pose_meta.width = width
814
+ pose_meta.height = height
815
+ pose_meta.kps_body = np.array(meta["keypoints_body"])[:, :2] * (width, height)
816
+ pose_meta.kps_body_p = np.array(meta["keypoints_body"])[:, 2]
817
+
818
+ kps_lhand = []
819
+ kps_lhand_p = []
820
+ for each_kps_lhand in meta["keypoints_left_hand"]:
821
+ if each_kps_lhand is not None:
822
+ kps_lhand.append([each_kps_lhand.x, each_kps_lhand.y])
823
+ kps_lhand_p.append(each_kps_lhand.score)
824
+ else:
825
+ kps_lhand.append([None, None])
826
+ kps_lhand_p.append(0.0)
827
+
828
+ pose_meta.kps_lhand = np.array(kps_lhand)
829
+ pose_meta.kps_lhand_p = np.array(kps_lhand_p)
830
+
831
+ kps_rhand = []
832
+ kps_rhand_p = []
833
+ for each_kps_rhand in meta["keypoints_right_hand"]:
834
+ if each_kps_rhand is not None:
835
+ kps_rhand.append([each_kps_rhand.x, each_kps_rhand.y])
836
+ kps_rhand_p.append(each_kps_rhand.score)
837
+ else:
838
+ kps_rhand.append([None, None])
839
+ kps_rhand_p.append(0.0)
840
+
841
+ pose_meta.kps_rhand = np.array(kps_rhand)
842
+ pose_meta.kps_rhand_p = np.array(kps_rhand_p)
843
+
844
+ pose_metas.append(pose_meta)
845
+
846
+ return pose_metas
847
+
wan/modules/animate/preprocess/sam_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025. Your modifications here.
2
+ # This file wraps and extends sam2.utils.misc for custom modifications.
3
+
4
+ from sam2.utils import misc as sam2_misc
5
+ from sam2.utils.misc import *
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+ import os
11
+
12
+ import logging
13
+
14
+ import torch
15
+ from hydra import compose
16
+ from hydra.utils import instantiate
17
+ from omegaconf import OmegaConf
18
+
19
+ from sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor
20
+ from sam2.build_sam import _load_checkpoint
21
+
22
+
23
+ def _load_img_v2_as_tensor(img, image_size):
24
+ img_pil = Image.fromarray(img.astype(np.uint8))
25
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
26
+ if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
27
+ img_np = img_np / 255.0
28
+ else:
29
+ raise RuntimeError(f"Unknown image dtype: {img_np.dtype}")
30
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
31
+ video_width, video_height = img_pil.size # the original video size
32
+ return img, video_height, video_width
33
+
34
+ def load_video_frames(
35
+ video_path,
36
+ image_size,
37
+ offload_video_to_cpu,
38
+ img_mean=(0.485, 0.456, 0.406),
39
+ img_std=(0.229, 0.224, 0.225),
40
+ async_loading_frames=False,
41
+ frame_names=None,
42
+ ):
43
+ """
44
+ Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
45
+
46
+ The frames are resized to image_size x image_size and are loaded to GPU if
47
+ `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
48
+
49
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
50
+ """
51
+ if isinstance(video_path, str) and os.path.isdir(video_path):
52
+ jpg_folder = video_path
53
+ else:
54
+ raise NotImplementedError("Only JPEG frames are supported at this moment")
55
+ if frame_names is None:
56
+ frame_names = [
57
+ p
58
+ for p in os.listdir(jpg_folder)
59
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
60
+ ]
61
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
62
+
63
+ num_frames = len(frame_names)
64
+ if num_frames == 0:
65
+ raise RuntimeError(f"no images found in {jpg_folder}")
66
+ img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
67
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
68
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
69
+
70
+ if async_loading_frames:
71
+ lazy_images = AsyncVideoFrameLoader(
72
+ img_paths, image_size, offload_video_to_cpu, img_mean, img_std
73
+ )
74
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
75
+
76
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
77
+ for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
78
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
79
+ if not offload_video_to_cpu:
80
+ images = images.cuda()
81
+ img_mean = img_mean.cuda()
82
+ img_std = img_std.cuda()
83
+ # normalize by mean and std
84
+ images -= img_mean
85
+ images /= img_std
86
+ return images, video_height, video_width
87
+
88
+
89
+ def load_video_frames_v2(
90
+ frames,
91
+ image_size,
92
+ offload_video_to_cpu,
93
+ img_mean=(0.485, 0.456, 0.406),
94
+ img_std=(0.229, 0.224, 0.225),
95
+ async_loading_frames=False,
96
+ frame_names=None,
97
+ ):
98
+ """
99
+ Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
100
+
101
+ The frames are resized to image_size x image_size and are loaded to GPU if
102
+ `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
103
+
104
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
105
+ """
106
+ num_frames = len(frames)
107
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
108
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
109
+
110
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
111
+ for n, frame in enumerate(tqdm(frames, desc="video frame")):
112
+ images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size)
113
+ if not offload_video_to_cpu:
114
+ images = images.cuda()
115
+ img_mean = img_mean.cuda()
116
+ img_std = img_std.cuda()
117
+ # normalize by mean and std
118
+ images -= img_mean
119
+ images /= img_std
120
+ return images, video_height, video_width
121
+
122
+ def build_sam2_video_predictor(
123
+ config_file,
124
+ ckpt_path=None,
125
+ device="cuda",
126
+ mode="eval",
127
+ hydra_overrides_extra=[],
128
+ apply_postprocessing=True,
129
+ ):
130
+ hydra_overrides = [
131
+ "++model._target_=video_predictor.SAM2VideoPredictor",
132
+ ]
133
+ if apply_postprocessing:
134
+ hydra_overrides_extra = hydra_overrides_extra.copy()
135
+ hydra_overrides_extra += [
136
+ # dynamically fall back to multi-mask if the single mask is not stable
137
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
138
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
139
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
140
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
141
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
142
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
143
+ "++model.fill_hole_area=8",
144
+ ]
145
+
146
+ hydra_overrides.extend(hydra_overrides_extra)
147
+ # Read config and init model
148
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
149
+ OmegaConf.resolve(cfg)
150
+ model = instantiate(cfg.model, _recursive_=True)
151
+ _load_checkpoint(model, ckpt_path)
152
+ model = model.to(device)
153
+ if mode == "eval":
154
+ model.eval()
155
+ return model
wan/modules/animate/preprocess/utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import cv2
4
+ import math
5
+ import random
6
+ import numpy as np
7
+
8
+ def get_mask_boxes(mask):
9
+ """
10
+
11
+ Args:
12
+ mask: [h, w]
13
+ Returns:
14
+
15
+ """
16
+ y_coords, x_coords = np.nonzero(mask)
17
+ x_min = x_coords.min()
18
+ x_max = x_coords.max()
19
+ y_min = y_coords.min()
20
+ y_max = y_coords.max()
21
+ bbox = np.array([x_min, y_min, x_max, y_max]).astype(np.int32)
22
+ return bbox
23
+
24
+
25
+ def get_aug_mask(body_mask, w_len=10, h_len=20):
26
+ body_bbox = get_mask_boxes(body_mask)
27
+
28
+ bbox_wh = body_bbox[2:4] - body_bbox[0:2]
29
+ w_slice = np.int32(bbox_wh[0] / w_len)
30
+ h_slice = np.int32(bbox_wh[1] / h_len)
31
+
32
+ for each_w in range(body_bbox[0], body_bbox[2], w_slice):
33
+ w_start = min(each_w, body_bbox[2])
34
+ w_end = min((each_w + w_slice), body_bbox[2])
35
+ # print(w_start, w_end)
36
+ for each_h in range(body_bbox[1], body_bbox[3], h_slice):
37
+ h_start = min(each_h, body_bbox[3])
38
+ h_end = min((each_h + h_slice), body_bbox[3])
39
+ if body_mask[h_start:h_end, w_start:w_end].sum() > 0:
40
+ body_mask[h_start:h_end, w_start:w_end] = 1
41
+
42
+ return body_mask
43
+
44
+ def get_mask_body_img(img_copy, hand_mask, k=7, iterations=1):
45
+ kernel = np.ones((k, k), np.uint8)
46
+ dilation = cv2.dilate(hand_mask, kernel, iterations=iterations)
47
+ mask_hand_img = img_copy * (1 - dilation[:, :, None])
48
+
49
+ return mask_hand_img, dilation
50
+
51
+
52
+ def get_face_bboxes(kp2ds, scale, image_shape, ratio_aug):
53
+ h, w = image_shape
54
+ kp2ds_face = kp2ds.copy()[23:91, :2]
55
+
56
+ min_x, min_y = np.min(kp2ds_face, axis=0)
57
+ max_x, max_y = np.max(kp2ds_face, axis=0)
58
+
59
+
60
+ initial_width = max_x - min_x
61
+ initial_height = max_y - min_y
62
+
63
+ initial_area = initial_width * initial_height
64
+
65
+ expanded_area = initial_area * scale
66
+
67
+ new_width = np.sqrt(expanded_area * (initial_width / initial_height))
68
+ new_height = np.sqrt(expanded_area * (initial_height / initial_width))
69
+
70
+ delta_width = (new_width - initial_width) / 2
71
+ delta_height = (new_height - initial_height) / 4
72
+
73
+ if ratio_aug:
74
+ if random.random() > 0.5:
75
+ delta_width += random.uniform(0, initial_width // 10)
76
+ else:
77
+ delta_height += random.uniform(0, initial_height // 10)
78
+
79
+ expanded_min_x = max(min_x - delta_width, 0)
80
+ expanded_max_x = min(max_x + delta_width, w)
81
+ expanded_min_y = max(min_y - 3 * delta_height, 0)
82
+ expanded_max_y = min(max_y + delta_height, h)
83
+
84
+ return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
85
+
86
+
87
+ def calculate_new_size(orig_w, orig_h, target_area, divisor=64):
88
+
89
+ target_ratio = orig_w / orig_h
90
+
91
+ def check_valid(w, h):
92
+
93
+ if w <= 0 or h <= 0:
94
+ return False
95
+ return (w * h <= target_area and
96
+ w % divisor == 0 and
97
+ h % divisor == 0)
98
+
99
+ def get_ratio_diff(w, h):
100
+
101
+ return abs(w / h - target_ratio)
102
+
103
+ def round_to_64(value, round_up=False, divisor=64):
104
+
105
+ if round_up:
106
+ return divisor * ((value + (divisor - 1)) // divisor)
107
+ return divisor * (value // divisor)
108
+
109
+ possible_sizes = []
110
+
111
+ max_area_h = int(np.sqrt(target_area / target_ratio))
112
+ max_area_w = int(max_area_h * target_ratio)
113
+
114
+ max_h = round_to_64(max_area_h, round_up=True, divisor=divisor)
115
+ max_w = round_to_64(max_area_w, round_up=True, divisor=divisor)
116
+
117
+ for h in range(divisor, max_h + divisor, divisor):
118
+ ideal_w = h * target_ratio
119
+
120
+ w_down = round_to_64(ideal_w)
121
+ w_up = round_to_64(ideal_w, round_up=True)
122
+
123
+ for w in [w_down, w_up]:
124
+ if check_valid(w, h, divisor):
125
+ possible_sizes.append((w, h, get_ratio_diff(w, h)))
126
+
127
+ if not possible_sizes:
128
+ raise ValueError("Can not find suitable size")
129
+
130
+ possible_sizes.sort(key=lambda x: (-x[0] * x[1], x[2]))
131
+
132
+ best_w, best_h, _ = possible_sizes[0]
133
+ return int(best_w), int(best_h)
134
+
135
+
136
+ def resize_by_area(image, target_area, keep_aspect_ratio=True, divisor=64, padding_color=(0, 0, 0)):
137
+ h, w = image.shape[:2]
138
+ try:
139
+ new_w, new_h = calculate_new_size(w, h, target_area, divisor)
140
+ except:
141
+ aspect_ratio = w / h
142
+
143
+ if keep_aspect_ratio:
144
+ new_h = math.sqrt(target_area / aspect_ratio)
145
+ new_w = target_area / new_h
146
+ else:
147
+ new_w = new_h = math.sqrt(target_area)
148
+
149
+ new_w, new_h = int((new_w // divisor) * divisor), int((new_h // divisor) * divisor)
150
+
151
+ interpolation = cv2.INTER_AREA if (new_w * new_h < w * h) else cv2.INTER_LINEAR
152
+
153
+ resized_image = padding_resize(image, height=new_h, width=new_w, padding_color=padding_color,
154
+ interpolation=interpolation)
155
+ return resized_image
156
+
157
+
158
+ def padding_resize(img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
159
+ ori_height = img_ori.shape[0]
160
+ ori_width = img_ori.shape[1]
161
+ channel = img_ori.shape[2]
162
+
163
+ img_pad = np.zeros((height, width, channel))
164
+ if channel == 1:
165
+ img_pad[:, :, 0] = padding_color[0]
166
+ else:
167
+ img_pad[:, :, 0] = padding_color[0]
168
+ img_pad[:, :, 1] = padding_color[1]
169
+ img_pad[:, :, 2] = padding_color[2]
170
+
171
+ if (ori_height / ori_width) > (height / width):
172
+ new_width = int(height / ori_height * ori_width)
173
+ img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
174
+ padding = int((width - new_width) / 2)
175
+ if len(img.shape) == 2:
176
+ img = img[:, :, np.newaxis]
177
+ img_pad[:, padding: padding + new_width, :] = img
178
+ else:
179
+ new_height = int(width / ori_width * ori_height)
180
+ img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
181
+ padding = int((height - new_height) / 2)
182
+ if len(img.shape) == 2:
183
+ img = img[:, :, np.newaxis]
184
+ img_pad[padding: padding + new_height, :, :] = img
185
+
186
+ img_pad = np.uint8(img_pad)
187
+
188
+ return img_pad
189
+
190
+
191
+ def get_frame_indices(frame_num, video_fps, clip_length, train_fps):
192
+
193
+ start_frame = 0
194
+ times = np.arange(0, clip_length) / train_fps
195
+ frame_indices = start_frame + np.round(times * video_fps).astype(int)
196
+ frame_indices = np.clip(frame_indices, 0, frame_num - 1)
197
+
198
+ return frame_indices.tolist()
199
+
200
+
201
+ def get_face_bboxes(kp2ds, scale, image_shape):
202
+ h, w = image_shape
203
+ kp2ds_face = kp2ds.copy()[1:] * (w, h)
204
+
205
+ min_x, min_y = np.min(kp2ds_face, axis=0)
206
+ max_x, max_y = np.max(kp2ds_face, axis=0)
207
+
208
+ initial_width = max_x - min_x
209
+ initial_height = max_y - min_y
210
+
211
+ initial_area = initial_width * initial_height
212
+
213
+ expanded_area = initial_area * scale
214
+
215
+ new_width = np.sqrt(expanded_area * (initial_width / initial_height))
216
+ new_height = np.sqrt(expanded_area * (initial_height / initial_width))
217
+
218
+ delta_width = (new_width - initial_width) / 2
219
+ delta_height = (new_height - initial_height) / 4
220
+
221
+ expanded_min_x = max(min_x - delta_width, 0)
222
+ expanded_max_x = min(max_x + delta_width, w)
223
+ expanded_min_y = max(min_y - 3 * delta_height, 0)
224
+ expanded_max_y = min(max_y + delta_height, h)
225
+
226
+ return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
wan/modules/animate/preprocess/video_predictor.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025. Your modifications here.
2
+ # A wrapper for sam2 functions
3
+ from collections import OrderedDict
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
8
+ from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor
9
+ from sam2.utils.misc import concat_points, fill_holes_in_mask_scores
10
+
11
+ from sam_utils import load_video_frames_v2, load_video_frames
12
+
13
+
14
+ class SAM2VideoPredictor(_SAM2VideoPredictor):
15
+ def __init__(self, *args, **kwargs):
16
+
17
+ super().__init__(*args, **kwargs)
18
+
19
+ @torch.inference_mode()
20
+ def init_state(
21
+ self,
22
+ video_path,
23
+ offload_video_to_cpu=False,
24
+ offload_state_to_cpu=False,
25
+ async_loading_frames=False,
26
+ frame_names=None
27
+ ):
28
+ """Initialize a inference state."""
29
+ images, video_height, video_width = load_video_frames(
30
+ video_path=video_path,
31
+ image_size=self.image_size,
32
+ offload_video_to_cpu=offload_video_to_cpu,
33
+ async_loading_frames=async_loading_frames,
34
+ frame_names=frame_names
35
+ )
36
+ inference_state = {}
37
+ inference_state["images"] = images
38
+ inference_state["num_frames"] = len(images)
39
+ # whether to offload the video frames to CPU memory
40
+ # turning on this option saves the GPU memory with only a very small overhead
41
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
42
+ # whether to offload the inference state to CPU memory
43
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
44
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
45
+ # and from 24 to 21 when tracking two objects)
46
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
47
+ # the original video height and width, used for resizing final output scores
48
+ inference_state["video_height"] = video_height
49
+ inference_state["video_width"] = video_width
50
+ inference_state["device"] = torch.device("cuda")
51
+ if offload_state_to_cpu:
52
+ inference_state["storage_device"] = torch.device("cpu")
53
+ else:
54
+ inference_state["storage_device"] = torch.device("cuda")
55
+ # inputs on each frame
56
+ inference_state["point_inputs_per_obj"] = {}
57
+ inference_state["mask_inputs_per_obj"] = {}
58
+ # visual features on a small number of recently visited frames for quick interactions
59
+ inference_state["cached_features"] = {}
60
+ # values that don't change across frames (so we only need to hold one copy of them)
61
+ inference_state["constants"] = {}
62
+ # mapping between client-side object id and model-side object index
63
+ inference_state["obj_id_to_idx"] = OrderedDict()
64
+ inference_state["obj_idx_to_id"] = OrderedDict()
65
+ inference_state["obj_ids"] = []
66
+ # A storage to hold the model's tracking results and states on each frame
67
+ inference_state["output_dict"] = {
68
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
69
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
70
+ }
71
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
72
+ inference_state["output_dict_per_obj"] = {}
73
+ # A temporary storage to hold new outputs when user interact with a frame
74
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
75
+ inference_state["temp_output_dict_per_obj"] = {}
76
+ # Frames that already holds consolidated outputs from click or mask inputs
77
+ # (we directly use their consolidated outputs during tracking)
78
+ inference_state["consolidated_frame_inds"] = {
79
+ "cond_frame_outputs": set(), # set containing frame indices
80
+ "non_cond_frame_outputs": set(), # set containing frame indices
81
+ }
82
+ # metadata for each tracking frame (e.g. which direction it's tracked)
83
+ inference_state["tracking_has_started"] = False
84
+ inference_state["frames_already_tracked"] = {}
85
+ # Warm up the visual backbone and cache the image feature on frame 0
86
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
87
+ return inference_state
88
+
89
+ @torch.inference_mode()
90
+ def init_state_v2(
91
+ self,
92
+ frames,
93
+ offload_video_to_cpu=False,
94
+ offload_state_to_cpu=False,
95
+ async_loading_frames=False,
96
+ frame_names=None
97
+ ):
98
+ """Initialize a inference state."""
99
+ images, video_height, video_width = load_video_frames_v2(
100
+ frames=frames,
101
+ image_size=self.image_size,
102
+ offload_video_to_cpu=offload_video_to_cpu,
103
+ async_loading_frames=async_loading_frames,
104
+ frame_names=frame_names
105
+ )
106
+ inference_state = {}
107
+ inference_state["images"] = images
108
+ inference_state["num_frames"] = len(images)
109
+ # whether to offload the video frames to CPU memory
110
+ # turning on this option saves the GPU memory with only a very small overhead
111
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
112
+ # whether to offload the inference state to CPU memory
113
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
114
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
115
+ # and from 24 to 21 when tracking two objects)
116
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
117
+ # the original video height and width, used for resizing final output scores
118
+ inference_state["video_height"] = video_height
119
+ inference_state["video_width"] = video_width
120
+ inference_state["device"] = torch.device("cuda")
121
+ if offload_state_to_cpu:
122
+ inference_state["storage_device"] = torch.device("cpu")
123
+ else:
124
+ inference_state["storage_device"] = torch.device("cuda")
125
+ # inputs on each frame
126
+ inference_state["point_inputs_per_obj"] = {}
127
+ inference_state["mask_inputs_per_obj"] = {}
128
+ # visual features on a small number of recently visited frames for quick interactions
129
+ inference_state["cached_features"] = {}
130
+ # values that don't change across frames (so we only need to hold one copy of them)
131
+ inference_state["constants"] = {}
132
+ # mapping between client-side object id and model-side object index
133
+ inference_state["obj_id_to_idx"] = OrderedDict()
134
+ inference_state["obj_idx_to_id"] = OrderedDict()
135
+ inference_state["obj_ids"] = []
136
+ # A storage to hold the model's tracking results and states on each frame
137
+ inference_state["output_dict"] = {
138
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
139
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
140
+ }
141
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
142
+ inference_state["output_dict_per_obj"] = {}
143
+ # A temporary storage to hold new outputs when user interact with a frame
144
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
145
+ inference_state["temp_output_dict_per_obj"] = {}
146
+ # Frames that already holds consolidated outputs from click or mask inputs
147
+ # (we directly use their consolidated outputs during tracking)
148
+ inference_state["consolidated_frame_inds"] = {
149
+ "cond_frame_outputs": set(), # set containing frame indices
150
+ "non_cond_frame_outputs": set(), # set containing frame indices
151
+ }
152
+ # metadata for each tracking frame (e.g. which direction it's tracked)
153
+ inference_state["tracking_has_started"] = False
154
+ inference_state["frames_already_tracked"] = {}
155
+ # Warm up the visual backbone and cache the image feature on frame 0
156
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
157
+ return inference_state
wan/modules/animate/xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model
wan/modules/attention.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import warnings
3
+ import torch
4
+ from typing import Optional, Tuple
5
+
6
+ try:
7
+ import flash_attn_interface
8
+ FLASH_ATTN_3_AVAILABLE = True
9
+ except ModuleNotFoundError:
10
+ FLASH_ATTN_3_AVAILABLE = False
11
+
12
+ try:
13
+ import flash_attn
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+
19
+ __all__ = [
20
+ 'flash_attention',
21
+ 'attention',
22
+ ]
23
+
24
+
25
+ # ---------------------------
26
+ # Custom op + fake kernel
27
+ # ---------------------------
28
+ from typing import Optional, Sequence # <- add Sequence
29
+
30
+ # ... imports unchanged ...
31
+ from typing import Optional, Sequence
32
+
33
+ @torch.library.custom_op("wan::flash_attention", mutates_args=())
34
+ def _wan_flash_attention_op(
35
+ q: torch.Tensor,
36
+ k: torch.Tensor,
37
+ v: torch.Tensor,
38
+ q_lens: Optional[torch.Tensor] = None,
39
+ k_lens: Optional[torch.Tensor] = None,
40
+ dropout_p: float = 0.0,
41
+ softmax_scale: Optional[float] = None,
42
+ q_scale: Optional[float] = None,
43
+ causal: bool = False,
44
+ # IMPORTANT: schema-friendly default (None), not a tuple
45
+ window_size: Optional[Sequence[int]] = None,
46
+ deterministic: bool = False,
47
+ dtype: torch.dtype = torch.bfloat16,
48
+ version: Optional[int] = None,
49
+ ) -> torch.Tensor:
50
+ half_dtypes = (torch.float16, torch.bfloat16)
51
+ assert dtype in half_dtypes
52
+ assert q.size(-1) <= 256
53
+
54
+ # normalize window_size to a 2-tuple for FA2 API
55
+ if window_size is None:
56
+ ws = (-1, -1)
57
+ else:
58
+ ws = tuple(window_size)
59
+ if len(ws) != 2:
60
+ raise ValueError(f"window_size must have length 2; got {window_size!r}")
61
+
62
+ b, lq, nheads = q.shape[0], q.shape[1], q.shape[2]
63
+ lk = k.shape[1]
64
+ out_dtype = q.dtype
65
+
66
+ def half(x: torch.Tensor) -> torch.Tensor:
67
+ return x if x.dtype in half_dtypes else x.to(dtype)
68
+
69
+ # --- preprocess (unchanged) ---
70
+ if q_lens is None:
71
+ q_flat = half(q.flatten(0, 1))
72
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32)
73
+ else:
74
+ q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
75
+
76
+ if k_lens is None:
77
+ k_flat = half(k.flatten(0, 1))
78
+ v_flat = half(v.flatten(0, 1))
79
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32)
80
+ else:
81
+ k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
82
+ v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
83
+
84
+ q_flat = q_flat.to(v_flat.dtype); k_flat = k_flat.to(v_flat.dtype)
85
+ if q_scale is not None:
86
+ q_flat = q_flat * q_scale
87
+
88
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89
+ warnings.warn('Flash attention 3 is not available, use flash attention 2 instead.')
90
+
91
+ if FLASH_ATTN_3_AVAILABLE:
92
+ ret = flash_attn_interface.flash_attn_varlen_func(
93
+ q=q_flat,
94
+ k=k_flat,
95
+ v=v_flat,
96
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
97
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(k_flat.device, non_blocking=True),
98
+ seqused_q=None,
99
+ seqused_k=None,
100
+ max_seqlen_q=lq,
101
+ max_seqlen_k=lk,
102
+ softmax_scale=softmax_scale,
103
+ causal=causal,
104
+ deterministic=deterministic,
105
+ )
106
+ out0 = ret[0] if isinstance(ret, (tuple, list)) else ret
107
+ total_q = b * lq
108
+ if out0.dim() != 3:
109
+ raise RuntimeError(f"Unexpected FA3 output rank {out0.dim()} shape={tuple(out0.shape)}")
110
+ if out0.shape[0] == total_q:
111
+ out_flat = out0
112
+ elif out0.shape[0] == nheads and out0.shape[1] == total_q:
113
+ out_flat = out0.transpose(0, 1).contiguous()
114
+ else:
115
+ raise RuntimeError(f"Unexpected FA3 output shape {tuple(out0.shape)}")
116
+ out = out_flat.unflatten(0, (b, lq))
117
+
118
+ elif FLASH_ATTN_2_AVAILABLE:
119
+ out = flash_attn.flash_attn_varlen_func(
120
+ q=q_flat,
121
+ k=k_flat,
122
+ v=v_flat,
123
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
124
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
125
+ max_seqlen_q=lq,
126
+ max_seqlen_k=lk,
127
+ dropout_p=dropout_p,
128
+ softmax_scale=softmax_scale,
129
+ causal=causal,
130
+ window_size=ws, # <- pass 2-tuple
131
+ deterministic=deterministic,
132
+ ).unflatten(0, (b, lq))
133
+ else:
134
+ q_s = q.transpose(1, 2).to(dtype)
135
+ k_s = k.transpose(1, 2).to(dtype)
136
+ v_s = v.transpose(1, 2).to(dtype)
137
+ out = torch.nn.functional.scaled_dot_product_attention(
138
+ q_s, k_s, v_s, attn_mask=None, is_causal=causal, dropout_p=dropout_p
139
+ ).transpose(1, 2).contiguous()
140
+
141
+ return out.to(out_dtype)
142
+
143
+ @_wan_flash_attention_op.register_fake
144
+ def _wan_flash_attention_op_fake(
145
+ q,
146
+ k,
147
+ v,
148
+ q_lens=None,
149
+ k_lens=None,
150
+ dropout_p: float = 0.0,
151
+ softmax_scale=None,
152
+ q_scale=None,
153
+ causal: bool = False,
154
+ window_size: Optional[Sequence[int]] = None,
155
+ deterministic: bool = False,
156
+ dtype: torch.dtype = torch.bfloat16,
157
+ version: Optional[int] = None,
158
+ ):
159
+ # Match output shape: (B, Lq, Nq, Dh_v) and keep the SAME fake device as `q`
160
+ B, Lq, Nq, _ = q.shape
161
+ Dh_v = v.shape[-1]
162
+ return q.new_empty((B, Lq, Nq, Dh_v), dtype=q.dtype)
163
+
164
+
165
+
166
+ # ---------------------------
167
+ # Public API (unchanged signature)
168
+ # ---------------------------
169
+ def flash_attention(
170
+ q,
171
+ k,
172
+ v,
173
+ q_lens=None,
174
+ k_lens=None,
175
+ dropout_p=0.,
176
+ softmax_scale=None,
177
+ q_scale=None,
178
+ causal=False,
179
+ window_size=(-1, -1),
180
+ deterministic=False,
181
+ dtype=torch.bfloat16,
182
+ version=None,
183
+ ):
184
+ """
185
+ q: [B, Lq, Nq, C1].
186
+ k: [B, Lk, Nk, C1].
187
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
188
+ q_lens: [B].
189
+ k_lens: [B].
190
+ dropout_p: float. Dropout probability.
191
+ softmax_scale: float. The scaling of QK^T before applying softmax.
192
+ causal: bool. Whether to apply causal attention mask.
193
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
194
+ deterministic: bool. If True, slightly slower and uses more memory.
195
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
196
+ """
197
+ # Simply delegate to the custom op so Dynamo/AOT treats it as a single node;
198
+ # our eager kernel inside _wan_flash_attention_op keeps the original behavior.
199
+ return _wan_flash_attention_op(
200
+ q, k, v,
201
+ q_lens=q_lens,
202
+ k_lens=k_lens,
203
+ dropout_p=dropout_p,
204
+ softmax_scale=softmax_scale,
205
+ q_scale=q_scale,
206
+ causal=causal,
207
+ window_size=window_size,
208
+ deterministic=deterministic,
209
+ dtype=dtype,
210
+ version=version,
211
+ )
212
+
213
+
214
+ def attention(
215
+ q,
216
+ k,
217
+ v,
218
+ q_lens=None,
219
+ k_lens=None,
220
+ dropout_p=0.,
221
+ softmax_scale=None,
222
+ q_scale=None,
223
+ causal=False,
224
+ window_size=(-1, -1),
225
+ deterministic=False,
226
+ dtype=torch.bfloat16,
227
+ fa_version=None,
228
+ ):
229
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
230
+ return flash_attention(
231
+ q=q,
232
+ k=k,
233
+ v=v,
234
+ q_lens=q_lens,
235
+ k_lens=k_lens,
236
+ dropout_p=dropout_p,
237
+ softmax_scale=softmax_scale,
238
+ q_scale=q_scale,
239
+ causal=causal,
240
+ window_size=window_size,
241
+ deterministic=deterministic,
242
+ dtype=dtype,
243
+ version=fa_version,
244
+ )
245
+ else:
246
+ if q_lens is not None or k_lens is not None:
247
+ warnings.warn(
248
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
249
+ )
250
+ q_ = q.transpose(1, 2).to(dtype)
251
+ k_ = k.transpose(1, 2).to(dtype)
252
+ v_ = v.transpose(1, 2).to(dtype)
253
+ out = torch.nn.functional.scaled_dot_product_attention(
254
+ q_, k_, v_, attn_mask=None, is_causal=causal, dropout_p=dropout_p
255
+ )
256
+ return out.transpose(1, 2).contiguous()
wan/modules/model.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+
9
+ from .attention import flash_attention
10
+
11
+ __all__ = ['WanModel']
12
+
13
+
14
+ def sinusoidal_embedding_1d(dim, position):
15
+ # preprocess
16
+ assert dim % 2 == 0
17
+ half = dim // 2
18
+ position = position.type(torch.float64)
19
+
20
+ # calculation
21
+ sinusoid = torch.outer(
22
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
23
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
24
+ return x
25
+
26
+
27
+ @torch.amp.autocast('cuda', enabled=False)
28
+ def rope_params(max_seq_len, dim, theta=10000):
29
+ assert dim % 2 == 0
30
+ freqs = torch.outer(
31
+ torch.arange(max_seq_len),
32
+ 1.0 / torch.pow(theta,
33
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
34
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
35
+ return freqs
36
+
37
+
38
+ @torch.amp.autocast('cuda', enabled=False)
39
+ def rope_apply(x, grid_sizes, freqs):
40
+ n, c = x.size(2), x.size(3) // 2
41
+
42
+ # split freqs
43
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
44
+
45
+ # loop over samples
46
+ output = []
47
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
48
+ seq_len = f * h * w
49
+
50
+ # precompute multipliers
51
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
52
+ seq_len, n, -1, 2))
53
+ freqs_i = torch.cat([
54
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
55
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
56
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
57
+ ],
58
+ dim=-1).reshape(seq_len, 1, -1)
59
+
60
+ # apply rotary embedding
61
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
62
+ x_i = torch.cat([x_i, x[i, seq_len:]])
63
+
64
+ # append to collection
65
+ output.append(x_i)
66
+ return torch.stack(output).float()
67
+
68
+
69
+ class WanRMSNorm(nn.Module):
70
+
71
+ def __init__(self, dim, eps=1e-5):
72
+ super().__init__()
73
+ self.dim = dim
74
+ self.eps = eps
75
+ self.weight = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x):
78
+ r"""
79
+ Args:
80
+ x(Tensor): Shape [B, L, C]
81
+ """
82
+ return self._norm(x.float()).type_as(x) * self.weight
83
+
84
+ def _norm(self, x):
85
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
86
+
87
+
88
+ class WanLayerNorm(nn.LayerNorm):
89
+
90
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
91
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
92
+
93
+ def forward(self, x):
94
+ r"""
95
+ Args:
96
+ x(Tensor): Shape [B, L, C]
97
+ """
98
+ return super().forward(x.float()).type_as(x)
99
+
100
+
101
+ class WanSelfAttention(nn.Module):
102
+
103
+ def __init__(self,
104
+ dim,
105
+ num_heads,
106
+ window_size=(-1, -1),
107
+ qk_norm=True,
108
+ eps=1e-6):
109
+ assert dim % num_heads == 0
110
+ super().__init__()
111
+ self.dim = dim
112
+ self.num_heads = num_heads
113
+ self.head_dim = dim // num_heads
114
+ self.window_size = window_size
115
+ self.qk_norm = qk_norm
116
+ self.eps = eps
117
+
118
+ # layers
119
+ self.q = nn.Linear(dim, dim)
120
+ self.k = nn.Linear(dim, dim)
121
+ self.v = nn.Linear(dim, dim)
122
+ self.o = nn.Linear(dim, dim)
123
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
124
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
+
126
+ def forward(self, x, seq_lens, grid_sizes, freqs):
127
+ r"""
128
+ Args:
129
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
130
+ seq_lens(Tensor): Shape [B]
131
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
132
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
133
+ """
134
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
135
+
136
+ # query, key, value function
137
+ def qkv_fn(x):
138
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
139
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
140
+ v = self.v(x).view(b, s, n, d)
141
+ return q, k, v
142
+
143
+ q, k, v = qkv_fn(x)
144
+
145
+ x = flash_attention(
146
+ q=rope_apply(q, grid_sizes, freqs),
147
+ k=rope_apply(k, grid_sizes, freqs),
148
+ v=v,
149
+ k_lens=seq_lens,
150
+ window_size=self.window_size)
151
+
152
+ # output
153
+ x = x.flatten(2)
154
+ x = self.o(x)
155
+ return x
156
+
157
+
158
+ class WanCrossAttention(WanSelfAttention):
159
+
160
+ def forward(self, x, context, context_lens):
161
+ r"""
162
+ Args:
163
+ x(Tensor): Shape [B, L1, C]
164
+ context(Tensor): Shape [B, L2, C]
165
+ context_lens(Tensor): Shape [B]
166
+ """
167
+ b, n, d = x.size(0), self.num_heads, self.head_dim
168
+
169
+ # compute query, key, value
170
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
171
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
172
+ v = self.v(context).view(b, -1, n, d)
173
+
174
+ # compute attention
175
+ x = flash_attention(q, k, v, k_lens=context_lens)
176
+
177
+ # output
178
+ x = x.flatten(2)
179
+ x = self.o(x)
180
+ return x
181
+
182
+
183
+ class WanAttentionBlock(nn.Module):
184
+
185
+ def __init__(self,
186
+ dim,
187
+ ffn_dim,
188
+ num_heads,
189
+ window_size=(-1, -1),
190
+ qk_norm=True,
191
+ cross_attn_norm=False,
192
+ eps=1e-6):
193
+ super().__init__()
194
+ self.dim = dim
195
+ self.ffn_dim = ffn_dim
196
+ self.num_heads = num_heads
197
+ self.window_size = window_size
198
+ self.qk_norm = qk_norm
199
+ self.cross_attn_norm = cross_attn_norm
200
+ self.eps = eps
201
+
202
+ # layers
203
+ self.norm1 = WanLayerNorm(dim, eps)
204
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
205
+ eps)
206
+ self.norm3 = WanLayerNorm(
207
+ dim, eps,
208
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
209
+ self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
210
+ eps)
211
+ self.norm2 = WanLayerNorm(dim, eps)
212
+ self.ffn = nn.Sequential(
213
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
214
+ nn.Linear(ffn_dim, dim))
215
+
216
+ # modulation
217
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
218
+
219
+ def forward(
220
+ self,
221
+ x,
222
+ e,
223
+ seq_lens,
224
+ grid_sizes,
225
+ freqs,
226
+ context,
227
+ context_lens,
228
+ ):
229
+ r"""
230
+ Args:
231
+ x(Tensor): Shape [B, L, C]
232
+ e(Tensor): Shape [B, L1, 6, C]
233
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
234
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
235
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
236
+ """
237
+ assert e.dtype == torch.float32
238
+ with torch.amp.autocast('cuda', dtype=torch.float32):
239
+ e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
240
+ assert e[0].dtype == torch.float32
241
+
242
+ # self-attention
243
+ y = self.self_attn(
244
+ self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
245
+ seq_lens, grid_sizes, freqs)
246
+ with torch.amp.autocast('cuda', dtype=torch.float32):
247
+ x = x + y * e[2].squeeze(2)
248
+
249
+ # cross-attention & ffn function
250
+ def cross_attn_ffn(x, context, context_lens, e):
251
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
252
+ y = self.ffn(
253
+ self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
254
+ with torch.amp.autocast('cuda', dtype=torch.float32):
255
+ x = x + y * e[5].squeeze(2)
256
+ return x
257
+
258
+ x = cross_attn_ffn(x, context, context_lens, e)
259
+ return x
260
+
261
+
262
+ class Head(nn.Module):
263
+
264
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
265
+ super().__init__()
266
+ self.dim = dim
267
+ self.out_dim = out_dim
268
+ self.patch_size = patch_size
269
+ self.eps = eps
270
+
271
+ # layers
272
+ out_dim = math.prod(patch_size) * out_dim
273
+ self.norm = WanLayerNorm(dim, eps)
274
+ self.head = nn.Linear(dim, out_dim)
275
+
276
+ # modulation
277
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
278
+
279
+ def forward(self, x, e):
280
+ r"""
281
+ Args:
282
+ x(Tensor): Shape [B, L1, C]
283
+ e(Tensor): Shape [B, L1, C]
284
+ """
285
+ assert e.dtype == torch.float32
286
+ with torch.amp.autocast('cuda', dtype=torch.float32):
287
+ e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
288
+ x = (
289
+ self.head(
290
+ self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))
291
+ return x
292
+
293
+
294
+ class WanModel(ModelMixin, ConfigMixin):
295
+ r"""
296
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
297
+ """
298
+
299
+ ignore_for_config = [
300
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
301
+ ]
302
+ _no_split_modules = ['WanAttentionBlock']
303
+
304
+ @register_to_config
305
+ def __init__(self,
306
+ model_type='t2v',
307
+ patch_size=(1, 2, 2),
308
+ text_len=512,
309
+ in_dim=16,
310
+ dim=2048,
311
+ ffn_dim=8192,
312
+ freq_dim=256,
313
+ text_dim=4096,
314
+ out_dim=16,
315
+ num_heads=16,
316
+ num_layers=32,
317
+ window_size=(-1, -1),
318
+ qk_norm=True,
319
+ cross_attn_norm=True,
320
+ eps=1e-6):
321
+ r"""
322
+ Initialize the diffusion model backbone.
323
+
324
+ Args:
325
+ model_type (`str`, *optional*, defaults to 't2v'):
326
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
327
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
328
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
329
+ text_len (`int`, *optional*, defaults to 512):
330
+ Fixed length for text embeddings
331
+ in_dim (`int`, *optional*, defaults to 16):
332
+ Input video channels (C_in)
333
+ dim (`int`, *optional*, defaults to 2048):
334
+ Hidden dimension of the transformer
335
+ ffn_dim (`int`, *optional*, defaults to 8192):
336
+ Intermediate dimension in feed-forward network
337
+ freq_dim (`int`, *optional*, defaults to 256):
338
+ Dimension for sinusoidal time embeddings
339
+ text_dim (`int`, *optional*, defaults to 4096):
340
+ Input dimension for text embeddings
341
+ out_dim (`int`, *optional*, defaults to 16):
342
+ Output video channels (C_out)
343
+ num_heads (`int`, *optional*, defaults to 16):
344
+ Number of attention heads
345
+ num_layers (`int`, *optional*, defaults to 32):
346
+ Number of transformer blocks
347
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
348
+ Window size for local attention (-1 indicates global attention)
349
+ qk_norm (`bool`, *optional*, defaults to True):
350
+ Enable query/key normalization
351
+ cross_attn_norm (`bool`, *optional*, defaults to False):
352
+ Enable cross-attention normalization
353
+ eps (`float`, *optional*, defaults to 1e-6):
354
+ Epsilon value for normalization layers
355
+ """
356
+
357
+ super().__init__()
358
+
359
+ assert model_type in ['t2v', 'i2v', 'ti2v', 's2v']
360
+ self.model_type = model_type
361
+
362
+ self.patch_size = patch_size
363
+ self.text_len = text_len
364
+ self.in_dim = in_dim
365
+ self.dim = dim
366
+ self.ffn_dim = ffn_dim
367
+ self.freq_dim = freq_dim
368
+ self.text_dim = text_dim
369
+ self.out_dim = out_dim
370
+ self.num_heads = num_heads
371
+ self.num_layers = num_layers
372
+ self.window_size = window_size
373
+ self.qk_norm = qk_norm
374
+ self.cross_attn_norm = cross_attn_norm
375
+ self.eps = eps
376
+
377
+ # embeddings
378
+ self.patch_embedding = nn.Conv3d(
379
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
380
+ self.text_embedding = nn.Sequential(
381
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
382
+ nn.Linear(dim, dim))
383
+
384
+ self.time_embedding = nn.Sequential(
385
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
386
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
387
+
388
+ # blocks
389
+ self.blocks = nn.ModuleList([
390
+ WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
391
+ cross_attn_norm, eps) for _ in range(num_layers)
392
+ ])
393
+
394
+ # head
395
+ self.head = Head(dim, out_dim, patch_size, eps)
396
+
397
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
398
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
399
+ d = dim // num_heads
400
+ self.freqs = torch.cat([
401
+ rope_params(1024, d - 4 * (d // 6)),
402
+ rope_params(1024, 2 * (d // 6)),
403
+ rope_params(1024, 2 * (d // 6))
404
+ ],
405
+ dim=1)
406
+
407
+ # initialize weights
408
+ self.init_weights()
409
+
410
+ def forward(
411
+ self,
412
+ x,
413
+ t,
414
+ context,
415
+ seq_len,
416
+ y=None,
417
+ ):
418
+ r"""
419
+ Forward pass through the diffusion model
420
+
421
+ Args:
422
+ x (List[Tensor]):
423
+ List of input video tensors, each with shape [C_in, F, H, W]
424
+ t (Tensor):
425
+ Diffusion timesteps tensor of shape [B]
426
+ context (List[Tensor]):
427
+ List of text embeddings each with shape [L, C]
428
+ seq_len (`int`):
429
+ Maximum sequence length for positional encoding
430
+ y (List[Tensor], *optional*):
431
+ Conditional video inputs for image-to-video mode, same shape as x
432
+
433
+ Returns:
434
+ List[Tensor]:
435
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
436
+ """
437
+ if self.model_type == 'i2v':
438
+ assert y is not None
439
+ # params
440
+ device = self.patch_embedding.weight.device
441
+ if self.freqs.device != device:
442
+ self.freqs = self.freqs.to(device)
443
+
444
+ if y is not None:
445
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
446
+
447
+ # embeddings
448
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
449
+ grid_sizes = torch.stack(
450
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
451
+ x = [u.flatten(2).transpose(1, 2) for u in x]
452
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
453
+ assert seq_lens.max() <= seq_len
454
+ x = torch.cat([
455
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
456
+ dim=1) for u in x
457
+ ])
458
+
459
+ # time embeddings
460
+ if t.dim() == 1:
461
+ t = t.expand(t.size(0), seq_len)
462
+ with torch.amp.autocast('cuda', dtype=torch.float32):
463
+ bt = t.size(0)
464
+ t = t.flatten()
465
+ e = self.time_embedding(
466
+ sinusoidal_embedding_1d(self.freq_dim,
467
+ t).unflatten(0, (bt, seq_len)).float())
468
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
469
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
470
+
471
+ # context
472
+ context_lens = None
473
+ context = self.text_embedding(
474
+ torch.stack([
475
+ torch.cat(
476
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
477
+ for u in context
478
+ ]))
479
+
480
+ # arguments
481
+ kwargs = dict(
482
+ e=e0,
483
+ seq_lens=seq_lens,
484
+ grid_sizes=grid_sizes,
485
+ freqs=self.freqs,
486
+ context=context,
487
+ context_lens=context_lens)
488
+
489
+ for block in self.blocks:
490
+ x = block(x, **kwargs)
491
+
492
+ # head
493
+ x = self.head(x, e)
494
+
495
+ # unpatchify
496
+ x = self.unpatchify(x, grid_sizes)
497
+ return [u.float() for u in x]
498
+
499
+ def unpatchify(self, x, grid_sizes):
500
+ r"""
501
+ Reconstruct video tensors from patch embeddings.
502
+
503
+ Args:
504
+ x (List[Tensor]):
505
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
506
+ grid_sizes (Tensor):
507
+ Original spatial-temporal grid dimensions before patching,
508
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
509
+
510
+ Returns:
511
+ List[Tensor]:
512
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
513
+ """
514
+
515
+ c = self.out_dim
516
+ out = []
517
+ for u, v in zip(x, grid_sizes.tolist()):
518
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
519
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
520
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
521
+ out.append(u)
522
+ return out
523
+
524
+ def init_weights(self):
525
+ r"""
526
+ Initialize model parameters using Xavier initialization.
527
+ """
528
+
529
+ # basic init
530
+ for m in self.modules():
531
+ if isinstance(m, nn.Linear):
532
+ nn.init.xavier_uniform_(m.weight)
533
+ if m.bias is not None:
534
+ nn.init.zeros_(m.bias)
535
+
536
+ # init embeddings
537
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
538
+ for m in self.text_embedding.modules():
539
+ if isinstance(m, nn.Linear):
540
+ nn.init.normal_(m.weight, std=.02)
541
+ for m in self.time_embedding.modules():
542
+ if isinstance(m, nn.Linear):
543
+ nn.init.normal_(m.weight, std=.02)
544
+
545
+ # init output layer
546
+ nn.init.zeros_(self.head.head.weight)
wan/modules/s2v/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .audio_encoder import AudioEncoder
3
+ from .model_s2v import WanModel_S2V
4
+
5
+ __all__ = ['WanModel_S2V', 'AudioEncoder']