flash-attention/flash_attn/flash_attn_interface.py
2022-06-01 18:50:26 -07:00

101 lines
4.5 KiB
Python

# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import torch
import torch.nn as nn
import flash_attn_cuda
def _flash_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal, return_softmax):
context, softmax_lse, *rest = flash_attn_cuda.fwd(qkv, cu_seqlens, dropout_p, max_s, softmax_scale,
False, causal, return_softmax, None)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask
def _flash_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, max_s,
softmax_scale, causal):
dqkv, dp, softmax_d = flash_attn_cuda.bwd(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p,
softmax_scale, max_s, False, causal, None)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dqkv
class FlashAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context
@staticmethod
def backward(ctx, dout):
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv = _flash_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class FlashAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context, S_dmask, softmax_lse
@staticmethod
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = _flash_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
"""
func = FlashAttnFun if not return_attn_probs else FlashAttnFunWithS
return func.apply(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal)