Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import torch | |
| import torch.distributed as dist | |
| def init_distributed_group(): | |
| """r initialize sequence parallel group. | |
| """ | |
| if not dist.is_initialized(): | |
| dist.init_process_group(backend='nccl') | |
| def get_rank(): | |
| return dist.get_rank() | |
| def get_world_size(): | |
| return dist.get_world_size() | |
| def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs): | |
| """ | |
| `scatter` along one dimension and `gather` along another. | |
| """ | |
| world_size = get_world_size() | |
| if world_size > 1: | |
| inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)] | |
| outputs = [torch.empty_like(u) for u in inputs] | |
| dist.all_to_all(outputs, inputs, group=group, **kwargs) | |
| x = torch.cat(outputs, dim=gather_dim).contiguous() | |
| return x | |
| def all_gather(tensor): | |
| world_size = dist.get_world_size() | |
| if world_size == 1: | |
| return [tensor] | |
| tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] | |
| torch.distributed.all_gather(tensor_list, tensor) | |
| return tensor_list | |
| def gather_forward(input, dim): | |
| # skip if world_size == 1 | |
| world_size = dist.get_world_size() | |
| if world_size == 1: | |
| return input | |
| # gather sequence | |
| output = all_gather(input) | |
| return torch.cat(output, dim=dim).contiguous() | |