From 1ec09ebd90c28b936538d193f275c99f06d5e8d1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 1 Jan 2023 17:06:39 -0800 Subject: [PATCH] [FusedDense] Limit matrix dims to 2M (instead of 64k) --- flash_attn/ops/fused_dense.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index d3bbe22..c8d6e0f 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -46,9 +46,11 @@ class FusedDenseFunc(torch.autograd.Function): weight = weight.contiguous() if process_group is not None: handle_x.wait() - batch_shape = total_x.shape[:-1] + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() - assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError('fused_dense only supports matrix dims <= 2M') output = F.linear(total_x, weight, bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) @@ -105,11 +107,9 @@ class FusedDenseFunc(torch.autograd.Function): def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group: Optional[ProcessGroup] = None): - batch_dim = x.shape[:-1].numel() dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) - if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024 - and dtype_eligible): + if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group) else: assert process_group is None @@ -222,7 +222,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): handle_x.wait() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() - assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: + raise RuntimeError('fused_dense only supports matrix dims <= 2M') if heuristic == -1: gelu_in = F.linear(total_x, weight1, bias1) output1 = F.gelu(gelu_in, approximate='tanh') @@ -348,12 +350,10 @@ def fused_dense_gelu_dense_func( checkpoint_lvl: int = 0, heuristic: int = 0, process_group: Optional[ProcessGroup] = None ): - batch_dim = x.shape[:-1].numel() dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) - and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024 - and dtype_eligible): + and (bias2 is None or bias2.is_cuda) and dtype_eligible): return FusedDenseGeluDenseFunc.apply( x, weight1, bias1, weight2, bias2, save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group