[ Misc ] Remove fp8_shard_indexer from Col/Row Parallel Linear (Simplify Weight Loading) (#5928)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
parent
6a2d659d28
commit
b185230744
@ -269,10 +269,6 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
# Special case for Fp8 scales.
|
|
||||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
|
||||||
None)
|
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
@ -281,11 +277,11 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
# Special case for Fp8 scales.
|
|
||||||
elif fp8_scales_shard_indexer is not None:
|
# Special case for loading scales off disk, which often do not
|
||||||
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
# have a shape (such as in the case of AutoFP8).
|
||||||
loaded_weight,
|
if len(loaded_weight.shape) == 0:
|
||||||
shard_id=0)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
@ -751,10 +747,6 @@ class RowParallelLinear(LinearBase):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
# Special case for Fp8 scales.
|
|
||||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
|
||||||
None)
|
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
input_dim = getattr(param, "input_dim", None)
|
input_dim = getattr(param, "input_dim", None)
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
@ -764,13 +756,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
|
||||||
# Special case for Fp8 scales.
|
# Special case for loading scales off disk, which often do not
|
||||||
elif fp8_scales_shard_indexer is not None:
|
# have a shape (such as in the case of AutoFP8).
|
||||||
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
if len(loaded_weight.shape) == 0:
|
||||||
loaded_weight,
|
|
||||||
shard_id=0)
|
|
||||||
|
|
||||||
if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
|
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user