diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index a6f3ec7..cb277c4 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -472,13 +472,15 @@ class ParallelMHA(nn.Module): """ def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0, - softmax_scale=None, causal=False, rotary_emb_dim=0, rotary_emb_scale_base=0, + softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0, + rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.process_group = process_group self.embed_dim = embed_dim self.causal = causal + self.layer_idx = layer_idx self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.checkpointing = checkpointing diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 919bdf5..d3bbe22 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -32,24 +32,28 @@ class FusedDenseFunc(torch.autograd.Function): ctx.process_group = process_group if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - x, weight = [a.to(dtype=dtype) for a in [x, weight]] - bias = bias.to(dtype=dtype) if bias is not None else None - + x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() + if process_group is not None: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None weight = weight.contiguous() + if process_group is not None: + handle_x.wait() + batch_shape = total_x.shape[:-1] + batch_dim = batch_shape.numel() + assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' + output = F.linear(total_x, weight, bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) else: ctx.save_for_backward(weight) - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' - if process_group is not None: - total_x, _ = all_gather_raw(x, process_group) - else: - total_x = x - output = F.linear(total_x, weight, bias) return output if not return_residual else (output, x) @staticmethod @@ -188,32 +192,42 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): 2: recompute gelu_in and gelu_out in the bwd """ assert -1 <= heuristic <= 4 - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]] - bias1 = bias1.to(dtype=dtype) if bias1 is not None else None - bias2 = bias2.to(dtype=dtype) if bias2 is not None else None if not save_pre_act: checkpoint_lvl = 2 assert checkpoint_lvl in [0, 1, 2] ctx.return_residual = return_residual ctx.process_group = process_group + ctx.checkpoint_lvl = checkpoint_lvl + ctx.heuristic = heuristic + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() + if process_group is not None: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] + bias1 = bias1.to(dtype=dtype) if bias1 is not None else None + bias2 = bias2.to(dtype=dtype) if bias2 is not None else None weight1 = weight1.contiguous() bias1 = bias1.contiguous() if bias1 is not None else None weight2 = weight2.contiguous() bias2 = bias2.contiguous() if bias2 is not None else None if process_group is not None: - total_x, _ = all_gather_raw(x, process_group) - else: - total_x = x + handle_x.wait() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' if heuristic == -1: gelu_in = F.linear(total_x, weight1, bias1) output1 = F.gelu(gelu_in, approximate='tanh') - # gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # This is before adding bias1 + # This is before adding bias1 + # gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # with torch.jit.fuser('fuser2'): # output1 = bias_gelu(gelu_in, bias1) else: @@ -223,8 +237,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): if save_pre_act: gelu_in = rest[0] output2 = F.linear(output1, weight2, bias2) - ctx.checkpoint_lvl = checkpoint_lvl - ctx.heuristic = heuristic if checkpoint_lvl == 0: ctx.save_for_backward(x, weight1, weight2, gelu_in, output1) elif checkpoint_lvl == 1: