diff --git a/hopper/benchmark_flash_attention.py b/hopper/benchmark_flash_attention.py new file mode 100644 index 0000000..9e81530 --- /dev/null +++ b/hopper/benchmark_flash_attention.py @@ -0,0 +1,281 @@ +# Install the newest triton version with +# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" +import pickle +import math +import time +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward +from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined + +from flash_attn import flash_attn_qkvpacked_func +from flash_attn_interface import flash_attn_func + +try: + from triton.ops.flash_attention import attention as attention_triton +except ImportError: + attention_triton = None + +try: + import xformers.ops as xops +except ImportError: + xops = None + +try: + import cudnn +except ImportError: + cudnn = None + + +def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def cudnn_spda_setup(q, k, v, causal=False): + b, nheads, seqlen_q, headdim = q.shape + _, _, seqlen_k, _ = k.shape + assert v.shape == (b, nheads, seqlen_k, headdim) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu = q, k, v + o_gpu = torch.empty_like(q_gpu) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + + o, stats = graph.sdpa( + name="sdpa", + q=q, + k=k, + v=v, + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + ) + + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + stats: stats_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def attention_pytorch(qkv, dropout_p=0.0, causal=True): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + dropout_p: float + Output: + output: (batch_size, seqlen, nheads, head_dim) + """ + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + # Preallocate attn_weights for `baddbmm` + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + attention_drop = F.dropout(attention, dropout_p) + output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + return output.to(dtype=qkv.dtype) + + +def time_fwd_bwd(func, *args, **kwargs): + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) + return time_f[1].mean, time_b[1].mean + + +repeats = 30 +device = 'cuda' +dtype = torch.float16 + +# Ideally, seq-len should be divisible by 132 to avoid wave quantization. +# However, the existing Triton implementation doesn't support seq-len like 8448. +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192)] +# bs_seqlen_vals = [(2, 8192)] +causal_vals = [False] +# headdim_vals = [64, 128] +headdim_vals = [128] +dim = 128 +dropout_p = 0.0 + +methods = (["Flash2", "Pytorch", "Flash3"] + + (["Triton"] if attention_triton is not None else []) + + (["xformers.c"] if xops is not None else []) + + (["xformers.f"] if xops is not None else []) + + (["cudnn"] if cudnn is not None else [])) + +time_f = {} +time_b = {} +time_f_b = {} +speed_f = {} +speed_b = {} +speed_f_b = {} +for causal in causal_vals: + for headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + config = (causal, headdim, batch_size, seqlen) + nheads = dim // headdim + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + f, b = time_fwd_bwd( + flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "Flash2"] = f + time_b[config, "Flash2"] = b + + try: + qkv = qkv.detach().requires_grad_(True) + f, b = time_fwd_bwd( + attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + res_baseline = attention_pytorch(qkv, dropout_p, causal=causal) + except: # Skip if OOM + f, b = float('nan'), float('nan') + time_f[config, "Pytorch"] = f + time_b[config, "Pytorch"] = b + + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(3)] + f, b = time_fwd_bwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False) + res = flash_attn_func(q, k, v, causal=causal) + + time_f[config, "Flash3"] = f + time_b[config, "Flash3"] = b + + if cudnn is not None: + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + res = benchmark_forward( + cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal), + repeats=repeats, verbose=False + ) + f = res[1].mean + time_f[config, "cudnn"] = f + time_b[config, "cudnn"] = math.inf + + if attention_triton is not None: + q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(3)] + # Try both values of sequence_parallel and pick the faster one + try: + f, b = time_fwd_bwd( + attention_triton, q, k, v, causal, headdim**(-0.5), + False, repeats=repeats, verbose=False + ) + except: + f, b = float('nan'), float('inf') + try: + _, b0 = time_fwd_bwd( + attention_triton, q, k, v, causal, headdim**(-0.5), + True, repeats=repeats, verbose=False + ) + except: + b0 = float('inf') + time_f[config, "Triton"] = f + time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan') + + if xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(3)] + f, b = time_fwd_bwd( + xops.memory_efficient_attention, q, k, v, + attn_bias=xops.LowerTriangularMask() if causal else None, + op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) + ) + time_f[config, "xformers.c"] = f + time_b[config, "xformers.c"] = b + + if xops is not None: + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(3)] + f, b = time_fwd_bwd( + xops.memory_efficient_attention, q, k, v, + attn_bias=xops.LowerTriangularMask() if causal else None, + op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp) + ) + time_f[config, "xformers.f"] = f + time_b[config, "xformers.f"] = b + + print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") + for method in methods: + time_f_b[config, method] = time_f[config, method] + time_b[config, method] + speed_f[config, method] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), + time_f[config, method] + ) + speed_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"), + time_b[config, method] + ) + speed_f_b[config, method] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), + time_f_b[config, method] + ) + #print (time_f[config,method]) + print( + f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " + f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " + f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" + ) + + +# with open('flash2_attn_time.plk', 'wb') as fp: +# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/hopper/benchmark_flash_attention_fp8.py b/hopper/benchmark_flash_attention_fp8.py new file mode 100644 index 0000000..7d9e234 --- /dev/null +++ b/hopper/benchmark_flash_attention_fp8.py @@ -0,0 +1,339 @@ +# Install the newest triton version with +# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" +import pickle +import math +import time +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward +from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined + +from flash_attn import flash_attn_qkvpacked_func +from flash_attn_interface import flash_attn_func + +try: + from triton_fused_attention import attention as attention_triton +except ImportError: + attention_triton = None + +try: + import xformers.ops as xops +except ImportError: + xops = None + +try: + import cudnn +except ImportError: + cudnn = None + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + elif torch_type == torch.float8_e4m3fn: + return cudnn.data_type.FP8_E4M3 + elif torch_type == torch.float8_e4m3fn: + return cudnn.data_type.FP8_E5M2 + else: + raise ValueError("Unsupported tensor data type.") + +def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False): + b, _, _, nheads, headdim = qkv.shape + assert cudnn is not None, 'CUDNN is not available' + o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device) + o_gpu_transposed = torch.as_strided( + o_gpu, + [b, nheads, seqlen_q, headdim], + [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1], + ) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device) + amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(qkv.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + new_q = torch.as_strided( + qkv, + [b, nheads, seqlen_q, headdim], + [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=0, + ) + q = graph.tensor( + name = "Q", + dim = list(new_q.shape), + stride = list(new_q.stride()), + data_type=convert_to_cudnn_type(qkv.dtype) + ) + new_k = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim, + ) + k = graph.tensor( + name = "K", + dim = list(new_k.shape), + stride = list(new_k.stride()), + data_type=convert_to_cudnn_type(qkv.dtype) + ) + new_v = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim * 2, + ) + v = graph.tensor( + name = "V", + dim = list(new_v.shape), + stride = list(new_v.stride()), + data_type=convert_to_cudnn_type(qkv.dtype) + ) + + def get_default_scale_tensor(): + return graph.tensor( + dim = [1, 1, 1, 1], + stride = [1, 1, 1, 1], + data_type=cudnn.data_type.FLOAT + ) + + default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda") + descale_q = get_default_scale_tensor() + descale_k = get_default_scale_tensor() + descale_v = get_default_scale_tensor() + descale_s = get_default_scale_tensor() + scale_s = get_default_scale_tensor() + scale_o = get_default_scale_tensor() + + o, _, amax_s, amax_o = graph.sdpa_fp8( + q=q, + k=k, + v=v, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_s=descale_s, + scale_s=scale_s, + scale_o=scale_o, + is_inference=True, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + name="sdpa", + ) + + o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride()) + + amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride()) + amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride()) + # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: new_q, + k: new_k, + v: new_v, + descale_q: default_scale_gpu, + descale_k: default_scale_gpu, + descale_v: default_scale_gpu, + descale_s: default_scale_gpu, + scale_s: default_scale_gpu, + scale_o: default_scale_gpu, + o: o_gpu_transposed, + amax_s: amax_s_gpu, + amax_o: amax_o_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return o_gpu, amax_o_gpu + + return run + + +def attention_pytorch(qkv, dropout_p=0.0, causal=True): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, head_dim) + dropout_p: float + Output: + output: (batch_size, seqlen, nheads, head_dim) + """ + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, 'b t h d -> (b h) t d') + k = rearrange(k, 'b s h d -> (b h) d s') + softmax_scale = 1.0 / math.sqrt(d) + # Preallocate attn_weights for `baddbmm` + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), + '(b h) t s -> b h t s', h=nheads) + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + attention_drop = F.dropout(attention, dropout_p) + output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + return output.to(dtype=qkv.dtype) + +def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + +def efficiency(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + +def time_fwd(func, *args, **kwargs): + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + time_f = benchmark_forward(func, *args, **kwargs) + return time_f[1].mean + + +torch.manual_seed(0) + +repeats = 30 +device = 'cuda' +# dtype = torch.float16 +dtype = torch.float8_e4m3fn + +#bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)] +bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)] +#bs_seqlen_vals = [(4, 4224), (2, 8448), (1, 8448 * 2)] +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)] +# bs_seqlen_vals = [(4, 8448)] +causal_vals = [False, True] +#headdim_vals = [64, 128, 256] +headdim_vals = [128,256] +dim = 2048 +# dim = 128 +dropout_p = 0.0 + +methods = (["Pytorch","Flash3", "cuDNN"] + + (["Triton"] if attention_triton is not None else []) + # + (["xformers.c"] if xops is not None else []) + # + (["xformers.f"] if xops is not None else []) + ) + +time_f = {} +time_b = {} +time_f_b = {} +speed_f = {} +speed_b = {} +speed_f_b = {} +for causal in causal_vals: + for headdim in headdim_vals: + for batch_size, seqlen in bs_seqlen_vals: + torch.cuda.empty_cache() + config = (causal, headdim, batch_size, seqlen) + nheads = dim // headdim + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, + requires_grad=False) for _ in range(3)] + qkv = torch.stack([q, k, v], dim=2) + qkv = qkv.to(torch.float16) + + f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False) + time_f[config, "Pytorch"] = f + res_baseline = attention_pytorch(qkv, dropout_p, causal=causal) + + if attention_triton is not None: + q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) + k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) + v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn) + scale = 1 / math.sqrt(headdim) + f = time_fwd( + attention_triton, q_transposed, k_transposed, v_transposed, + causal, scale, repeats=5, verbose=False, desc='Triton' + ) + f = time_fwd( + attention_triton, q_transposed, k_transposed, v_transposed, + causal, scale, repeats=repeats, verbose=False, desc='Triton' + ) + time_f[config, "Triton"] = f + res = attention_triton( + q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2), + causal, scale + ).half().transpose(1, 2) + torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5) + + out = torch.empty_like(q) + q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) + + v_transposed = v.transpose(1,3).contiguous().clone() + #v_transposed = v.transpose(1,3).clone() + time.sleep(1) + f = time_fwd(flash_attn_func, q, k, v_transposed, causal=causal, repeats=repeats, verbose=False) + # res = flash_attn_func(q, k, v, causal=causal, is_fp16_acc=False) + # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05) + + time_f[config, "Flash3"] = f + + if cudnn is not None: + qkv_fp8 = qkv.to(dtype) + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + f = time_fwd( + cudnn_spda_setup( + qkv_fp8, seqlen, seqlen, + causal=causal + ), + repeats=repeats, verbose=False + ) + time_f[config, "cuDNN"] = f + # res, amax_o = cudnn_spda_setup( + # qkv_fp8, seqlen, seqlen, + # causal=causal + # )() + # res = res.half() + # TODO: CUDNN has numerics issues when + # num_heads=16, dim=128, seq_len=1024, batch_size=2 + # or larger sizes. + # res_cpu = res.cpu().reshape(-1) + # res_baseline_cpu = res_baseline.cpu().reshape(-1) + # print(amax_o) + # print(res) + # print(res_baseline) + # for i in range(len(res_cpu)): + # item = res_cpu[i] + # item_baseline = res_baseline_cpu[i] + # if abs(item - item_baseline) > 0.5: + # print(i) + # print(item) + # print(item_baseline) + # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05) + + print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") + for method in methods: + speed_f[config, method] = efficiency( + flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), + time_f[config, method] + ) + #print (time_f[config,method]) + print( + f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, " + ) + + +# with open('flash3_attn_time.plk', 'wb') as fp: +# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index 8523438..192e555 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -20,7 +20,8 @@ using namespace cute; template struct CollectiveEpilogueFwd { - using Element = typename Ktraits::Element; + using PrecType = typename Ktraits::Element; + using Element = decltype(cute::conditional_return>(cutlass::half_t{}, PrecType{})); static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kHeadDim = Ktraits::kHeadDim; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 397ed4c..c684a19 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -249,7 +249,13 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split } } } else { - // run_mha_fwd_(params, stream); + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_(params, stream); + } } } @@ -266,12 +272,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - "FlashAttention only support fp16 and bf16 data type for now"); - // TODO: will add e4m3 later - // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn, - // "FlashAttention only support fp16 and bf16 data type"); - // "FlashAttention only support fp16 and fp8 (e4m3) data type for now"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn, + "FlashAttention only support fp16, bf16 and fp8 (e4m3) data type for now"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -301,29 +303,50 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + if (q_dtype == torch::kFloat8_e4m3fn) { + CHECK_SHAPE(v, batch_size, head_size_og, num_heads_k, seqlen_k); + } else { CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + } at::Tensor q_padded, k_padded, v_padded; - if (head_size_og % 8 != 0) { + if (q_dtype == torch::kFloat8_e4m3fn) + { + if (head_size_og % 16 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16})); + } else { + q_padded = q; + k_padded = k; + } + if (seqlen_k % 16 != 0) { + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 16 - seqlen_k % 16})); + } else { + v_padded = v; + } + } + else { + if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - } else { + } else { q_padded = q; k_padded = k; v_padded = v; + } } at::Tensor out; if (out_.has_value()) { out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + //TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { - out = torch::empty_like(q_padded); + out = q_dtype == torch::kFloat8_e4m3fn ? torch::empty_like(q_padded, at::kHalf) : torch::empty_like(q_padded); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index d88ab78..07e3366 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -15,7 +15,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _flash_attn_forward(q, k, v, softmax_scale, causal): - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, k, @@ -41,7 +41,7 @@ def _flash_attn_backward( causal ): # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + #dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd( dout, q, diff --git a/hopper/flash_fwd_hdim128_fp8_sm90.cu b/hopper/flash_fwd_hdim128_fp8_sm90.cu new file mode 100644 index 0000000..68dd61b --- /dev/null +++ b/hopper/flash_fwd_hdim128_fp8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_fp8_sm90.cu b/hopper/flash_fwd_hdim256_fp8_sm90.cu new file mode 100644 index 0000000..42fe6bb --- /dev/null +++ b/hopper/flash_fwd_hdim256_fp8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_fp8_sm90.cu b/hopper/flash_fwd_hdim64_fp8_sm90.cu new file mode 100644 index 0000000..e331295 --- /dev/null +++ b/hopper/flash_fwd_hdim64_fp8_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index cd7adb3..0ca2e4c 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -21,6 +21,7 @@ template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using Element = typename Kernel_traits::Element; + using ElementO = decltype(cute::conditional_return>(cutlass::half_t{}, Element{})); using TileShape_MNK = typename Kernel_traits::TileShape_MNK; using ClusterShape = typename Kernel_traits::ClusterShape_MNK; @@ -127,10 +128,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + if constexpr (is_same_v) { + //run_flash_fwd, Is_causal>(params, stream); + //run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); + //run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); + } }); }); }); @@ -143,10 +148,11 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { // Only use Cluster if number of tiles along seqlen_q is even BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + if constexpr (is_same_v) { + run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); + } else { + run_flash_fwd, Is_causal, Seqlen_traits>(params, stream); + } }); }); }); diff --git a/hopper/kernel_traits.h b/hopper/kernel_traits.h index 90ee3cc..0335a9c 100644 --- a/hopper/kernel_traits.h +++ b/hopper/kernel_traits.h @@ -25,6 +25,7 @@ struct SharedStorageQKVO { cute::array_aligned> smem_o; }; struct { + cute::uint64_t tma_load_mbar[4]; // 4 TMA barriers pre-allocated for usage. cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; @@ -40,6 +41,7 @@ struct Flash_fwd_kernel_traits { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; + using ElementO = decltype(cute::conditional_return>(cutlass::half_t{}, Element{})); // The number of threads. static constexpr int kNWarps = kNWarps_; @@ -69,9 +71,11 @@ struct Flash_fwd_kernel_traits { decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutMNK{})); + using TiledMma1 = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(TileShape_MNK{})), - GMMA::Major::K, GMMA::Major::MN>(), + GMMA::Major::K, cute::conditional_return>( + GMMA::Major::K, GMMA::Major::MN)>(), AtomLayoutMNK{})); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutV = - decltype(tile_to_shape(SmemLayoutAtomV{}, + using SmemLayoutVFp16 = + decltype(tile_to_shape(SmemLayoutAtomVFp16{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutVFp8 = + decltype(tile_to_shape(SmemLayoutAtomVFp8{}, + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); + using SmemLayoutV = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVFp16{})); + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutVtFp16 = + decltype(cute::composition(SmemLayoutVFp16{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), + make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); + + using SmemLayoutVt = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVtFp16{})); + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); - using SmemCopyAtomQ = Copy_Atom; + using SmemCopyAtomQ = Copy_Atom; - using SharedStorage = SharedStorageQKVO; using MainloopPipeline = typename cutlass::PipelineTmaAsync; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 2de15fb..f9c8e2a 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -43,12 +43,30 @@ struct CollectiveMainloopFwd { using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutV = SmemLayoutK; + + using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutVFp8 = + decltype(tile_to_shape(SmemLayoutAtomVFp8{}, + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); + + using SmemLayoutVFp16 = SmemLayoutK; // Note this is the transpose in terms of the view, not in terms of memory. - using SmemLayoutVt = - decltype(cute::composition(SmemLayoutV{}, + using SmemLayoutVtFp16 = + decltype(cute::composition(SmemLayoutVFp16{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), - make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); + make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); + + using SmemLayoutV = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVFp16{})); + using SmemLayoutVt = decltype(cute::conditional_return>(SmemLayoutVFp8{}, SmemLayoutVtFp16{})); + + // Dummy S layout for getting the shape for GEMM-II. + using SmemLayoutAtomS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutS = + decltype(tile_to_shape(SmemLayoutAtomS{}, + make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{})))); + // using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom; // using SmemLayoutVt = // decltype(tile_to_shape(SmemLayoutAtomVt{}, @@ -85,6 +103,19 @@ struct CollectiveMainloopFwd { take<0, 2>(SmemLayoutK{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + // + using TileShapeVFP8 = decltype(make_shape(cute::get<2>(TileShape_MNK{}), cute::get<1>(TileShape_MNK{}))); + using TileShapeVFP16 = decltype(make_shape(cute::get<1>(TileShape_MNK{}), cute::get<2>(TileShape_MNK{}))); + using TileShapeV = decltype(cute::conditional_return>(TileShapeVFP8{}, TileShapeVFP16{})); + using TMA_VFP8 = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + take<0, 2>(SmemLayoutV{}), + TileShapeV{}, + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + using TMA_V = decltype(cute::conditional_return>(TMA_VFP8{}, TMA_KV{})); + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); using MainloopPipeline = typename Ktraits::MainloopPipeline; @@ -97,6 +128,7 @@ struct CollectiveMainloopFwd { static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; + // Host side kernel arguments struct Arguments { Element const* ptr_Q; @@ -115,7 +147,8 @@ struct CollectiveMainloopFwd { typename Seqlen_traits::LayoutT layout_V; cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; - TMA_KV tma_load_K, tma_load_V; + TMA_KV tma_load_K; + TMA_V tma_load_V; float const softmax_scale_log2; }; @@ -136,12 +169,15 @@ struct CollectiveMainloopFwd { SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); - TMA_KV tma_load_V = make_tma_copy( + auto gmemLayoutVFp16 = args.shape_K; + auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16); + auto gmemLayoutV = cute::conditional_return>(gmemLayoutVFp8, gmemLayoutVFp16); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), gmemLayoutV, args.layout_V.stride()); + TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _0{}), - select<1, 2>(TileShape_MNK{}), + cute::conditional_return>(select<2, 1>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{})), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any return {args.layout_Q, args.layout_K, args.layout_V, cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), @@ -198,7 +234,10 @@ struct CollectiveMainloopFwd { Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); - Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); + auto gmemLayoutVFp16 = mainloop_params.shape_K; + auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16); + auto gmemLayoutV = cute::conditional_return>(gmemLayoutVFp8, gmemLayoutVFp16); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(gmemLayoutV); auto [m_block, bidh, bidb] = block_coord; int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); @@ -207,12 +246,34 @@ struct CollectiveMainloopFwd { uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gQ = seqlen_traits_q.get_local_tile_tensor( - mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K) - Tensor gK = seqlen_traits_k.get_local_tile_tensor( - mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) - Tensor gV = seqlen_traits_k.get_local_tile_tensor( - mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + + Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), TileShapeV{}, cute::conditional_return>(make_coord(_0{}, _), make_coord(_, _0{}))); // (N, K, _) + +#if 0 + if (threadIdx.x == 0 && blockIdx.x == 0) { + print ("\n"); + print (gV); + print ("\n"); + print (gK); + print ("\n"); + print ("\n"); + print (sV); + print ("\n"); + print (sK); + print ("\n"); + print (gmemLayoutVFp8); + print ("\n"); + print (gmemLayoutVFp16); + } + + // Tensor gQ = seqlen_traits_q.get_local_tile_tensor( + // mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K) + // Tensor gK = seqlen_traits_k.get_local_tile_tensor( + // mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + // Tensor gV = seqlen_traits_k.get_local_tile_tensor( + // mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); @@ -369,6 +430,13 @@ struct CollectiveMainloopFwd { // Note: S becomes P. Tensor tOrV = threadMma1.partition_fragment_B(sVt); + // Dummy sS to just get the shape correctly for GEMM-II. + Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutS{}); + Tensor tOrS = threadMma1.partition_fragment_A(sS); + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + ReorgCFp8toAFp8 reg2reg; + auto tOrPLayout = ReshapeTStoTP()(tSrS, tOrS); + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -382,7 +450,6 @@ struct CollectiveMainloopFwd { cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); warp_scheduler_barrier_sync(); flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); @@ -424,7 +491,11 @@ struct CollectiveMainloopFwd { } softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); + auto tSrSPrec = convert_type(tSrS); + if constexpr (is_same_v) { + reg2reg(tSrSPrec); + } + Tensor tOrP = make_tensor(tSrSPrec.data(), tOrPLayout); Tensor scores_scale = make_fragment_like(softmax.row_max); clear(scores_scale); @@ -456,7 +527,11 @@ struct CollectiveMainloopFwd { pipeline_v.consumer_release(smem_pipe_read_v); // release V ++smem_pipe_read_k; ++smem_pipe_read_v; - cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); + auto tSrSPrec = convert_type(tSrS); + if constexpr (is_same_v) { + reg2reg(tSrSPrec); + } + cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP); } #pragma unroll 1 @@ -479,7 +554,11 @@ struct CollectiveMainloopFwd { ++smem_pipe_read_k; ++smem_pipe_read_v; // softmax.rescale_o(tOrO, scores_scale); - cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); + auto tSrSPrec = convert_type(tSrS); + if constexpr (is_same_v) { + reg2reg(tSrSPrec); + } + cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP); } // Tell warp 0 that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); diff --git a/hopper/setup.py b/hopper/setup.py index 5d01a02..2d3d01b 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -116,6 +116,9 @@ if not SKIP_CUDA_BUILD: "flash_fwd_hdim128_bf16_sm90.cu", "flash_fwd_hdim256_fp16_sm90.cu", "flash_fwd_hdim256_bf16_sm90.cu", + "flash_fwd_hdim64_fp8_sm90.cu", + "flash_fwd_hdim128_fp8_sm90.cu", + "flash_fwd_hdim256_fp8_sm90.cu", "flash_bwd_hdim64_fp16_sm90.cu", "flash_bwd_hdim128_fp16_sm90.cu", "flash_bwd_hdim256_fp16_sm90.cu", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 55ec486..a37954e 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -170,7 +170,7 @@ def test_flash_attn_output( (113, 211), (108, 256), (256, 512), - (384, 256), + (384, 256), (512, 256), (640, 128), (1024, 1024), @@ -261,49 +261,87 @@ def test_flash_attn_varlen_output( reorder_ops=True, ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["gqa"]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [64, 128, 256]) +#@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [256]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_output_fp8( + seqlen_q, seqlen_k, d, causal, mha_type, dtype +): + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 + nheads = 6 + nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # batch_size = 1 + # nheads = 1 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.float16, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=torch.float16, requires_grad=True) + out, lse = flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype).transpose(1,3).contiguous().clone(), causal=causal) + q = q.to(dtype).to(torch.float16) + k = k.to(dtype).to(torch.float16) + v = v.to(dtype).to(torch.float16) + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + causal=causal, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + causal=causal, + upcast=False, + reorder_ops=True, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # g = torch.randn_like(out) - # if d <= 128: - # ( - # dq_unpad, - # dk_unpad, - # dv_unpad, - # ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) - # dk = dk_pad_fn(dk_unpad) - # dv = dk_pad_fn(dv_unpad) - # ( - # dq_ref, - # dk_ref, - # dv_ref, - # ) = torch.autograd.grad(out_ref, (q, k, v), g) - # ( - # dq_pt, - # dk_pt, - # dv_pt, - # ) = torch.autograd.grad(out_pt, (q, k, v), g) - # dq = dq_pad_fn(dq_unpad) - # print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - # print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - # print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - # print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - # print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - # print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - # print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - # print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - # print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - # print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - # print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - # print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - - # if d <= 128: - # assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() - # assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() - # assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() diff --git a/hopper/utils.h b/hopper/utils.h index 90116f8..21e7cc6 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -228,6 +228,88 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor CUTLASS_DEVICE auto operator()(Fragment &accum) { + + using namespace cute; + + // First update `mi` to the max per-row + // + auto VT = shape<0>(accum); // number of vector elements per tile. + auto MT = shape<1>(accum); // number of tiles along M. + auto NT = shape<2>(accum); // number of tiles along N. + + auto data = accum.data(); + int n = 0; + +#pragma unroll + for (int i = 0; i < MT; ++i) { + + // Traverse 2-rows + 2-cols (2x2) simultaneously. + +#pragma unroll + for (int k = 0; k < NT * size<2>(VT) / 2; ++k) { + + auto upper = *reinterpret_cast(&data[n]); + auto lower = *reinterpret_cast(&data[n + 4]); + + auto upper0 = __byte_perm(upper, lower, selectorEx0); + auto lower0 = __byte_perm(upper, lower, selectorEx1); + upper0 = + __shfl_sync(uint32_t(-1), upper0, upper_map[threadIdx.x % 4], 4); + lower0 = + __shfl_sync(uint32_t(-1), lower0, lower_map[threadIdx.x % 4], 4); + + uint32_t *data_32bit = reinterpret_cast(&data[n]); + data_32bit[0] = __byte_perm(upper0, lower0, selectorEx4); + data_32bit[1] = __byte_perm(upper0, lower0, selectorEx5); + n += 8; + } + } + } +}; + + +// Reshape Utility for converting the layout from accumulator of GEMM-I +// to Operand A of GEMM-II. +struct ReshapeTStoTP { + template + CUTLASS_DEVICE auto operator()(FragmentC &&tC, FragmentQ &&tQ) { + + // get the layout of one row of Q. + auto layoutQRow = make_layout_like(tQ(_, 0, _).layout()); + // get the layout of M dimension of C. + auto layoutCM = get<1>(tC.layout()); + return make_layout(get<0>(layoutQRow), layoutCM, get<1>(layoutQRow)); + } +}; template