[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
This commit is contained in:
parent
bcfa7c9751
commit
cb0daccc41
@ -170,16 +170,21 @@ class ColumnParallelLinear(nn.Linear):
|
||||
process_group: ProcessGroup,
|
||||
bias: bool = True,
|
||||
sequence_parallel=True,
|
||||
multiple_of=1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if out_features % world_size != 0:
|
||||
raise ValueError(
|
||||
f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})"
|
||||
)
|
||||
if out_features % multiple_of:
|
||||
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
|
||||
multiple = out_features // multiple_of
|
||||
# We want to split @multiple across world_size, but it could be an uneven split
|
||||
div = multiple // world_size
|
||||
mod = multiple % world_size
|
||||
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
||||
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
||||
super().__init__(
|
||||
in_features, out_features // world_size, bias=bias, device=device, dtype=dtype
|
||||
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
|
||||
)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
@ -205,15 +210,20 @@ class RowParallelLinear(nn.Linear):
|
||||
process_group: ProcessGroup,
|
||||
bias: bool = True,
|
||||
sequence_parallel=True,
|
||||
multiple_of=1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
if in_features % world_size != 0:
|
||||
raise ValueError(
|
||||
f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})"
|
||||
)
|
||||
if in_features % multiple_of:
|
||||
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
|
||||
multiple = in_features // multiple_of
|
||||
# We want to split @multiple across world_size, but it could be an uneven split
|
||||
div = multiple // world_size
|
||||
mod = multiple % world_size
|
||||
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
||||
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
||||
# Only rank 0 will have bias
|
||||
super().__init__(
|
||||
in_features // world_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user