From cb0daccc414021309b8748cbbcbfee5b2604eaf5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 16 Aug 2023 23:43:35 -0700 Subject: [PATCH] [FusedDense] Allow Row/ColumnParallelLinear to have uneven split --- flash_attn/ops/fused_dense.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 098b538..3353767 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -170,16 +170,21 @@ class ColumnParallelLinear(nn.Linear): process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, + multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) - if out_features % world_size != 0: - raise ValueError( - f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})" - ) + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + multiple = out_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) super().__init__( - in_features, out_features // world_size, bias=bias, device=device, dtype=dtype + in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype ) self.process_group = process_group self.sequence_parallel = sequence_parallel @@ -205,15 +210,20 @@ class RowParallelLinear(nn.Linear): process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, + multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) rank = torch.distributed.get_rank(process_group) - if in_features % world_size != 0: - raise ValueError( - f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})" - ) + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + multiple = in_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) # Only rank 0 will have bias super().__init__( in_features // world_size,