fMHA: Add backward pass (#844)
* fMHA: Add backward pass * Better checks for strides/alignments * Remove fb-internal URL * torch.Tensor.untyped_storage requires pytorch 2.0+ * minor changes * make test --------- Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
e2d439ee7e
commit
9b8166e3f0
@ -37,8 +37,20 @@ cutlass_example_add_executable(
|
||||
fused_multihead_attention_variable_seqlen.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
41_fused_multi_head_attention_backward
|
||||
fused_multi_head_attention_backward.cu
|
||||
DISABLE_TESTS ON
|
||||
)
|
||||
|
||||
|
||||
add_custom_target(41_fused_multi_head_attention
|
||||
DEPENDS 41_fused_multi_head_attention_fixed_seqlen
|
||||
41_fused_multi_head_attention_variable_seqlen
|
||||
41_fused_multi_head_attention_backward
|
||||
)
|
||||
|
||||
add_test(
|
||||
NAME ctest_examples_41_fmha_backward_python
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/fmha_backward_test.py $<TARGET_FILE:41_fused_multi_head_attention_backward>
|
||||
)
|
||||
|
199
examples/41_fused_multi_head_attention/fmha_backward_test.py
Normal file
199
examples/41_fused_multi_head_attention/fmha_backward_test.py
Normal file
@ -0,0 +1,199 @@
|
||||
import argparse
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME
|
||||
import math
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable")
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.float16
|
||||
B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128
|
||||
causal = True
|
||||
repeat_count = 100
|
||||
|
||||
ATOL = {
|
||||
torch.float: 5e-4,
|
||||
torch.half: 9.5e-2,
|
||||
torch.bfloat16: 7e-1,
|
||||
}[dtype]
|
||||
|
||||
RTOL = {
|
||||
torch.float: 1e-4,
|
||||
torch.half: 2e-2,
|
||||
torch.bfloat16: 1e-1,
|
||||
}[dtype]
|
||||
|
||||
|
||||
assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ"
|
||||
|
||||
fmha_bw_binary = args.example_exe
|
||||
if not os.path.isfile(fmha_bw_binary):
|
||||
print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""")
|
||||
sys.exit(1)
|
||||
|
||||
def create_lower_triangular_mask():
|
||||
return torch.triu(torch.full( # type: ignore
|
||||
[1, Mq, Mkv],
|
||||
dtype=dtype,
|
||||
fill_value=float("-inf"),
|
||||
), diagonal=1)
|
||||
|
||||
def ref_mha_bmk(q, k, v, mask):
|
||||
# Multi-head attention with inputs/outputs in BMK format
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
v = v.float()
|
||||
|
||||
q = q * (1 / q.shape[-1] ** 0.5)
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if mask is not None:
|
||||
attn += mask
|
||||
attn_max = attn.max(-1, True).values
|
||||
attn_norm = (attn - attn_max).exp().sum(-1, True)
|
||||
attn = attn.softmax(-1)
|
||||
lse = attn_max + attn_norm.log()
|
||||
lse = lse.squeeze(2)
|
||||
return attn @ v, lse
|
||||
|
||||
|
||||
def bmhk2bmk(t):
|
||||
return t.permute((0, 2, 1, 3)).reshape(
|
||||
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
||||
)
|
||||
|
||||
def ref_mha_bmhk(q, k, v, mask):
|
||||
# Multi-head attention with inputs/outputs in BMHK format
|
||||
assert q.ndim == 4
|
||||
|
||||
out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask)
|
||||
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
||||
return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]])
|
||||
|
||||
def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta):
|
||||
lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension
|
||||
delta = delta.reshape([-1, delta.shape[-1], 1])
|
||||
|
||||
# bmhk -> bmk
|
||||
q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)]
|
||||
|
||||
attn_T = k @ q.transpose(-2, -1)
|
||||
if mask is not None:
|
||||
attn_T += mask.transpose(-2, -1)
|
||||
attn_T = attn_T * (1 / q.shape[-1] ** 0.5)
|
||||
attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]])
|
||||
attn_T = attn_T.exp()
|
||||
|
||||
grad_v = attn_T @ grad_out
|
||||
|
||||
dov = grad_out @ v.transpose(-2, -1)
|
||||
tmp = (dov - delta) * attn_T.transpose(-2, -1)
|
||||
tmp = tmp / (q.shape[-1] ** 0.5)
|
||||
|
||||
grad_q = tmp @ k
|
||||
grad_k = tmp.transpose(-2, -1) @ q
|
||||
|
||||
return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]]
|
||||
|
||||
|
||||
print("initializing tensors...")
|
||||
query = torch.randn([B, Mq, H, K], dtype=dtype)
|
||||
key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype)
|
||||
value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype)
|
||||
mask = create_lower_triangular_mask() if causal else None
|
||||
|
||||
# let PyTorch compute gradients
|
||||
query.requires_grad_(True)
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
print("computing fw...")
|
||||
out, lse = ref_mha_bmhk(query, key, value, mask=mask)
|
||||
out = out.to(dtype).contiguous()
|
||||
grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype)
|
||||
|
||||
print("computing bw with autograd...")
|
||||
out.backward(grad_out)
|
||||
scale = (1 / query.shape[-1] ** 0.5)
|
||||
|
||||
|
||||
# Additional data needed by the kernel
|
||||
delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous()
|
||||
pad_amount = (32 - (lse.shape[2] % 32)) % 32
|
||||
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
|
||||
|
||||
print("computing bw with reference implem...")
|
||||
gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta)
|
||||
|
||||
with PipedSubprocess(fmha_bw_binary) as bw_kernel:
|
||||
# Send kernel arguments
|
||||
bw_kernel.write(
|
||||
TORCH_DTYPE_NAME[query.dtype],
|
||||
"scale", scale,
|
||||
"head_dim", K,
|
||||
"head_dim_value", Kv,
|
||||
"num_queries", Mq,
|
||||
"num_keys", Mkv,
|
||||
"num_heads", H,
|
||||
"custom_mask_type", (1 if causal else 0),
|
||||
"num_batches", B,
|
||||
"repeat_count", repeat_count,
|
||||
)
|
||||
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
|
||||
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
|
||||
bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"])
|
||||
bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"])
|
||||
bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"])
|
||||
bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"])
|
||||
bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"])
|
||||
|
||||
if bw_kernel.read() != "OK":
|
||||
print("Got unexpected output")
|
||||
print(bw_kernel.subp.communicate()[0])
|
||||
sys.exit(0)
|
||||
|
||||
# Read kernel output
|
||||
gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float()
|
||||
gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float()
|
||||
gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float()
|
||||
runtime_ms = float(bw_kernel.readNamed("runtime_ms"))
|
||||
|
||||
float_ops = B * H * sum([
|
||||
# att = Q @ K.transpose
|
||||
Mq * Mkv * K * 2,
|
||||
# att @ dO
|
||||
Mkv * Mq * Kv * 2,
|
||||
# dov = dO @ V
|
||||
Mq * Kv * Mkv * 2,
|
||||
# dov @ K
|
||||
Mq * K * Mkv * 2,
|
||||
# dov @ Q
|
||||
Mq * K * Mkv * 2,
|
||||
])
|
||||
if causal:
|
||||
float_ops //= 2
|
||||
|
||||
print(f"""
|
||||
Fused multi-head attention - backward
|
||||
batch_size={B}
|
||||
num_queries={Mq}
|
||||
num_keys={Mkv}
|
||||
num_heads={H}
|
||||
head_dim={K}
|
||||
head_dim_value={Kv}
|
||||
|
||||
Correctness:
|
||||
grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()})
|
||||
grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()})
|
||||
grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()})
|
||||
(atol={ATOL} / rtol={RTOL})
|
||||
Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops)
|
||||
""")
|
||||
|
||||
assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
|
||||
assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
|
||||
assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
|
@ -0,0 +1,295 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "kernel_backward.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
|
||||
|
||||
using Arch = cutlass::arch::Sm80;
|
||||
static constexpr int kMaxK = 128;
|
||||
|
||||
template <typename ArchTag, typename Element, int kMaxK>
|
||||
struct DefaultKernel {
|
||||
// Some heuristics to select the best kernel (tested on Sm60, Sm70, Sm80)
|
||||
// NOTE: Requires quite a lot of shmem for Sm80+,
|
||||
// so might require tweaking those manually for Sm86/Sm89
|
||||
|
||||
static constexpr bool kSupports64x128 =
|
||||
ArchTag::kMinComputeCapability >= 80 ||
|
||||
(ArchTag::kMinComputeCapability >= 70 &&
|
||||
cutlass::sizeof_bits<Element>::value <= 16);
|
||||
static constexpr int kBlockSizeI = kSupports64x128 && kMaxK > 64 ? 128 : 64;
|
||||
static constexpr bool kIsHalf = cutlass::sizeof_bits<Element>::value <= 16;
|
||||
static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
|
||||
static constexpr bool kPreload = kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF;
|
||||
static constexpr int kBlockSizeJ = kPreload && kMaxK > 64 ? 128 : 64;
|
||||
|
||||
using Kernel = AttentionBackwardKernel<
|
||||
Arch,
|
||||
Element,
|
||||
true, // kIsAligned_
|
||||
false, // kApplyDropout_
|
||||
kPreload,// kPreload_
|
||||
kBlockSizeI, // kBlockSizeI_,
|
||||
kBlockSizeJ, // kBlockSizeJ_,
|
||||
kMaxK // kMaxK
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
template <typename T> struct TypeName;
|
||||
template <> struct TypeName<float> { static constexpr const char* Name = "f32"; };
|
||||
template <> struct TypeName<cutlass::half_t> { static constexpr const char* Name = "f16"; };
|
||||
template <> struct TypeName<cutlass::bfloat16_t> { static constexpr const char* Name = "b16"; };
|
||||
|
||||
void readExpect(std::string const& expected) {
|
||||
std::string read;
|
||||
std::cin >> read;
|
||||
if (read != expected) {
|
||||
std::cerr << "FATAL: Read '" << read << "' but expected '" << expected << "'" << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helpers to read from stdin
|
||||
template <typename Element>
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> readTensorOnDevice(std::string const& expectedName) {
|
||||
readExpect("tensor_begin");
|
||||
readExpect(std::string(TypeName<Element>::Name) + ":" + expectedName);
|
||||
uint64_t len = 0;
|
||||
std::cin >> len;
|
||||
readExpect("file");
|
||||
std::string filename;
|
||||
std::cin >> filename;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor({int64_t(1), int64_t(len / sizeof(Element))});
|
||||
uint8_t* data = (uint8_t*)tensor.host_data();
|
||||
|
||||
std::fstream myFile(filename, std::ios::in | std::ios::binary );
|
||||
myFile.read((char*)data, len);
|
||||
readExpect("tensor_end");
|
||||
tensor.sync_device();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
int64_t readInt64(std::string const& expectedName) {
|
||||
readExpect(expectedName);
|
||||
int64_t s = 0;
|
||||
std::cin >> s;
|
||||
return s;
|
||||
}
|
||||
|
||||
float readFloat(std::string const& expectedName) {
|
||||
readExpect(expectedName);
|
||||
float s = 0;
|
||||
std::cin >> s;
|
||||
return s;
|
||||
}
|
||||
|
||||
// Writing
|
||||
template <typename Element>
|
||||
void writeTensor(std::string const& name, cutlass::HostTensor<Element, cutlass::layout::RowMajor>& tensor) {
|
||||
tensor.sync_host(); // device->host
|
||||
size_t u8len = tensor.size() * sizeof(Element);
|
||||
|
||||
// Python is expected to provide a file name to write to
|
||||
readExpect("tmpfile");
|
||||
std::string tmpfile;
|
||||
std::cin >> tmpfile;
|
||||
|
||||
uint8_t* data = (uint8_t*)tensor.host_data();
|
||||
std::fstream myFile(tmpfile, std::ios::out | std::ios::binary );
|
||||
myFile.write((char*)data, u8len);
|
||||
myFile.close();
|
||||
|
||||
std::cout << "tensor_begin " << TypeName<Element>::Name << ":" << name << " ";
|
||||
std::cout << u8len << " file " << tmpfile << " tensor_end" << std::endl;
|
||||
}
|
||||
|
||||
void writeInt64(std::string const& name, int64_t value) {
|
||||
std::cout << name << " " << value << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element>
|
||||
int runKernel() {
|
||||
using Kernel = typename DefaultKernel<Arch, Element, kMaxK>::Kernel;
|
||||
|
||||
#define READ_I64(NAME) p.NAME = (decltype(p.NAME))readInt64(#NAME)
|
||||
#define READ_TENSOR_AND_STRIDES_BMH(DT, NAME, NAME_XS) \
|
||||
auto storage##NAME = readTensorOnDevice<DT>(#NAME); \
|
||||
p.NAME##_ptr = storage##NAME.device_data(); \
|
||||
READ_I64(NAME_XS##_strideB); \
|
||||
READ_I64(NAME_XS##_strideM); \
|
||||
READ_I64(NAME_XS##_strideH);
|
||||
|
||||
#define CUDA_CHECK(FN) { \
|
||||
auto cudaError = FN; \
|
||||
if (cudaError != cudaSuccess) { \
|
||||
std::cerr << "FATAL: " #FN " failed: " << cudaGetErrorString(cudaError) << std::endl; \
|
||||
return -1; \
|
||||
} \
|
||||
}
|
||||
|
||||
typename Kernel::Params p;
|
||||
p.scale = readFloat("scale");
|
||||
READ_I64(head_dim);
|
||||
READ_I64(head_dim_value);
|
||||
READ_I64(num_queries);
|
||||
READ_I64(num_keys);
|
||||
READ_I64(num_heads);
|
||||
READ_I64(custom_mask_type);
|
||||
READ_I64(num_batches);
|
||||
int64_t repeat_count = readInt64("repeat_count");
|
||||
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, value, v);
|
||||
auto lse = readTensorOnDevice<typename Kernel::lse_scalar_t>("logsumexp");
|
||||
p.logsumexp_ptr = lse.device_data();
|
||||
p.lse_strideB = readInt64("lse_strideB");
|
||||
p.lse_strideH = readInt64("lse_strideH");
|
||||
|
||||
// output
|
||||
auto stOutput = readTensorOnDevice<Element>("output");
|
||||
p.output_ptr = stOutput.device_data();
|
||||
READ_I64(o_strideB);
|
||||
auto o_strideM = readInt64("o_strideM");
|
||||
if (o_strideM != p.o_strideM()) {
|
||||
std::cerr << "Invalid `o_strideM`: " << o_strideM << " - expected " << p.o_strideM();
|
||||
return 2;
|
||||
}
|
||||
READ_I64(o_strideH);
|
||||
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, grad_output, gO);
|
||||
|
||||
auto stDelta = readTensorOnDevice<typename Kernel::accum_t>("delta");
|
||||
p.delta_ptr = stDelta.device_data();
|
||||
READ_I64(delta_strideB);
|
||||
READ_I64(delta_strideH);
|
||||
|
||||
// Allocate workspace
|
||||
if (p.workspace_size()) {
|
||||
cudaMalloc(&p.workspace, p.workspace_size());
|
||||
}
|
||||
|
||||
// Allocate outputs in BMHK format
|
||||
p.gQKV_strideM_multiplier = 1;
|
||||
p.gQ_strideH = p.head_dim;
|
||||
p.gQ_strideB = p.gQ_strideM() * p.num_queries;
|
||||
p.gK_strideH = p.head_dim;
|
||||
p.gK_strideB = p.gK_strideM() * p.num_keys;
|
||||
p.gV_strideH = p.head_dim_value;
|
||||
p.gV_strideB = p.gV_strideM() * p.num_keys;
|
||||
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gQ({int64_t(1), p.gQ_strideB * p.num_batches});
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gK({int64_t(1), p.gK_strideB * p.num_batches});
|
||||
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gV({int64_t(1), p.gV_strideB * p.num_batches});
|
||||
p.grad_query_ptr = gQ.device_data();
|
||||
p.grad_key_ptr = gK.device_data();
|
||||
p.grad_value_ptr = gV.device_data();
|
||||
|
||||
if (!Kernel::check_supported(p)) {
|
||||
std::cerr << "FATAL: Kernel does not support these inputs" << std::endl;
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Run kernel
|
||||
cudaDeviceSynchronize();
|
||||
auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
|
||||
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes)));
|
||||
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
|
||||
|
||||
// Write outputs
|
||||
std::cout << "OK ";
|
||||
writeTensor("grad_query", gQ);
|
||||
writeInt64("gQ_strideB", p.gQ_strideB);
|
||||
writeInt64("gQ_strideM", p.gQ_strideM());
|
||||
writeInt64("gQ_strideH", p.gQ_strideH);
|
||||
writeTensor("grad_key", gK);
|
||||
writeInt64("gK_strideB", p.gK_strideB);
|
||||
writeInt64("gK_strideM", p.gK_strideM());
|
||||
writeInt64("gK_strideH", p.gK_strideH);
|
||||
writeTensor("grad_value", gV);
|
||||
writeInt64("gV_strideB", p.gV_strideB);
|
||||
writeInt64("gV_strideM", p.gV_strideM());
|
||||
writeInt64("gV_strideH", p.gV_strideH);
|
||||
|
||||
// Timing
|
||||
cudaEvent_t events[2];
|
||||
for (auto & event : events) {
|
||||
CUDA_CHECK(cudaEventCreate(&event));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(events[0]));
|
||||
for (int i = 0; i < repeat_count; ++i) {
|
||||
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(events[1]));
|
||||
CUDA_CHECK(cudaEventSynchronize(events[1]));
|
||||
// Measure elapsed runtime
|
||||
float runtime_ms = 0;
|
||||
CUDA_CHECK(cudaEventElapsedTime(&runtime_ms, events[0], events[1]));
|
||||
|
||||
std::cout << "runtime_ms " << runtime_ms / float(repeat_count) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::ios_base::sync_with_stdio(false);
|
||||
|
||||
std::string dtype;
|
||||
std::cin >> dtype;
|
||||
std::cerr << "Running kernel with dtype: " << dtype << std::endl;
|
||||
if (dtype == "f16") {
|
||||
return runKernel<cutlass::half_t>();
|
||||
} else if (dtype == "b16") {
|
||||
return runKernel<cutlass::bfloat16_t>();
|
||||
} else if (dtype == "f32") {
|
||||
return runKernel<float>();
|
||||
} else {
|
||||
std::cerr << "FATAL: Unknown dtype: " << dtype << std::endl;
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -264,9 +264,8 @@ class NoOpWarpIteratorScale {
|
||||
// in pipelined+multistage MMA implementations we keep an array of fragments.
|
||||
// if we aren't using scaling we don't want to waste registers on fragments
|
||||
// of scale elements, so ideally this would be sized 0.
|
||||
// using size 1 is kind of a hack to get around arrays of zero-sized objects
|
||||
// not being allowed. the compiler is probably smart enough to wipe it out
|
||||
// anyways.
|
||||
// Since arrays of zero-sized objects are not allowed, using size as 1.
|
||||
// The compiler will most likely wipe it out anyways.
|
||||
using Fragment = cutlass::Array<char, 1>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
@ -115,10 +115,10 @@
|
||||
std::cerr << #PTR " is not correctly aligned\n"; \
|
||||
return false; \
|
||||
}
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
std::cerr << #COND " failed\n"; \
|
||||
return false; \
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
std::cerr << "'" #COND "' failed: " << ERR << "\n"; \
|
||||
return false; \
|
||||
}
|
||||
#endif
|
||||
|
||||
|
2286
examples/41_fused_multi_head_attention/kernel_backward.h
Normal file
2286
examples/41_fused_multi_head_attention/kernel_backward.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -573,27 +573,33 @@ struct AttentionKernel {
|
||||
if (kSupportsBias) {
|
||||
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideB % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned (strideB)");
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideH % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.bias_strideM % kAlignmentQ == 0,
|
||||
"attn_bias is not correctly aligned");
|
||||
}
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
p.q_strideM % kAlignmentQ == 0,
|
||||
"query is not correctly aligned (strideM)");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
|
||||
p.k_strideM % kAlignmentK == 0,
|
||||
"key is not correctly aligned (strideM)");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
|
||||
p.v_strideM % kAlignmentV == 0,
|
||||
"value is not correctly aligned (strideM)");
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
|
||||
"query is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
|
||||
p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
|
||||
"key is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
|
||||
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
|
||||
"value is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
|
||||
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
|
||||
|
112
examples/41_fused_multi_head_attention/piped_subprocess.py
Normal file
112
examples/41_fused_multi_head_attention/piped_subprocess.py
Normal file
@ -0,0 +1,112 @@
|
||||
from typing import List
|
||||
import torch
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
TORCH_DTYPE_NAME = {
|
||||
torch.float32: "f32",
|
||||
torch.float16: "f16",
|
||||
torch.bfloat16: "b16"
|
||||
}
|
||||
NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()}
|
||||
|
||||
def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor:
|
||||
# PyTorch >= 2.0
|
||||
if hasattr(tensor, 'untyped_storage'):
|
||||
return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage())
|
||||
return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped())
|
||||
|
||||
class PipedSubprocess:
|
||||
def __init__(self, binary: str) -> None:
|
||||
self.binary = binary
|
||||
self.tempdir_ctx = tempfile.TemporaryDirectory()
|
||||
|
||||
def __enter__(self) -> "PipedSubprocess":
|
||||
self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0)
|
||||
self.tempdir = self.tempdir_ctx.__enter__()
|
||||
self.file_counter = 0
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def temp_filename(self, suffix: str) -> str:
|
||||
self.file_counter += 1
|
||||
return os.path.join(self.tempdir, f"{self.file_counter}{suffix}")
|
||||
|
||||
def write(self, *args) -> None:
|
||||
for a in args:
|
||||
self.subp.stdin.write(str(a) + " ")
|
||||
|
||||
def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None:
|
||||
print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}")
|
||||
tensor_u8 = _tensor_from_storage(tensor, torch.uint8)
|
||||
self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0])
|
||||
filename = self.temp_filename(f"{name}.tensor")
|
||||
assert tensor.storage_offset() == 0
|
||||
with open(filename, "wb+") as fd:
|
||||
fd.write(bytes(tensor_u8.numpy()))
|
||||
self.write("file", filename)
|
||||
self.write("tensor_end")
|
||||
|
||||
for stride_name, stride_value in zip(stride_names, tensor.stride()):
|
||||
self.write(stride_name, stride_value)
|
||||
|
||||
def readTensor(self, name, stride_name, shape) -> torch.Tensor:
|
||||
tmpfile = self.temp_filename(f"{name}.tensor")
|
||||
self.write("tmpfile", tmpfile)
|
||||
|
||||
self.readExpect("tensor_begin")
|
||||
dtype_str, name = self.read().split(":")
|
||||
print(f"C++->Py : {dtype_str}:{name}")
|
||||
u8len = int(self.read())
|
||||
dtype = NAME_TORCH_DTYPE[dtype_str]
|
||||
|
||||
self.readExpect("file")
|
||||
self.readExpect(tmpfile)
|
||||
|
||||
with open(tmpfile, "rb") as fd:
|
||||
data = fd.read(u8len)
|
||||
# `np.array` is not strictly needed, but avoids a torch warning
|
||||
tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len)
|
||||
self.readExpect("tensor_end")
|
||||
|
||||
tensor = _tensor_from_storage(tensor_u8, dtype)
|
||||
strides = []
|
||||
for sn in stride_name:
|
||||
self.readExpect(sn)
|
||||
strides.append(int(self.read()))
|
||||
if len(strides) != shape:
|
||||
strides.append(1)
|
||||
assert len(strides) == len(shape), name
|
||||
return torch.as_strided(tensor, shape, strides)
|
||||
|
||||
def readNamed(self, name: str):
|
||||
self.readExpect(name)
|
||||
return self.read()
|
||||
|
||||
def readExpect(self, what: str) -> None:
|
||||
r = self.read()
|
||||
if r != what:
|
||||
raise ValueError(f"Read {r} but expected {what}")
|
||||
|
||||
def read(self):
|
||||
read_all = []
|
||||
# Skip initial whitespace
|
||||
while True:
|
||||
r = self.subp.stdout.read(1)
|
||||
if r not in [' ', "\n"]:
|
||||
read_all.append(r)
|
||||
break
|
||||
# Read data
|
||||
while True:
|
||||
r = self.subp.stdout.read(1)
|
||||
if r in [' ', "\n"]:
|
||||
break
|
||||
read_all.append(r)
|
||||
return ''.join(read_all)
|
||||
|
@ -30,6 +30,11 @@ include(GNUInstallDirs)
|
||||
|
||||
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
# Set Python3_EXECUTABLE to be visible from global scope.
|
||||
# In CMake 3.24, this could be supported by adding the GLOBAL field
|
||||
# to find_package above (https://cmake.org/cmake/help/latest/command/find_package.html#id7)
|
||||
set(Python3_EXECUTABLE ${Python3_EXECUTABLE} CACHE INTERNAL "Path to python3 executable")
|
||||
|
||||
add_library(cutlass_library_includes INTERFACE)
|
||||
add_library(nvidia::cutlass::library::includes ALIAS cutlass_library_includes)
|
||||
set_target_properties(cutlass_library_includes PROPERTIES EXPORT_NAME library::includes)
|
||||
|
Loading…
Reference in New Issue
Block a user