[FusedDense] Limit matrix dims to 2M (instead of 64k)

This commit is contained in:
Tri Dao 2023-01-01 17:06:39 -08:00
parent 714c1b4f0f
commit 1ec09ebd90

View File

@ -46,9 +46,11 @@ class FusedDenseFunc(torch.autograd.Function):
weight = weight.contiguous()
if process_group is not None:
handle_x.wait()
batch_shape = total_x.shape[:-1]
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
@ -105,11 +107,9 @@ class FusedDenseFunc(torch.autograd.Function):
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group: Optional[ProcessGroup] = None):
batch_dim = x.shape[:-1].numel()
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
and dtype_eligible):
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
else:
assert process_group is None
@ -222,7 +222,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
if heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
@ -348,12 +350,10 @@ def fused_dense_gelu_dense_func(
checkpoint_lvl: int = 0, heuristic: int = 0,
process_group: Optional[ProcessGroup] = None
):
batch_dim = x.shape[:-1].numel()
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024
and dtype_eligible):
and (bias2 is None or bias2.is_cuda) and dtype_eligible):
return FusedDenseGeluDenseFunc.apply(
x, weight1, bias1, weight2, bias2,
save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group