[Model] RowParallelLinear: pass bias to quant_method.apply (#6327)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
a921e86392
commit
a5314e8698
@ -83,6 +83,9 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
|||||||
# cleaned up properly, and its server host thread leaks, causing the
|
# cleaned up properly, and its server host thread leaks, causing the
|
||||||
# second run of the test to fail with internal NCCL error.
|
# second run of the test to fail with internal NCCL error.
|
||||||
"use_async": True,
|
"use_async": True,
|
||||||
|
|
||||||
|
# precision
|
||||||
|
"dtype": "float32",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -715,6 +715,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
@ -770,18 +771,19 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
|
# bias will not get added more than once in TP>1 case)
|
||||||
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
output_parallel = self.quant_method.apply(self,
|
||||||
|
input_parallel,
|
||||||
|
bias=bias_)
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
else:
|
else:
|
||||||
output_ = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
|
||||||
output = output_ + self.bias if self.bias is not None else output_
|
|
||||||
output_bias = None
|
|
||||||
else:
|
|
||||||
output = output_
|
|
||||||
output_bias = self.bias
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user