[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
This commit is contained in:
parent
bcfa7c9751
commit
cb0daccc41
@ -170,16 +170,21 @@ class ColumnParallelLinear(nn.Linear):
|
|||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
|
multiple_of=1,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
if out_features % world_size != 0:
|
if out_features % multiple_of:
|
||||||
raise ValueError(
|
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
|
||||||
f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})"
|
multiple = out_features // multiple_of
|
||||||
)
|
# We want to split @multiple across world_size, but it could be an uneven split
|
||||||
|
div = multiple // world_size
|
||||||
|
mod = multiple % world_size
|
||||||
|
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
||||||
|
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_features, out_features // world_size, bias=bias, device=device, dtype=dtype
|
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.sequence_parallel = sequence_parallel
|
self.sequence_parallel = sequence_parallel
|
||||||
@ -205,15 +210,20 @@ class RowParallelLinear(nn.Linear):
|
|||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
|
multiple_of=1,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
rank = torch.distributed.get_rank(process_group)
|
rank = torch.distributed.get_rank(process_group)
|
||||||
if in_features % world_size != 0:
|
if in_features % multiple_of:
|
||||||
raise ValueError(
|
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
|
||||||
f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})"
|
multiple = in_features // multiple_of
|
||||||
)
|
# We want to split @multiple across world_size, but it could be an uneven split
|
||||||
|
div = multiple // world_size
|
||||||
|
mod = multiple % world_size
|
||||||
|
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
||||||
|
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
||||||
# Only rank 0 will have bias
|
# Only rank 0 will have bias
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_features // world_size,
|
in_features // world_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user