2022-05-21 05:21:58 +08:00
|
|
|
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
import flash_attn_cuda
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
def _flash_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal, return_softmax):
|
|
|
|
|
context, softmax_lse, *rest = flash_attn_cuda.fwd(qkv, cu_seqlens, dropout_p, max_s, softmax_scale,
|
2022-05-21 05:21:58 +08:00
|
|
|
False, causal, return_softmax, None)
|
|
|
|
|
# if context.isnan().any() or softmax_lse.isnan().any():
|
|
|
|
|
# breakpoint()
|
|
|
|
|
S_dmask = rest[0] if return_softmax else None
|
|
|
|
|
return context, softmax_lse, S_dmask
|
|
|
|
|
|
|
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
def _flash_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, max_s,
|
2022-05-21 05:21:58 +08:00
|
|
|
softmax_scale, causal):
|
2022-05-27 04:57:38 +08:00
|
|
|
dqkv, dp, softmax_d = flash_attn_cuda.bwd(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p,
|
2022-05-21 05:21:58 +08:00
|
|
|
softmax_scale, max_s, False, causal, None)
|
|
|
|
|
# if dqkv.isnan().any() or softmax_d.isnan().any():
|
|
|
|
|
# breakpoint()
|
|
|
|
|
return dqkv
|
|
|
|
|
|
|
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
class FlashAttnFun(torch.autograd.Function):
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
|
|
|
|
|
# Save rng_state because the backward pass will regenerate the dropout mask
|
|
|
|
|
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
|
|
|
|
if softmax_scale is None:
|
|
|
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
2022-05-27 04:57:38 +08:00
|
|
|
context, softmax_lse, S_dmask = _flash_attn_forward(
|
2022-05-21 05:21:58 +08:00
|
|
|
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False
|
|
|
|
|
)
|
|
|
|
|
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.max_s = max_s
|
|
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
|
|
|
|
return context
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, dout):
|
|
|
|
|
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
|
|
|
|
if rng_state is not None:
|
|
|
|
|
cur_rng_state = torch.cuda.get_rng_state()
|
|
|
|
|
torch.cuda.set_rng_state(rng_state)
|
|
|
|
|
# S_dmask is None, temporarily use another tensor just to get it running
|
2022-05-27 04:57:38 +08:00
|
|
|
dqkv = _flash_attn_backward(
|
2022-05-21 05:21:58 +08:00
|
|
|
dout, qkv, context, context, softmax_lse, cu_seqlens, ctx.dropout_p,
|
|
|
|
|
ctx.max_s, ctx.softmax_scale, ctx.causal
|
|
|
|
|
)
|
|
|
|
|
if rng_state is not None:
|
|
|
|
|
torch.cuda.set_rng_state(cur_rng_state)
|
|
|
|
|
return dqkv, None, None, None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We duplicate code to return both the output and the softmax for testing
|
|
|
|
|
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
|
2022-05-27 04:57:38 +08:00
|
|
|
class FlashAttnFunWithS(torch.autograd.Function):
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
|
|
|
|
|
# Save rng_state because the backward pass is gonna regenerate the dropout mask
|
|
|
|
|
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
|
|
|
|
if softmax_scale is None:
|
|
|
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
2022-05-27 04:57:38 +08:00
|
|
|
context, softmax_lse, S_dmask = _flash_attn_forward(
|
2022-05-21 05:21:58 +08:00
|
|
|
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True
|
|
|
|
|
)
|
|
|
|
|
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.max_s = max_s
|
|
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
|
|
|
|
return context, S_dmask, softmax_lse
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
|
|
|
|
|
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
|
|
|
|
if rng_state is not None:
|
|
|
|
|
cur_rng_state = torch.cuda.get_rng_state()
|
|
|
|
|
torch.cuda.set_rng_state(rng_state)
|
2022-05-27 04:57:38 +08:00
|
|
|
dqkv = _flash_attn_backward(
|
2022-05-21 05:21:58 +08:00
|
|
|
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, ctx.dropout_p,
|
|
|
|
|
ctx.max_s, ctx.softmax_scale, ctx.causal
|
|
|
|
|
)
|
|
|
|
|
if rng_state is not None:
|
|
|
|
|
torch.cuda.set_rng_state(cur_rng_state)
|
|
|
|
|
return dqkv, None, None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
|
2022-05-21 05:21:58 +08:00
|
|
|
return_attn_probs=False):
|
|
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
|
|
|
"""
|
2022-05-27 04:57:38 +08:00
|
|
|
func = FlashAttnFun if not return_attn_probs else FlashAttnFunWithS
|
2022-05-21 05:21:58 +08:00
|
|
|
return func.apply(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal)
|