alexnasa commited on
Commit
1caa0d9
·
verified ·
1 Parent(s): a9a475b

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +236 -236
generate.py CHANGED
@@ -1,236 +1,236 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import argparse
3
- import logging
4
- import os
5
- import sys
6
- import warnings
7
- from datetime import datetime
8
-
9
- warnings.filterwarnings('ignore')
10
-
11
- import random
12
-
13
- import torch
14
- import torch.distributed as dist
15
- from PIL import Image
16
-
17
- import wan
18
- from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
- from wan.distributed.util import init_distributed_group
20
- from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
- from wan.utils.utils import merge_video_audio, save_video, str2bool
22
-
23
-
24
- EXAMPLE_PROMPT = {
25
- "t2v-A14B": {
26
- "prompt":
27
- "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
28
- },
29
- "i2v-A14B": {
30
- "prompt":
31
- "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
32
- "image":
33
- "examples/i2v_input.JPG",
34
- },
35
- "ti2v-5B": {
36
- "prompt":
37
- "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
38
- },
39
- "animate-14B": {
40
- "prompt": "视频中的人在做动作",
41
- "video": "",
42
- "pose": "",
43
- "mask": "",
44
- },
45
- "s2v-14B": {
46
- "prompt":
47
- "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
48
- "image":
49
- "examples/i2v_input.JPG",
50
- "audio":
51
- "examples/talk.wav",
52
- "tts_prompt_audio":
53
- "examples/zero_shot_prompt.wav",
54
- "tts_prompt_text":
55
- "希望你以后能够做的比我还好呦。",
56
- "tts_text":
57
- "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
58
- },
59
- }
60
-
61
-
62
- def _validate_args(args):
63
- # Basic check
64
- assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
65
- assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
66
- assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
67
-
68
- if args.prompt is None:
69
- args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
70
- if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
71
- args.image = EXAMPLE_PROMPT[args.task]["image"]
72
- if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
73
- args.audio = EXAMPLE_PROMPT[args.task]["audio"]
74
- if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
75
- args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
76
- args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
77
- args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
78
-
79
- if args.task == "i2v-A14B":
80
- assert args.image is not None, "Please specify the image path for i2v."
81
-
82
- cfg = WAN_CONFIGS[args.task]
83
-
84
- if args.sample_steps is None:
85
- args.sample_steps = cfg.sample_steps
86
-
87
- if args.sample_shift is None:
88
- args.sample_shift = cfg.sample_shift
89
-
90
- if args.sample_guide_scale is None:
91
- args.sample_guide_scale = cfg.sample_guide_scale
92
-
93
- if args.frame_num is None:
94
- args.frame_num = cfg.frame_num
95
-
96
- args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
97
- 0, sys.maxsize)
98
- # Size check
99
- if not 's2v' in args.task:
100
- assert args.size in SUPPORTED_SIZES[
101
- args.
102
- task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
103
-
104
-
105
- class _Args:
106
- pass
107
-
108
- def _parse_args():
109
- args = _Args()
110
-
111
- # core generation options
112
- args.task = "animate-14B"
113
- # args.size = "1280*720"
114
- args.size = "720*1280"
115
- args.frame_num = None
116
- args.ckpt_dir = "./Wan2.2-Animate-14B/"
117
- args.offload_model = True
118
- args.ulysses_size = 1
119
- args.t5_fsdp = False
120
- args.t5_cpu = False
121
- args.dit_fsdp = False
122
- args.prompt = None
123
- args.use_prompt_extend = False
124
- args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
125
- args.prompt_extend_model = None
126
- args.prompt_extend_target_lang = "zh" # ["zh", "en"]
127
- args.base_seed = 0
128
- args.image = None
129
- args.sample_solver = "unipc" # ['unipc', 'dpm++']
130
- args.sample_steps = None
131
- args.sample_shift = None
132
- args.sample_guide_scale = None
133
- args.convert_model_dtype = False
134
-
135
- # animate
136
- args.refert_num = 1
137
-
138
- # s2v-only
139
- args.num_clip = None
140
- args.audio = None
141
- args.enable_tts = False
142
- args.tts_prompt_audio = None
143
- args.tts_prompt_text = None
144
- args.tts_text = None
145
- args.pose_video = None
146
- args.start_from_ref = False
147
- args.infer_frames = 80
148
-
149
- _validate_args(args)
150
- return args
151
-
152
-
153
-
154
- def _init_logging(rank):
155
- # logging
156
- if rank == 0:
157
- # set format
158
- logging.basicConfig(
159
- level=logging.INFO,
160
- format="[%(asctime)s] %(levelname)s: %(message)s",
161
- handlers=[logging.StreamHandler(stream=sys.stdout)])
162
- else:
163
- logging.basicConfig(level=logging.ERROR)
164
-
165
- def load_model(use_relighting_lora = False):
166
-
167
- cfg = WAN_CONFIGS["animate-14B"]
168
-
169
- return wan.WanAnimate(
170
- config=cfg,
171
- checkpoint_dir="./Wan2.2-Animate-14B/",
172
- device_id=0,
173
- rank=0,
174
- t5_fsdp=False,
175
- dit_fsdp=False,
176
- use_sp=False,
177
- t5_cpu=False,
178
- convert_model_dtype=False,
179
- use_relighting_lora=use_relighting_lora
180
- )
181
-
182
- def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
183
- args = _parse_args()
184
- rank = int(os.getenv("RANK", 0))
185
- world_size = int(os.getenv("WORLD_SIZE", 1))
186
- local_rank = int(os.getenv("LOCAL_RANK", 0))
187
- device = local_rank
188
- _init_logging(rank)
189
-
190
- cfg = WAN_CONFIGS[args.task]
191
-
192
- logging.info(f"Input prompt: {args.prompt}")
193
- img = None
194
- if args.image is not None:
195
- img = Image.open(args.image).convert("RGB")
196
- logging.info(f"Input image: {args.image}")
197
-
198
- print(f'rank:{rank}')
199
-
200
-
201
-
202
- logging.info(f"Generating video ...")
203
- video = wan_animate.generate(
204
- src_root_path=preprocess_dir,
205
- replace_flag=replace_flag,
206
- refert_num = args.refert_num,
207
- clip_len=args.frame_num,
208
- shift=args.sample_shift,
209
- sample_solver=args.sample_solver,
210
- sampling_steps=args.sample_steps,
211
- guide_scale=args.sample_guide_scale,
212
- seed=args.base_seed,
213
- offload_model=args.offload_model)
214
- if rank == 0:
215
-
216
- save_video(
217
- tensor=video[None],
218
- save_file=save_file,
219
- fps=cfg.sample_fps,
220
- nrow=1,
221
- normalize=True,
222
- value_range=(-1, 1))
223
- # if "s2v" in args.task:
224
- # if args.enable_tts is False:
225
- # merge_video_audio(video_path=args.save_file, audio_path=args.audio)
226
- # else:
227
- # merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
228
- del video
229
-
230
- torch.cuda.synchronize()
231
- if dist.is_initialized():
232
- dist.barrier()
233
- dist.destroy_process_group()
234
-
235
- logging.info("Finished.")
236
-
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from datetime import datetime
8
+
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import random
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from PIL import Image
16
+
17
+ import wan
18
+ from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
+ from wan.distributed.util import init_distributed_group
20
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
+ from wan.utils.utils import merge_video_audio, save_video, str2bool
22
+
23
+
24
+ EXAMPLE_PROMPT = {
25
+ "t2v-A14B": {
26
+ "prompt":
27
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
28
+ },
29
+ "i2v-A14B": {
30
+ "prompt":
31
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
32
+ "image":
33
+ "examples/i2v_input.JPG",
34
+ },
35
+ "ti2v-5B": {
36
+ "prompt":
37
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
38
+ },
39
+ "animate-14B": {
40
+ "prompt": "视频中的人在做动作",
41
+ "video": "",
42
+ "pose": "",
43
+ "mask": "",
44
+ },
45
+ "s2v-14B": {
46
+ "prompt":
47
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
48
+ "image":
49
+ "examples/i2v_input.JPG",
50
+ "audio":
51
+ "examples/talk.wav",
52
+ "tts_prompt_audio":
53
+ "examples/zero_shot_prompt.wav",
54
+ "tts_prompt_text":
55
+ "希望你以后能够做的比我还好呦。",
56
+ "tts_text":
57
+ "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
58
+ },
59
+ }
60
+
61
+
62
+ def _validate_args(args):
63
+ # Basic check
64
+ assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
65
+ assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
66
+ assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
67
+
68
+ if args.prompt is None:
69
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
70
+ if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
71
+ args.image = EXAMPLE_PROMPT[args.task]["image"]
72
+ if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
73
+ args.audio = EXAMPLE_PROMPT[args.task]["audio"]
74
+ if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
75
+ args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
76
+ args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
77
+ args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
78
+
79
+ if args.task == "i2v-A14B":
80
+ assert args.image is not None, "Please specify the image path for i2v."
81
+
82
+ cfg = WAN_CONFIGS[args.task]
83
+
84
+ if args.sample_steps is None:
85
+ args.sample_steps = cfg.sample_steps
86
+
87
+ if args.sample_shift is None:
88
+ args.sample_shift = cfg.sample_shift
89
+
90
+ if args.sample_guide_scale is None:
91
+ args.sample_guide_scale = cfg.sample_guide_scale
92
+
93
+ if args.frame_num is None:
94
+ args.frame_num = cfg.frame_num
95
+
96
+ args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
97
+ 0, sys.maxsize)
98
+ # Size check
99
+ if not 's2v' in args.task:
100
+ assert args.size in SUPPORTED_SIZES[
101
+ args.
102
+ task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
103
+
104
+
105
+ class _Args:
106
+ pass
107
+
108
+ def _parse_args():
109
+ args = _Args()
110
+
111
+ # core generation options
112
+ args.task = "animate-14B"
113
+ # args.size = "1280*720"
114
+ args.size = "720*1280"
115
+ args.frame_num = None
116
+ args.ckpt_dir = "./Wan2.2-Animate-14B/"
117
+ args.offload_model = True
118
+ args.ulysses_size = 1
119
+ args.t5_fsdp = False
120
+ args.t5_cpu = False
121
+ args.dit_fsdp = False
122
+ args.prompt = None
123
+ args.use_prompt_extend = False
124
+ args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
125
+ args.prompt_extend_model = None
126
+ args.prompt_extend_target_lang = "zh" # ["zh", "en"]
127
+ args.base_seed = 1234
128
+ args.image = None
129
+ args.sample_solver = "unipc" # ['unipc', 'dpm++']
130
+ args.sample_steps = None
131
+ args.sample_shift = None
132
+ args.sample_guide_scale = None
133
+ args.convert_model_dtype = False
134
+
135
+ # animate
136
+ args.refert_num = 1
137
+
138
+ # s2v-only
139
+ args.num_clip = None
140
+ args.audio = None
141
+ args.enable_tts = False
142
+ args.tts_prompt_audio = None
143
+ args.tts_prompt_text = None
144
+ args.tts_text = None
145
+ args.pose_video = None
146
+ args.start_from_ref = False
147
+ args.infer_frames = 80
148
+
149
+ _validate_args(args)
150
+ return args
151
+
152
+
153
+
154
+ def _init_logging(rank):
155
+ # logging
156
+ if rank == 0:
157
+ # set format
158
+ logging.basicConfig(
159
+ level=logging.INFO,
160
+ format="[%(asctime)s] %(levelname)s: %(message)s",
161
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
162
+ else:
163
+ logging.basicConfig(level=logging.ERROR)
164
+
165
+ def load_model(use_relighting_lora = False):
166
+
167
+ cfg = WAN_CONFIGS["animate-14B"]
168
+
169
+ return wan.WanAnimate(
170
+ config=cfg,
171
+ checkpoint_dir="./Wan2.2-Animate-14B/",
172
+ device_id=0,
173
+ rank=0,
174
+ t5_fsdp=False,
175
+ dit_fsdp=False,
176
+ use_sp=False,
177
+ t5_cpu=False,
178
+ convert_model_dtype=False,
179
+ use_relighting_lora=use_relighting_lora
180
+ )
181
+
182
+ def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
183
+ args = _parse_args()
184
+ rank = int(os.getenv("RANK", 0))
185
+ world_size = int(os.getenv("WORLD_SIZE", 1))
186
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
187
+ device = local_rank
188
+ _init_logging(rank)
189
+
190
+ cfg = WAN_CONFIGS[args.task]
191
+
192
+ logging.info(f"Input prompt: {args.prompt}")
193
+ img = None
194
+ if args.image is not None:
195
+ img = Image.open(args.image).convert("RGB")
196
+ logging.info(f"Input image: {args.image}")
197
+
198
+ print(f'rank:{rank}')
199
+
200
+
201
+
202
+ logging.info(f"Generating video ...")
203
+ video = wan_animate.generate(
204
+ src_root_path=preprocess_dir,
205
+ replace_flag=replace_flag,
206
+ refert_num = args.refert_num,
207
+ clip_len=args.frame_num,
208
+ shift=args.sample_shift,
209
+ sample_solver=args.sample_solver,
210
+ sampling_steps=args.sample_steps,
211
+ guide_scale=args.sample_guide_scale,
212
+ seed=args.base_seed,
213
+ offload_model=args.offload_model)
214
+ if rank == 0:
215
+
216
+ save_video(
217
+ tensor=video[None],
218
+ save_file=save_file,
219
+ fps=cfg.sample_fps,
220
+ nrow=1,
221
+ normalize=True,
222
+ value_range=(-1, 1))
223
+ # if "s2v" in args.task:
224
+ # if args.enable_tts is False:
225
+ # merge_video_audio(video_path=args.save_file, audio_path=args.audio)
226
+ # else:
227
+ # merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
228
+ del video
229
+
230
+ torch.cuda.synchronize()
231
+ if dist.is_initialized():
232
+ dist.barrier()
233
+ dist.destroy_process_group()
234
+
235
+ logging.info("Finished.")
236
+