[FusedDense] Allow Row/ColumnParallelLinear to have uneven split

This commit is contained in:
Tri Dao 2023-08-16 23:43:35 -07:00
parent bcfa7c9751
commit cb0daccc41

View File

@ -170,16 +170,21 @@ class ColumnParallelLinear(nn.Linear):
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % world_size != 0:
raise ValueError(
f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})"
)
if out_features % multiple_of:
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
multiple = out_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
super().__init__(
in_features, out_features // world_size, bias=bias, device=device, dtype=dtype
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
@ -205,15 +210,20 @@ class RowParallelLinear(nn.Linear):
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % world_size != 0:
raise ValueError(
f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})"
)
if in_features % multiple_of:
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
multiple = in_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
# Only rank 0 will have bias
super().__init__(
in_features // world_size,