diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 098b538..3353767 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -170,16 +170,21 @@ class ColumnParallelLinear(nn.Linear): process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, + multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) - if out_features % world_size != 0: - raise ValueError( - f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})" - ) + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + multiple = out_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) super().__init__( - in_features, out_features // world_size, bias=bias, device=device, dtype=dtype + in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype ) self.process_group = process_group self.sequence_parallel = sequence_parallel @@ -205,15 +210,20 @@ class RowParallelLinear(nn.Linear): process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, + multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) rank = torch.distributed.get_rank(process_group) - if in_features % world_size != 0: - raise ValueError( - f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})" - ) + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + multiple = in_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) # Only rank 0 will have bias super().__init__( in_features // world_size,