Changes For FP8 (#1075)
* adding files for fp8 changes. * removed contiguous check. * enable all tests except odd-seq-lengths, where it crashes now. * undid clang formatting. * change to correct tile size for headdim=128. * fixed odd-seq-len-k. * minor formatting. * minor reformatting. --------- Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
This commit is contained in:
parent
59594f2a67
commit
1899c970c8
281
hopper/benchmark_flash_attention.py
Normal file
281
hopper/benchmark_flash_attention.py
Normal file
@ -0,0 +1,281 @@
|
||||
# Install the newest triton version with
|
||||
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
|
||||
import pickle
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
||||
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
||||
|
||||
from flash_attn import flash_attn_qkvpacked_func
|
||||
from flash_attn_interface import flash_attn_func
|
||||
|
||||
try:
|
||||
from triton.ops.flash_attention import attention as attention_triton
|
||||
except ImportError:
|
||||
attention_triton = None
|
||||
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except ImportError:
|
||||
xops = None
|
||||
|
||||
try:
|
||||
import cudnn
|
||||
except ImportError:
|
||||
cudnn = None
|
||||
|
||||
|
||||
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
||||
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
||||
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
||||
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
||||
|
||||
def efficiency(flop, time):
|
||||
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
||||
|
||||
|
||||
def convert_to_cudnn_type(torch_type):
|
||||
if torch_type == torch.float16:
|
||||
return cudnn.data_type.HALF
|
||||
elif torch_type == torch.bfloat16:
|
||||
return cudnn.data_type.BFLOAT16
|
||||
elif torch_type == torch.float32:
|
||||
return cudnn.data_type.FLOAT
|
||||
elif torch_type == torch.int32:
|
||||
return cudnn.data_type.INT32
|
||||
elif torch_type == torch.int64:
|
||||
return cudnn.data_type.INT64
|
||||
else:
|
||||
raise ValueError("Unsupported tensor data type.")
|
||||
|
||||
|
||||
def cudnn_spda_setup(q, k, v, causal=False):
|
||||
b, nheads, seqlen_q, headdim = q.shape
|
||||
_, _, seqlen_k, _ = k.shape
|
||||
assert v.shape == (b, nheads, seqlen_k, headdim)
|
||||
assert cudnn is not None, 'CUDNN is not available'
|
||||
q_gpu, k_gpu, v_gpu = q, k, v
|
||||
o_gpu = torch.empty_like(q_gpu)
|
||||
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
|
||||
graph = cudnn.pygraph(
|
||||
io_data_type=convert_to_cudnn_type(q.dtype),
|
||||
intermediate_data_type=cudnn.data_type.FLOAT,
|
||||
compute_data_type=cudnn.data_type.FLOAT,
|
||||
)
|
||||
q = graph.tensor_like(q_gpu.detach())
|
||||
k = graph.tensor_like(k_gpu.detach())
|
||||
v = graph.tensor_like(v_gpu.detach())
|
||||
|
||||
o, stats = graph.sdpa(
|
||||
name="sdpa",
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
is_inference=False,
|
||||
attn_scale=1.0 / math.sqrt(headdim),
|
||||
use_causal_mask=causal,
|
||||
)
|
||||
|
||||
o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
|
||||
stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
|
||||
|
||||
graph.validate()
|
||||
graph.build_operation_graph()
|
||||
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
||||
graph.check_support()
|
||||
graph.build_plans()
|
||||
|
||||
variant_pack = {
|
||||
q: q_gpu,
|
||||
k: k_gpu,
|
||||
v: v_gpu,
|
||||
o: o_gpu,
|
||||
stats: stats_gpu,
|
||||
}
|
||||
|
||||
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
|
||||
|
||||
def run(*args, **kwargs):
|
||||
graph.execute(variant_pack, workspace)
|
||||
return o_gpu
|
||||
|
||||
return run
|
||||
|
||||
|
||||
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
||||
"""
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
||||
dropout_p: float
|
||||
Output:
|
||||
output: (batch_size, seqlen, nheads, head_dim)
|
||||
"""
|
||||
batch_size, seqlen, _, nheads, d = qkv.shape
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
q = rearrange(q, 'b t h d -> (b h) t d')
|
||||
k = rearrange(k, 'b s h d -> (b h) d s')
|
||||
softmax_scale = 1.0 / math.sqrt(d)
|
||||
# Preallocate attn_weights for `baddbmm`
|
||||
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
||||
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
||||
'(b h) t s -> b h t s', h=nheads)
|
||||
if causal:
|
||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||
# So we have to construct the mask in float
|
||||
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
attention = torch.softmax(scores, dim=-1)
|
||||
attention_drop = F.dropout(attention, dropout_p)
|
||||
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||
return output.to(dtype=qkv.dtype)
|
||||
|
||||
|
||||
def time_fwd_bwd(func, *args, **kwargs):
|
||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
||||
return time_f[1].mean, time_b[1].mean
|
||||
|
||||
|
||||
repeats = 30
|
||||
device = 'cuda'
|
||||
dtype = torch.float16
|
||||
|
||||
# Ideally, seq-len should be divisible by 132 to avoid wave quantization.
|
||||
# However, the existing Triton implementation doesn't support seq-len like 8448.
|
||||
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192)]
|
||||
# bs_seqlen_vals = [(2, 8192)]
|
||||
causal_vals = [False]
|
||||
# headdim_vals = [64, 128]
|
||||
headdim_vals = [128]
|
||||
dim = 128
|
||||
dropout_p = 0.0
|
||||
|
||||
methods = (["Flash2", "Pytorch", "Flash3"]
|
||||
+ (["Triton"] if attention_triton is not None else [])
|
||||
+ (["xformers.c"] if xops is not None else [])
|
||||
+ (["xformers.f"] if xops is not None else [])
|
||||
+ (["cudnn"] if cudnn is not None else []))
|
||||
|
||||
time_f = {}
|
||||
time_b = {}
|
||||
time_f_b = {}
|
||||
speed_f = {}
|
||||
speed_b = {}
|
||||
speed_f_b = {}
|
||||
for causal in causal_vals:
|
||||
for headdim in headdim_vals:
|
||||
for batch_size, seqlen in bs_seqlen_vals:
|
||||
config = (causal, headdim, batch_size, seqlen)
|
||||
nheads = dim // headdim
|
||||
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
f, b = time_fwd_bwd(
|
||||
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
||||
)
|
||||
time_f[config, "Flash2"] = f
|
||||
time_b[config, "Flash2"] = b
|
||||
|
||||
try:
|
||||
qkv = qkv.detach().requires_grad_(True)
|
||||
f, b = time_fwd_bwd(
|
||||
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
||||
)
|
||||
res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
|
||||
except: # Skip if OOM
|
||||
f, b = float('nan'), float('nan')
|
||||
time_f[config, "Pytorch"] = f
|
||||
time_b[config, "Pytorch"] = b
|
||||
|
||||
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True) for _ in range(3)]
|
||||
f, b = time_fwd_bwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
|
||||
res = flash_attn_func(q, k, v, causal=causal)
|
||||
|
||||
time_f[config, "Flash3"] = f
|
||||
time_b[config, "Flash3"] = b
|
||||
|
||||
if cudnn is not None:
|
||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||
res = benchmark_forward(
|
||||
cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal),
|
||||
repeats=repeats, verbose=False
|
||||
)
|
||||
f = res[1].mean
|
||||
time_f[config, "cudnn"] = f
|
||||
time_b[config, "cudnn"] = math.inf
|
||||
|
||||
if attention_triton is not None:
|
||||
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True) for _ in range(3)]
|
||||
# Try both values of sequence_parallel and pick the faster one
|
||||
try:
|
||||
f, b = time_fwd_bwd(
|
||||
attention_triton, q, k, v, causal, headdim**(-0.5),
|
||||
False, repeats=repeats, verbose=False
|
||||
)
|
||||
except:
|
||||
f, b = float('nan'), float('inf')
|
||||
try:
|
||||
_, b0 = time_fwd_bwd(
|
||||
attention_triton, q, k, v, causal, headdim**(-0.5),
|
||||
True, repeats=repeats, verbose=False
|
||||
)
|
||||
except:
|
||||
b0 = float('inf')
|
||||
time_f[config, "Triton"] = f
|
||||
time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
|
||||
|
||||
if xops is not None:
|
||||
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True) for _ in range(3)]
|
||||
f, b = time_fwd_bwd(
|
||||
xops.memory_efficient_attention, q, k, v,
|
||||
attn_bias=xops.LowerTriangularMask() if causal else None,
|
||||
op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
|
||||
)
|
||||
time_f[config, "xformers.c"] = f
|
||||
time_b[config, "xformers.c"] = b
|
||||
|
||||
if xops is not None:
|
||||
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True) for _ in range(3)]
|
||||
f, b = time_fwd_bwd(
|
||||
xops.memory_efficient_attention, q, k, v,
|
||||
attn_bias=xops.LowerTriangularMask() if causal else None,
|
||||
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
|
||||
)
|
||||
time_f[config, "xformers.f"] = f
|
||||
time_b[config, "xformers.f"] = b
|
||||
|
||||
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
||||
for method in methods:
|
||||
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
|
||||
speed_f[config, method] = efficiency(
|
||||
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
||||
time_f[config, method]
|
||||
)
|
||||
speed_b[config, method] = efficiency(
|
||||
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
|
||||
time_b[config, method]
|
||||
)
|
||||
speed_f_b[config, method] = efficiency(
|
||||
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
|
||||
time_f_b[config, method]
|
||||
)
|
||||
#print (time_f[config,method])
|
||||
print(
|
||||
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
|
||||
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
|
||||
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
|
||||
)
|
||||
|
||||
|
||||
# with open('flash2_attn_time.plk', 'wb') as fp:
|
||||
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
339
hopper/benchmark_flash_attention_fp8.py
Normal file
339
hopper/benchmark_flash_attention_fp8.py
Normal file
@ -0,0 +1,339 @@
|
||||
# Install the newest triton version with
|
||||
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
|
||||
import pickle
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
||||
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
||||
|
||||
from flash_attn import flash_attn_qkvpacked_func
|
||||
from flash_attn_interface import flash_attn_func
|
||||
|
||||
try:
|
||||
from triton_fused_attention import attention as attention_triton
|
||||
except ImportError:
|
||||
attention_triton = None
|
||||
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except ImportError:
|
||||
xops = None
|
||||
|
||||
try:
|
||||
import cudnn
|
||||
except ImportError:
|
||||
cudnn = None
|
||||
|
||||
|
||||
def convert_to_cudnn_type(torch_type):
|
||||
if torch_type == torch.float16:
|
||||
return cudnn.data_type.HALF
|
||||
elif torch_type == torch.bfloat16:
|
||||
return cudnn.data_type.BFLOAT16
|
||||
elif torch_type == torch.float32:
|
||||
return cudnn.data_type.FLOAT
|
||||
elif torch_type == torch.int32:
|
||||
return cudnn.data_type.INT32
|
||||
elif torch_type == torch.int64:
|
||||
return cudnn.data_type.INT64
|
||||
elif torch_type == torch.float8_e4m3fn:
|
||||
return cudnn.data_type.FP8_E4M3
|
||||
elif torch_type == torch.float8_e4m3fn:
|
||||
return cudnn.data_type.FP8_E5M2
|
||||
else:
|
||||
raise ValueError("Unsupported tensor data type.")
|
||||
|
||||
def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
|
||||
b, _, _, nheads, headdim = qkv.shape
|
||||
assert cudnn is not None, 'CUDNN is not available'
|
||||
o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)
|
||||
o_gpu_transposed = torch.as_strided(
|
||||
o_gpu,
|
||||
[b, nheads, seqlen_q, headdim],
|
||||
[nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
|
||||
)
|
||||
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)
|
||||
amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
|
||||
amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
|
||||
graph = cudnn.pygraph(
|
||||
io_data_type=convert_to_cudnn_type(qkv.dtype),
|
||||
intermediate_data_type=cudnn.data_type.FLOAT,
|
||||
compute_data_type=cudnn.data_type.FLOAT,
|
||||
)
|
||||
new_q = torch.as_strided(
|
||||
qkv,
|
||||
[b, nheads, seqlen_q, headdim],
|
||||
[seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
|
||||
storage_offset=0,
|
||||
)
|
||||
q = graph.tensor(
|
||||
name = "Q",
|
||||
dim = list(new_q.shape),
|
||||
stride = list(new_q.stride()),
|
||||
data_type=convert_to_cudnn_type(qkv.dtype)
|
||||
)
|
||||
new_k = torch.as_strided(
|
||||
qkv,
|
||||
[b, nheads, seqlen_k, headdim],
|
||||
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
|
||||
storage_offset=nheads * headdim,
|
||||
)
|
||||
k = graph.tensor(
|
||||
name = "K",
|
||||
dim = list(new_k.shape),
|
||||
stride = list(new_k.stride()),
|
||||
data_type=convert_to_cudnn_type(qkv.dtype)
|
||||
)
|
||||
new_v = torch.as_strided(
|
||||
qkv,
|
||||
[b, nheads, seqlen_k, headdim],
|
||||
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
|
||||
storage_offset=nheads * headdim * 2,
|
||||
)
|
||||
v = graph.tensor(
|
||||
name = "V",
|
||||
dim = list(new_v.shape),
|
||||
stride = list(new_v.stride()),
|
||||
data_type=convert_to_cudnn_type(qkv.dtype)
|
||||
)
|
||||
|
||||
def get_default_scale_tensor():
|
||||
return graph.tensor(
|
||||
dim = [1, 1, 1, 1],
|
||||
stride = [1, 1, 1, 1],
|
||||
data_type=cudnn.data_type.FLOAT
|
||||
)
|
||||
|
||||
default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
|
||||
descale_q = get_default_scale_tensor()
|
||||
descale_k = get_default_scale_tensor()
|
||||
descale_v = get_default_scale_tensor()
|
||||
descale_s = get_default_scale_tensor()
|
||||
scale_s = get_default_scale_tensor()
|
||||
scale_o = get_default_scale_tensor()
|
||||
|
||||
o, _, amax_s, amax_o = graph.sdpa_fp8(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
descale_v=descale_v,
|
||||
descale_s=descale_s,
|
||||
scale_s=scale_s,
|
||||
scale_o=scale_o,
|
||||
is_inference=True,
|
||||
attn_scale=1.0 / math.sqrt(headdim),
|
||||
use_causal_mask=causal,
|
||||
name="sdpa",
|
||||
)
|
||||
|
||||
o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
|
||||
|
||||
amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
|
||||
amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
|
||||
# stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
|
||||
|
||||
graph.validate()
|
||||
graph.build_operation_graph()
|
||||
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
||||
graph.check_support()
|
||||
graph.build_plans()
|
||||
|
||||
variant_pack = {
|
||||
q: new_q,
|
||||
k: new_k,
|
||||
v: new_v,
|
||||
descale_q: default_scale_gpu,
|
||||
descale_k: default_scale_gpu,
|
||||
descale_v: default_scale_gpu,
|
||||
descale_s: default_scale_gpu,
|
||||
scale_s: default_scale_gpu,
|
||||
scale_o: default_scale_gpu,
|
||||
o: o_gpu_transposed,
|
||||
amax_s: amax_s_gpu,
|
||||
amax_o: amax_o_gpu,
|
||||
}
|
||||
|
||||
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
|
||||
|
||||
def run(*args, **kwargs):
|
||||
graph.execute(variant_pack, workspace)
|
||||
return o_gpu, amax_o_gpu
|
||||
|
||||
return run
|
||||
|
||||
|
||||
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
||||
"""
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
||||
dropout_p: float
|
||||
Output:
|
||||
output: (batch_size, seqlen, nheads, head_dim)
|
||||
"""
|
||||
batch_size, seqlen, _, nheads, d = qkv.shape
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
q = rearrange(q, 'b t h d -> (b h) t d')
|
||||
k = rearrange(k, 'b s h d -> (b h) d s')
|
||||
softmax_scale = 1.0 / math.sqrt(d)
|
||||
# Preallocate attn_weights for `baddbmm`
|
||||
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
||||
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
||||
'(b h) t s -> b h t s', h=nheads)
|
||||
if causal:
|
||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||
# So we have to construct the mask in float
|
||||
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
attention = torch.softmax(scores, dim=-1)
|
||||
attention_drop = F.dropout(attention, dropout_p)
|
||||
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||
return output.to(dtype=qkv.dtype)
|
||||
|
||||
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
||||
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
||||
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
||||
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
||||
|
||||
def efficiency(flop, time):
|
||||
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
||||
|
||||
def time_fwd(func, *args, **kwargs):
|
||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||
time_f = benchmark_forward(func, *args, **kwargs)
|
||||
return time_f[1].mean
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
repeats = 30
|
||||
device = 'cuda'
|
||||
# dtype = torch.float16
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
#bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
|
||||
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
|
||||
#bs_seqlen_vals = [(4, 4224), (2, 8448), (1, 8448 * 2)]
|
||||
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
|
||||
# bs_seqlen_vals = [(4, 8448)]
|
||||
causal_vals = [False, True]
|
||||
#headdim_vals = [64, 128, 256]
|
||||
headdim_vals = [128,256]
|
||||
dim = 2048
|
||||
# dim = 128
|
||||
dropout_p = 0.0
|
||||
|
||||
methods = (["Pytorch","Flash3", "cuDNN"]
|
||||
+ (["Triton"] if attention_triton is not None else [])
|
||||
# + (["xformers.c"] if xops is not None else [])
|
||||
# + (["xformers.f"] if xops is not None else [])
|
||||
)
|
||||
|
||||
time_f = {}
|
||||
time_b = {}
|
||||
time_f_b = {}
|
||||
speed_f = {}
|
||||
speed_b = {}
|
||||
speed_f_b = {}
|
||||
for causal in causal_vals:
|
||||
for headdim in headdim_vals:
|
||||
for batch_size, seqlen in bs_seqlen_vals:
|
||||
torch.cuda.empty_cache()
|
||||
config = (causal, headdim, batch_size, seqlen)
|
||||
nheads = dim // headdim
|
||||
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16,
|
||||
requires_grad=False) for _ in range(3)]
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
qkv = qkv.to(torch.float16)
|
||||
|
||||
f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
|
||||
time_f[config, "Pytorch"] = f
|
||||
res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
|
||||
|
||||
if attention_triton is not None:
|
||||
q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
|
||||
k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
|
||||
v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
|
||||
scale = 1 / math.sqrt(headdim)
|
||||
f = time_fwd(
|
||||
attention_triton, q_transposed, k_transposed, v_transposed,
|
||||
causal, scale, repeats=5, verbose=False, desc='Triton'
|
||||
)
|
||||
f = time_fwd(
|
||||
attention_triton, q_transposed, k_transposed, v_transposed,
|
||||
causal, scale, repeats=repeats, verbose=False, desc='Triton'
|
||||
)
|
||||
time_f[config, "Triton"] = f
|
||||
res = attention_triton(
|
||||
q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
|
||||
causal, scale
|
||||
).half().transpose(1, 2)
|
||||
torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
|
||||
|
||||
v_transposed = v.transpose(1,3).contiguous().clone()
|
||||
#v_transposed = v.transpose(1,3).clone()
|
||||
time.sleep(1)
|
||||
f = time_fwd(flash_attn_func, q, k, v_transposed, causal=causal, repeats=repeats, verbose=False)
|
||||
# res = flash_attn_func(q, k, v, causal=causal, is_fp16_acc=False)
|
||||
# torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
|
||||
|
||||
time_f[config, "Flash3"] = f
|
||||
|
||||
if cudnn is not None:
|
||||
qkv_fp8 = qkv.to(dtype)
|
||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||
f = time_fwd(
|
||||
cudnn_spda_setup(
|
||||
qkv_fp8, seqlen, seqlen,
|
||||
causal=causal
|
||||
),
|
||||
repeats=repeats, verbose=False
|
||||
)
|
||||
time_f[config, "cuDNN"] = f
|
||||
# res, amax_o = cudnn_spda_setup(
|
||||
# qkv_fp8, seqlen, seqlen,
|
||||
# causal=causal
|
||||
# )()
|
||||
# res = res.half()
|
||||
# TODO: CUDNN has numerics issues when
|
||||
# num_heads=16, dim=128, seq_len=1024, batch_size=2
|
||||
# or larger sizes.
|
||||
# res_cpu = res.cpu().reshape(-1)
|
||||
# res_baseline_cpu = res_baseline.cpu().reshape(-1)
|
||||
# print(amax_o)
|
||||
# print(res)
|
||||
# print(res_baseline)
|
||||
# for i in range(len(res_cpu)):
|
||||
# item = res_cpu[i]
|
||||
# item_baseline = res_baseline_cpu[i]
|
||||
# if abs(item - item_baseline) > 0.5:
|
||||
# print(i)
|
||||
# print(item)
|
||||
# print(item_baseline)
|
||||
# torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
|
||||
|
||||
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
||||
for method in methods:
|
||||
speed_f[config, method] = efficiency(
|
||||
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
||||
time_f[config, method]
|
||||
)
|
||||
#print (time_f[config,method])
|
||||
print(
|
||||
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
|
||||
)
|
||||
|
||||
|
||||
# with open('flash3_attn_time.plk', 'wb') as fp:
|
||||
# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
@ -20,7 +20,8 @@ using namespace cute;
|
||||
template <typename Ktraits, typename Seqlen_traits>
|
||||
struct CollectiveEpilogueFwd {
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using PrecType = typename Ktraits::Element;
|
||||
using Element = decltype(cute::conditional_return<is_same_v<PrecType, cutlass::float_e4m3_t>>(cutlass::half_t{}, PrecType{}));
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
|
||||
@ -249,7 +249,13 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
|
||||
if (params.d == 64) {
|
||||
run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
|
||||
} else if (params.d == 128) {
|
||||
run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -266,12 +272,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type for now");
|
||||
// TODO: will add e4m3 later
|
||||
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
|
||||
// "FlashAttention only support fp16 and bf16 data type");
|
||||
// "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn,
|
||||
"FlashAttention only support fp16, bf16 and fp8 (e4m3) data type for now");
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
|
||||
@ -301,29 +303,50 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
if (q_dtype == torch::kFloat8_e4m3fn) {
|
||||
CHECK_SHAPE(v, batch_size, head_size_og, num_heads_k, seqlen_k);
|
||||
} else {
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
}
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
if (q_dtype == torch::kFloat8_e4m3fn)
|
||||
{
|
||||
if (head_size_og % 16 != 0) {
|
||||
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16}));
|
||||
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16}));
|
||||
} else {
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
}
|
||||
if (seqlen_k % 16 != 0) {
|
||||
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 16 - seqlen_k % 16}));
|
||||
} else {
|
||||
v_padded = v;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (head_size_og % 8 != 0) {
|
||||
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
} else {
|
||||
} else {
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
//TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
} else {
|
||||
out = torch::empty_like(q_padded);
|
||||
out = q_dtype == torch::kFloat8_e4m3fn ? torch::empty_like(q_padded, at::kHalf) : torch::empty_like(q_padded);
|
||||
}
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
|
||||
@ -15,7 +15,7 @@ def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
def _flash_attn_forward(q, k, v, softmax_scale, causal):
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
# q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
@ -41,7 +41,7 @@ def _flash_attn_backward(
|
||||
causal
|
||||
):
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
#dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(
|
||||
dout,
|
||||
q,
|
||||
|
||||
9
hopper/flash_fwd_hdim128_fp8_sm90.cu
Normal file
9
hopper/flash_fwd_hdim128_fp8_sm90.cu
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::float_e4m3_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::float_e4m3_t>(params, stream);
|
||||
}
|
||||
9
hopper/flash_fwd_hdim256_fp8_sm90.cu
Normal file
9
hopper/flash_fwd_hdim256_fp8_sm90.cu
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::float_e4m3_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::float_e4m3_t>(params, stream);
|
||||
}
|
||||
9
hopper/flash_fwd_hdim64_fp8_sm90.cu
Normal file
9
hopper/flash_fwd_hdim64_fp8_sm90.cu
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::float_e4m3_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::float_e4m3_t>(params, stream);
|
||||
}
|
||||
@ -21,6 +21,7 @@
|
||||
template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
|
||||
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementO = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(cutlass::half_t{}, Element{}));
|
||||
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
|
||||
@ -127,10 +128,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
||||
// Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
|
||||
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
||||
run_flash_fwd<
|
||||
Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
|
||||
Is_causal, Seqlen_traits
|
||||
>(params, stream);
|
||||
if constexpr (is_same_v<T, cutlass::float_e4m3_t>) {
|
||||
//run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 3, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
|
||||
//run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 4, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
|
||||
//run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 4, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -143,10 +148,11 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
||||
// Only use Cluster if number of tiles along seqlen_q is even
|
||||
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
||||
run_flash_fwd<
|
||||
Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
|
||||
Is_causal, Seqlen_traits
|
||||
>(params, stream);
|
||||
if constexpr (is_same_v<T, cutlass::float_e4m3_t>) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 3, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -25,6 +25,7 @@ struct SharedStorageQKVO {
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
cute::uint64_t tma_load_mbar[4]; // 4 TMA barriers pre-allocated for usage.
|
||||
cutlass::arch::ClusterTransactionBarrier barrier_Q;
|
||||
cutlass::arch::ClusterBarrier barrier_O;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
|
||||
@ -40,6 +41,7 @@ struct Flash_fwd_kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
using ElementO = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(cutlass::half_t{}, Element{}));
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
@ -69,9 +71,11 @@ struct Flash_fwd_kernel_traits {
|
||||
decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
|
||||
>{},
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using TiledMma1 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
GMMA::Major::K, cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(
|
||||
GMMA::Major::K, GMMA::Major::MN)>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
@ -84,19 +88,33 @@ struct Flash_fwd_kernel_traits {
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
using SmemLayoutAtomVFp16 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutV =
|
||||
decltype(tile_to_shape(SmemLayoutAtomV{},
|
||||
using SmemLayoutVFp16 =
|
||||
decltype(tile_to_shape(SmemLayoutAtomVFp16{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
using SmemLayoutVFp8 =
|
||||
decltype(tile_to_shape(SmemLayoutAtomVFp8{},
|
||||
make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVFp16{}));
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVtFp16 =
|
||||
decltype(cute::composition(SmemLayoutVFp16{},
|
||||
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
|
||||
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutVFp16{}(_, _, _0{}))>{}))));
|
||||
|
||||
using SmemLayoutVt = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVtFp16{}));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementO,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, ElementO>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ,
|
||||
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, ElementO, SmemLayoutQ,
|
||||
SmemLayoutK, SmemLayoutV, SmemLayoutO>;
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
|
||||
@ -43,12 +43,30 @@ struct CollectiveMainloopFwd {
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutV = SmemLayoutK;
|
||||
|
||||
using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
using SmemLayoutVFp8 =
|
||||
decltype(tile_to_shape(SmemLayoutAtomVFp8{},
|
||||
make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
|
||||
|
||||
using SmemLayoutVFp16 = SmemLayoutK;
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVt =
|
||||
decltype(cute::composition(SmemLayoutV{},
|
||||
using SmemLayoutVtFp16 =
|
||||
decltype(cute::composition(SmemLayoutVFp16{},
|
||||
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
|
||||
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
|
||||
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutVFp16{}(_, _, _0{}))>{}))));
|
||||
|
||||
using SmemLayoutV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVFp16{}));
|
||||
using SmemLayoutVt = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVtFp16{}));
|
||||
|
||||
// Dummy S layout for getting the shape for GEMM-II.
|
||||
using SmemLayoutAtomS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
|
||||
using SmemLayoutS =
|
||||
decltype(tile_to_shape(SmemLayoutAtomS{},
|
||||
make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{}))));
|
||||
|
||||
// using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom<Element>;
|
||||
// using SmemLayoutVt =
|
||||
// decltype(tile_to_shape(SmemLayoutAtomVt{},
|
||||
@ -85,6 +103,19 @@ struct CollectiveMainloopFwd {
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
//
|
||||
using TileShapeVFP8 = decltype(make_shape(cute::get<2>(TileShape_MNK{}), cute::get<1>(TileShape_MNK{})));
|
||||
using TileShapeVFP16 = decltype(make_shape(cute::get<1>(TileShape_MNK{}), cute::get<2>(TileShape_MNK{})));
|
||||
using TileShapeV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(TileShapeVFP8{}, TileShapeVFP16{}));
|
||||
using TMA_VFP8 = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}),
|
||||
take<0, 2>(SmemLayoutV{}),
|
||||
TileShapeV{},
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
|
||||
using TMA_V = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(TMA_VFP8{}, TMA_KV{}));
|
||||
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
@ -97,6 +128,7 @@ struct CollectiveMainloopFwd {
|
||||
|
||||
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
|
||||
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
Element const* ptr_Q;
|
||||
@ -115,7 +147,8 @@ struct CollectiveMainloopFwd {
|
||||
typename Seqlen_traits::LayoutT layout_V;
|
||||
cutlass::FastDivmod qhead_per_khead_divmod;
|
||||
TMA_Q tma_load_Q;
|
||||
TMA_KV tma_load_K, tma_load_V;
|
||||
TMA_KV tma_load_K;
|
||||
TMA_V tma_load_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
@ -136,12 +169,15 @@ struct CollectiveMainloopFwd {
|
||||
SmemLayoutK{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
|
||||
TMA_KV tma_load_V = make_tma_copy(
|
||||
auto gmemLayoutVFp16 = args.shape_K;
|
||||
auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16);
|
||||
auto gmemLayoutV = cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(gmemLayoutVFp8, gmemLayoutVFp16);
|
||||
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), gmemLayoutV, args.layout_V.stride());
|
||||
TMA_V tma_load_V = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mV,
|
||||
SmemLayoutV{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(select<2, 1>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {args.layout_Q, args.layout_K, args.layout_V,
|
||||
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
|
||||
@ -198,7 +234,10 @@ struct CollectiveMainloopFwd {
|
||||
|
||||
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
|
||||
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
|
||||
auto gmemLayoutVFp16 = mainloop_params.shape_K;
|
||||
auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16);
|
||||
auto gmemLayoutV = cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(gmemLayoutVFp8, gmemLayoutVFp16);
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(gmemLayoutV);
|
||||
|
||||
auto [m_block, bidh, bidb] = block_coord;
|
||||
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
||||
@ -207,12 +246,34 @@ struct CollectiveMainloopFwd {
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
|
||||
mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
|
||||
Tensor gK = seqlen_traits_k.get_local_tile_tensor(
|
||||
mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
|
||||
Tensor gV = seqlen_traits_k.get_local_tile_tensor(
|
||||
mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
|
||||
|
||||
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
|
||||
Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
||||
Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), TileShapeV{}, cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(make_coord(_0{}, _), make_coord(_, _0{}))); // (N, K, _)
|
||||
|
||||
#if 0
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
print ("\n");
|
||||
print (gV);
|
||||
print ("\n");
|
||||
print (gK);
|
||||
print ("\n");
|
||||
print ("\n");
|
||||
print (sV);
|
||||
print ("\n");
|
||||
print (sK);
|
||||
print ("\n");
|
||||
print (gmemLayoutVFp8);
|
||||
print ("\n");
|
||||
print (gmemLayoutVFp16);
|
||||
}
|
||||
|
||||
// Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
|
||||
// mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
|
||||
// Tensor gK = seqlen_traits_k.get_local_tile_tensor(
|
||||
// mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
|
||||
// Tensor gV = seqlen_traits_k.get_local_tile_tensor(
|
||||
// mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
|
||||
|
||||
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
||||
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
||||
@ -369,6 +430,13 @@ struct CollectiveMainloopFwd {
|
||||
// Note: S becomes P.
|
||||
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
|
||||
|
||||
// Dummy sS to just get the shape correctly for GEMM-II.
|
||||
Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutS{});
|
||||
Tensor tOrS = threadMma1.partition_fragment_A(sS);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
ReorgCFp8toAFp8 reg2reg;
|
||||
auto tOrPLayout = ReshapeTStoTP()(tSrS, tOrS);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
@ -382,7 +450,6 @@ struct CollectiveMainloopFwd {
|
||||
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
|
||||
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
warp_scheduler_barrier_sync();
|
||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
@ -424,7 +491,11 @@ struct CollectiveMainloopFwd {
|
||||
}
|
||||
|
||||
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
|
||||
auto tSrSPrec = convert_type<Element>(tSrS);
|
||||
if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
|
||||
reg2reg(tSrSPrec);
|
||||
}
|
||||
Tensor tOrP = make_tensor(tSrSPrec.data(), tOrPLayout);
|
||||
Tensor scores_scale = make_fragment_like(softmax.row_max);
|
||||
clear(scores_scale);
|
||||
|
||||
@ -456,7 +527,11 @@ struct CollectiveMainloopFwd {
|
||||
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
||||
++smem_pipe_read_k;
|
||||
++smem_pipe_read_v;
|
||||
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
||||
auto tSrSPrec = convert_type<Element>(tSrS);
|
||||
if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
|
||||
reg2reg(tSrSPrec);
|
||||
}
|
||||
cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
@ -479,7 +554,11 @@ struct CollectiveMainloopFwd {
|
||||
++smem_pipe_read_k;
|
||||
++smem_pipe_read_v;
|
||||
// softmax.rescale_o(tOrO, scores_scale);
|
||||
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
||||
auto tSrSPrec = convert_type<Element>(tSrS);
|
||||
if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
|
||||
reg2reg(tSrSPrec);
|
||||
}
|
||||
cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
|
||||
}
|
||||
// Tell warp 0 that smem_q is ready
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
||||
|
||||
@ -116,6 +116,9 @@ if not SKIP_CUDA_BUILD:
|
||||
"flash_fwd_hdim128_bf16_sm90.cu",
|
||||
"flash_fwd_hdim256_fp16_sm90.cu",
|
||||
"flash_fwd_hdim256_bf16_sm90.cu",
|
||||
"flash_fwd_hdim64_fp8_sm90.cu",
|
||||
"flash_fwd_hdim128_fp8_sm90.cu",
|
||||
"flash_fwd_hdim256_fp8_sm90.cu",
|
||||
"flash_bwd_hdim64_fp16_sm90.cu",
|
||||
"flash_bwd_hdim128_fp16_sm90.cu",
|
||||
"flash_bwd_hdim256_fp16_sm90.cu",
|
||||
|
||||
@ -170,7 +170,7 @@ def test_flash_attn_output(
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(384, 256),
|
||||
(384, 256),
|
||||
(512, 256),
|
||||
(640, 128),
|
||||
(1024, 1024),
|
||||
@ -261,49 +261,87 @@ def test_flash_attn_varlen_output(
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["gqa"])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
@pytest.mark.parametrize("d", [64, 128, 256])
|
||||
#@pytest.mark.parametrize("d", [128])
|
||||
# @pytest.mark.parametrize("d", [256])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(64, 128),
|
||||
(128, 128),
|
||||
(256, 256),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(384, 256),
|
||||
(640, 128),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
def test_flash_attn_output_fp8(
|
||||
seqlen_q, seqlen_k, d, causal, mha_type, dtype
|
||||
):
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# batch_size = 40
|
||||
# nheads = 16
|
||||
batch_size = 9
|
||||
nheads = 6
|
||||
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||
# batch_size = 1
|
||||
# nheads = 1
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.float16, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True)
|
||||
out, lse = flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype).transpose(1,3).contiguous().clone(), causal=causal)
|
||||
q = q.to(dtype).to(torch.float16)
|
||||
k = k.to(dtype).to(torch.float16)
|
||||
v = v.to(dtype).to(torch.float16)
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
causal=causal,
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
causal=causal,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
# g = torch.randn_like(out)
|
||||
# if d <= 128:
|
||||
# (
|
||||
# dq_unpad,
|
||||
# dk_unpad,
|
||||
# dv_unpad,
|
||||
# ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
|
||||
# dk = dk_pad_fn(dk_unpad)
|
||||
# dv = dk_pad_fn(dv_unpad)
|
||||
# (
|
||||
# dq_ref,
|
||||
# dk_ref,
|
||||
# dv_ref,
|
||||
# ) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
# (
|
||||
# dq_pt,
|
||||
# dk_pt,
|
||||
# dv_pt,
|
||||
# ) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
# dq = dq_pad_fn(dq_unpad)
|
||||
# print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
# print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
# print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
# print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
# print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
# print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
# print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
# print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
# print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
# print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
# print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
# print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
|
||||
|
||||
# if d <= 128:
|
||||
# assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
|
||||
# assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
|
||||
# assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
@ -228,6 +228,88 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
//
|
||||
// Need this register byte permute/shuffle to match register layout of
|
||||
// (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II.
|
||||
struct ReorgCFp8toAFp8 {
|
||||
int selectorEx0;
|
||||
int selectorEx1;
|
||||
int selectorEx4;
|
||||
int selectorEx5;
|
||||
int upper_map[4] = {0, 3, 1, 2};
|
||||
int lower_map[4] = {1, 2, 0, 3};
|
||||
|
||||
CUTLASS_DEVICE ReorgCFp8toAFp8() {
|
||||
int laneId = cutlass::canonical_lane_idx();
|
||||
|
||||
if (laneId % 4 == 0 || laneId % 4 == 3) {
|
||||
selectorEx0 = 0x3210;
|
||||
selectorEx1 = 0x7654;
|
||||
selectorEx4 = 0x5410;
|
||||
selectorEx5 = 0x7632;
|
||||
} else {
|
||||
selectorEx0 = 0x7654;
|
||||
selectorEx1 = 0x3210;
|
||||
selectorEx4 = 0x1054;
|
||||
selectorEx5 = 0x3276;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Fragment> CUTLASS_DEVICE auto operator()(Fragment &accum) {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// First update `mi` to the max per-row
|
||||
//
|
||||
auto VT = shape<0>(accum); // number of vector elements per tile.
|
||||
auto MT = shape<1>(accum); // number of tiles along M.
|
||||
auto NT = shape<2>(accum); // number of tiles along N.
|
||||
|
||||
auto data = accum.data();
|
||||
int n = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < MT; ++i) {
|
||||
|
||||
// Traverse 2-rows + 2-cols (2x2) simultaneously.
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < NT * size<2>(VT) / 2; ++k) {
|
||||
|
||||
auto upper = *reinterpret_cast<uint32_t *>(&data[n]);
|
||||
auto lower = *reinterpret_cast<uint32_t *>(&data[n + 4]);
|
||||
|
||||
auto upper0 = __byte_perm(upper, lower, selectorEx0);
|
||||
auto lower0 = __byte_perm(upper, lower, selectorEx1);
|
||||
upper0 =
|
||||
__shfl_sync(uint32_t(-1), upper0, upper_map[threadIdx.x % 4], 4);
|
||||
lower0 =
|
||||
__shfl_sync(uint32_t(-1), lower0, lower_map[threadIdx.x % 4], 4);
|
||||
|
||||
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
|
||||
data_32bit[0] = __byte_perm(upper0, lower0, selectorEx4);
|
||||
data_32bit[1] = __byte_perm(upper0, lower0, selectorEx5);
|
||||
n += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Reshape Utility for converting the layout from accumulator of GEMM-I
|
||||
// to Operand A of GEMM-II.
|
||||
struct ReshapeTStoTP {
|
||||
template <class FragmentC, class FragmentQ>
|
||||
CUTLASS_DEVICE auto operator()(FragmentC &&tC, FragmentQ &&tQ) {
|
||||
|
||||
// get the layout of one row of Q.
|
||||
auto layoutQRow = make_layout_like(tQ(_, 0, _).layout());
|
||||
// get the layout of M dimension of C.
|
||||
auto layoutCM = get<1>(tC.layout());
|
||||
return make_layout(get<0>(layoutQRow), layoutCM, get<1>(layoutQRow));
|
||||
}
|
||||
};
|
||||
|
||||
template <int NumCopyThreads, typename ElemO, typename TMACopyO, typename LayoutO,
|
||||
typename TileShapeO, typename SMemO, typename SeqLenTraits>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user