[FusedDense] Limit matrix dims to 2M (instead of 64k)
This commit is contained in:
parent
714c1b4f0f
commit
1ec09ebd90
@ -46,9 +46,11 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||
weight = weight.contiguous()
|
||||
if process_group is not None:
|
||||
handle_x.wait()
|
||||
batch_shape = total_x.shape[:-1]
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
if min(batch_dim, n, *weight.shape) > 65535 * 32:
|
||||
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
|
||||
output = F.linear(total_x, weight, bias)
|
||||
if ctx.compute_weight_gradient:
|
||||
ctx.save_for_backward(x, weight)
|
||||
@ -105,11 +107,9 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||
|
||||
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
|
||||
return_residual: bool = False, process_group: Optional[ProcessGroup] = None):
|
||||
batch_dim = x.shape[:-1].numel()
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
|
||||
and dtype_eligible):
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
||||
else:
|
||||
assert process_group is None
|
||||
@ -222,7 +222,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
|
||||
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
|
||||
if heuristic == -1:
|
||||
gelu_in = F.linear(total_x, weight1, bias1)
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
@ -348,12 +350,10 @@ def fused_dense_gelu_dense_func(
|
||||
checkpoint_lvl: int = 0, heuristic: int = 0,
|
||||
process_group: Optional[ProcessGroup] = None
|
||||
):
|
||||
batch_dim = x.shape[:-1].numel()
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
|
||||
and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024
|
||||
and dtype_eligible):
|
||||
and (bias2 is None or bias2.is_cuda) and dtype_eligible):
|
||||
return FusedDenseGeluDenseFunc.apply(
|
||||
x, weight1, bias1, weight2, bias2,
|
||||
save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group
|
||||
|
||||
Loading…
Reference in New Issue
Block a user