fMHA: Add backward pass (#844)

* fMHA: Add backward pass

* Better checks for strides/alignments

* Remove fb-internal URL

* torch.Tensor.untyped_storage requires pytorch 2.0+

* minor changes

* make test

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
dan_the_3rd 2023-04-07 02:44:58 +02:00 committed by GitHub
parent e2d439ee7e
commit 9b8166e3f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 2931 additions and 17 deletions

View File

@ -37,8 +37,20 @@ cutlass_example_add_executable(
fused_multihead_attention_variable_seqlen.cu
)
cutlass_example_add_executable(
41_fused_multi_head_attention_backward
fused_multi_head_attention_backward.cu
DISABLE_TESTS ON
)
add_custom_target(41_fused_multi_head_attention
DEPENDS 41_fused_multi_head_attention_fixed_seqlen
41_fused_multi_head_attention_variable_seqlen
41_fused_multi_head_attention_backward
)
add_test(
NAME ctest_examples_41_fmha_backward_python
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/fmha_backward_test.py $<TARGET_FILE:41_fused_multi_head_attention_backward>
)

View File

@ -0,0 +1,199 @@
import argparse
import torch
import sys
import os
from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME
import math
parser = argparse.ArgumentParser()
parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable")
args = parser.parse_args()
torch.manual_seed(0)
dtype = torch.float16
B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128
causal = True
repeat_count = 100
ATOL = {
torch.float: 5e-4,
torch.half: 9.5e-2,
torch.bfloat16: 7e-1,
}[dtype]
RTOL = {
torch.float: 1e-4,
torch.half: 2e-2,
torch.bfloat16: 1e-1,
}[dtype]
assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ"
fmha_bw_binary = args.example_exe
if not os.path.isfile(fmha_bw_binary):
print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""")
sys.exit(1)
def create_lower_triangular_mask():
return torch.triu(torch.full( # type: ignore
[1, Mq, Mkv],
dtype=dtype,
fill_value=float("-inf"),
), diagonal=1)
def ref_mha_bmk(q, k, v, mask):
# Multi-head attention with inputs/outputs in BMK format
q = q.float()
k = k.float()
v = v.float()
q = q * (1 / q.shape[-1] ** 0.5)
attn = q @ k.transpose(-2, -1)
if mask is not None:
attn += mask
attn_max = attn.max(-1, True).values
attn_norm = (attn - attn_max).exp().sum(-1, True)
attn = attn.softmax(-1)
lse = attn_max + attn_norm.log()
lse = lse.squeeze(2)
return attn @ v, lse
def bmhk2bmk(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)
def ref_mha_bmhk(q, k, v, mask):
# Multi-head attention with inputs/outputs in BMHK format
assert q.ndim == 4
out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask)
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]])
def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta):
lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension
delta = delta.reshape([-1, delta.shape[-1], 1])
# bmhk -> bmk
q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)]
attn_T = k @ q.transpose(-2, -1)
if mask is not None:
attn_T += mask.transpose(-2, -1)
attn_T = attn_T * (1 / q.shape[-1] ** 0.5)
attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]])
attn_T = attn_T.exp()
grad_v = attn_T @ grad_out
dov = grad_out @ v.transpose(-2, -1)
tmp = (dov - delta) * attn_T.transpose(-2, -1)
tmp = tmp / (q.shape[-1] ** 0.5)
grad_q = tmp @ k
grad_k = tmp.transpose(-2, -1) @ q
return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]]
print("initializing tensors...")
query = torch.randn([B, Mq, H, K], dtype=dtype)
key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype)
value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype)
mask = create_lower_triangular_mask() if causal else None
# let PyTorch compute gradients
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
print("computing fw...")
out, lse = ref_mha_bmhk(query, key, value, mask=mask)
out = out.to(dtype).contiguous()
grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype)
print("computing bw with autograd...")
out.backward(grad_out)
scale = (1 / query.shape[-1] ** 0.5)
# Additional data needed by the kernel
delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous()
pad_amount = (32 - (lse.shape[2] % 32)) % 32
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
print("computing bw with reference implem...")
gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta)
with PipedSubprocess(fmha_bw_binary) as bw_kernel:
# Send kernel arguments
bw_kernel.write(
TORCH_DTYPE_NAME[query.dtype],
"scale", scale,
"head_dim", K,
"head_dim_value", Kv,
"num_queries", Mq,
"num_keys", Mkv,
"num_heads", H,
"custom_mask_type", (1 if causal else 0),
"num_batches", B,
"repeat_count", repeat_count,
)
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"])
bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"])
bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"])
bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"])
bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"])
if bw_kernel.read() != "OK":
print("Got unexpected output")
print(bw_kernel.subp.communicate()[0])
sys.exit(0)
# Read kernel output
gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float()
gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float()
gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float()
runtime_ms = float(bw_kernel.readNamed("runtime_ms"))
float_ops = B * H * sum([
# att = Q @ K.transpose
Mq * Mkv * K * 2,
# att @ dO
Mkv * Mq * Kv * 2,
# dov = dO @ V
Mq * Kv * Mkv * 2,
# dov @ K
Mq * K * Mkv * 2,
# dov @ Q
Mq * K * Mkv * 2,
])
if causal:
float_ops //= 2
print(f"""
Fused multi-head attention - backward
batch_size={B}
num_queries={Mq}
num_keys={Mkv}
num_heads={H}
head_dim={K}
head_dim_value={Kv}
Correctness:
grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()})
grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()})
grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()})
(atol={ATOL} / rtol={RTOL})
Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops)
""")
assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"
assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!"

View File

@ -0,0 +1,295 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <vector>
#include <iostream>
#include <fstream>
#include "kernel_backward.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
using Arch = cutlass::arch::Sm80;
static constexpr int kMaxK = 128;
template <typename ArchTag, typename Element, int kMaxK>
struct DefaultKernel {
// Some heuristics to select the best kernel (tested on Sm60, Sm70, Sm80)
// NOTE: Requires quite a lot of shmem for Sm80+,
// so might require tweaking those manually for Sm86/Sm89
static constexpr bool kSupports64x128 =
ArchTag::kMinComputeCapability >= 80 ||
(ArchTag::kMinComputeCapability >= 70 &&
cutlass::sizeof_bits<Element>::value <= 16);
static constexpr int kBlockSizeI = kSupports64x128 && kMaxK > 64 ? 128 : 64;
static constexpr bool kIsHalf = cutlass::sizeof_bits<Element>::value <= 16;
static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
static constexpr bool kPreload = kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF;
static constexpr int kBlockSizeJ = kPreload && kMaxK > 64 ? 128 : 64;
using Kernel = AttentionBackwardKernel<
Arch,
Element,
true, // kIsAligned_
false, // kApplyDropout_
kPreload,// kPreload_
kBlockSizeI, // kBlockSizeI_,
kBlockSizeJ, // kBlockSizeJ_,
kMaxK // kMaxK
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace {
template <typename T> struct TypeName;
template <> struct TypeName<float> { static constexpr const char* Name = "f32"; };
template <> struct TypeName<cutlass::half_t> { static constexpr const char* Name = "f16"; };
template <> struct TypeName<cutlass::bfloat16_t> { static constexpr const char* Name = "b16"; };
void readExpect(std::string const& expected) {
std::string read;
std::cin >> read;
if (read != expected) {
std::cerr << "FATAL: Read '" << read << "' but expected '" << expected << "'" << std::endl;
std::exit(1);
}
}
/// Helpers to read from stdin
template <typename Element>
cutlass::HostTensor<Element, cutlass::layout::RowMajor> readTensorOnDevice(std::string const& expectedName) {
readExpect("tensor_begin");
readExpect(std::string(TypeName<Element>::Name) + ":" + expectedName);
uint64_t len = 0;
std::cin >> len;
readExpect("file");
std::string filename;
std::cin >> filename;
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor({int64_t(1), int64_t(len / sizeof(Element))});
uint8_t* data = (uint8_t*)tensor.host_data();
std::fstream myFile(filename, std::ios::in | std::ios::binary );
myFile.read((char*)data, len);
readExpect("tensor_end");
tensor.sync_device();
return tensor;
}
int64_t readInt64(std::string const& expectedName) {
readExpect(expectedName);
int64_t s = 0;
std::cin >> s;
return s;
}
float readFloat(std::string const& expectedName) {
readExpect(expectedName);
float s = 0;
std::cin >> s;
return s;
}
// Writing
template <typename Element>
void writeTensor(std::string const& name, cutlass::HostTensor<Element, cutlass::layout::RowMajor>& tensor) {
tensor.sync_host(); // device->host
size_t u8len = tensor.size() * sizeof(Element);
// Python is expected to provide a file name to write to
readExpect("tmpfile");
std::string tmpfile;
std::cin >> tmpfile;
uint8_t* data = (uint8_t*)tensor.host_data();
std::fstream myFile(tmpfile, std::ios::out | std::ios::binary );
myFile.write((char*)data, u8len);
myFile.close();
std::cout << "tensor_begin " << TypeName<Element>::Name << ":" << name << " ";
std::cout << u8len << " file " << tmpfile << " tensor_end" << std::endl;
}
void writeInt64(std::string const& name, int64_t value) {
std::cout << name << " " << value << std::endl;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element>
int runKernel() {
using Kernel = typename DefaultKernel<Arch, Element, kMaxK>::Kernel;
#define READ_I64(NAME) p.NAME = (decltype(p.NAME))readInt64(#NAME)
#define READ_TENSOR_AND_STRIDES_BMH(DT, NAME, NAME_XS) \
auto storage##NAME = readTensorOnDevice<DT>(#NAME); \
p.NAME##_ptr = storage##NAME.device_data(); \
READ_I64(NAME_XS##_strideB); \
READ_I64(NAME_XS##_strideM); \
READ_I64(NAME_XS##_strideH);
#define CUDA_CHECK(FN) { \
auto cudaError = FN; \
if (cudaError != cudaSuccess) { \
std::cerr << "FATAL: " #FN " failed: " << cudaGetErrorString(cudaError) << std::endl; \
return -1; \
} \
}
typename Kernel::Params p;
p.scale = readFloat("scale");
READ_I64(head_dim);
READ_I64(head_dim_value);
READ_I64(num_queries);
READ_I64(num_keys);
READ_I64(num_heads);
READ_I64(custom_mask_type);
READ_I64(num_batches);
int64_t repeat_count = readInt64("repeat_count");
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);
READ_TENSOR_AND_STRIDES_BMH(Element, value, v);
auto lse = readTensorOnDevice<typename Kernel::lse_scalar_t>("logsumexp");
p.logsumexp_ptr = lse.device_data();
p.lse_strideB = readInt64("lse_strideB");
p.lse_strideH = readInt64("lse_strideH");
// output
auto stOutput = readTensorOnDevice<Element>("output");
p.output_ptr = stOutput.device_data();
READ_I64(o_strideB);
auto o_strideM = readInt64("o_strideM");
if (o_strideM != p.o_strideM()) {
std::cerr << "Invalid `o_strideM`: " << o_strideM << " - expected " << p.o_strideM();
return 2;
}
READ_I64(o_strideH);
READ_TENSOR_AND_STRIDES_BMH(Element, grad_output, gO);
auto stDelta = readTensorOnDevice<typename Kernel::accum_t>("delta");
p.delta_ptr = stDelta.device_data();
READ_I64(delta_strideB);
READ_I64(delta_strideH);
// Allocate workspace
if (p.workspace_size()) {
cudaMalloc(&p.workspace, p.workspace_size());
}
// Allocate outputs in BMHK format
p.gQKV_strideM_multiplier = 1;
p.gQ_strideH = p.head_dim;
p.gQ_strideB = p.gQ_strideM() * p.num_queries;
p.gK_strideH = p.head_dim;
p.gK_strideB = p.gK_strideM() * p.num_keys;
p.gV_strideH = p.head_dim_value;
p.gV_strideB = p.gV_strideM() * p.num_keys;
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gQ({int64_t(1), p.gQ_strideB * p.num_batches});
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gK({int64_t(1), p.gK_strideB * p.num_batches});
cutlass::HostTensor<Element, cutlass::layout::RowMajor> gV({int64_t(1), p.gV_strideB * p.num_batches});
p.grad_query_ptr = gQ.device_data();
p.grad_key_ptr = gK.device_data();
p.grad_value_ptr = gV.device_data();
if (!Kernel::check_supported(p)) {
std::cerr << "FATAL: Kernel does not support these inputs" << std::endl;
return 2;
}
// Run kernel
cudaDeviceSynchronize();
auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
CUDA_CHECK(cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes)));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
// Write outputs
std::cout << "OK ";
writeTensor("grad_query", gQ);
writeInt64("gQ_strideB", p.gQ_strideB);
writeInt64("gQ_strideM", p.gQ_strideM());
writeInt64("gQ_strideH", p.gQ_strideH);
writeTensor("grad_key", gK);
writeInt64("gK_strideB", p.gK_strideB);
writeInt64("gK_strideM", p.gK_strideM());
writeInt64("gK_strideH", p.gK_strideH);
writeTensor("grad_value", gV);
writeInt64("gV_strideB", p.gV_strideB);
writeInt64("gV_strideM", p.gV_strideM());
writeInt64("gV_strideH", p.gV_strideH);
// Timing
cudaEvent_t events[2];
for (auto & event : events) {
CUDA_CHECK(cudaEventCreate(&event));
}
CUDA_CHECK(cudaEventRecord(events[0]));
for (int i = 0; i < repeat_count; ++i) {
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
}
CUDA_CHECK(cudaEventRecord(events[1]));
CUDA_CHECK(cudaEventSynchronize(events[1]));
// Measure elapsed runtime
float runtime_ms = 0;
CUDA_CHECK(cudaEventElapsedTime(&runtime_ms, events[0], events[1]));
std::cout << "runtime_ms " << runtime_ms / float(repeat_count) << std::endl;
return 0;
}
int main() {
std::ios_base::sync_with_stdio(false);
std::string dtype;
std::cin >> dtype;
std::cerr << "Running kernel with dtype: " << dtype << std::endl;
if (dtype == "f16") {
return runKernel<cutlass::half_t>();
} else if (dtype == "b16") {
return runKernel<cutlass::bfloat16_t>();
} else if (dtype == "f32") {
return runKernel<float>();
} else {
std::cerr << "FATAL: Unknown dtype: " << dtype << std::endl;
return 3;
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -264,9 +264,8 @@ class NoOpWarpIteratorScale {
// in pipelined+multistage MMA implementations we keep an array of fragments.
// if we aren't using scaling we don't want to waste registers on fragments
// of scale elements, so ideally this would be sized 0.
// using size 1 is kind of a hack to get around arrays of zero-sized objects
// not being allowed. the compiler is probably smart enough to wipe it out
// anyways.
// Since arrays of zero-sized objects are not allowed, using size as 1.
// The compiler will most likely wipe it out anyways.
using Fragment = cutlass::Array<char, 1>;
CUTLASS_HOST_DEVICE

View File

@ -115,10 +115,10 @@
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << #COND " failed\n"; \
return false; \
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << "'" #COND "' failed: " << ERR << "\n"; \
return false; \
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -573,27 +573,33 @@ struct AttentionKernel {
if (kSupportsBias) {
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
XFORMERS_CHECK(
p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideB)");
XFORMERS_CHECK(
p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
}
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
p.q_strideM % kAlignmentQ == 0,
"query is not correctly aligned (strideM)");
XFORMERS_CHECK(
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
p.k_strideM % kAlignmentK == 0,
"key is not correctly aligned (strideM)");
XFORMERS_CHECK(
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
p.v_strideM % kAlignmentV == 0,
"value is not correctly aligned (strideM)");
XFORMERS_CHECK(
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
"query is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
"key is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
"value is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");

View File

@ -0,0 +1,112 @@
from typing import List
import torch
import subprocess
import sys
import tempfile
import os
import numpy as np
TORCH_DTYPE_NAME = {
torch.float32: "f32",
torch.float16: "f16",
torch.bfloat16: "b16"
}
NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()}
def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor:
# PyTorch >= 2.0
if hasattr(tensor, 'untyped_storage'):
return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage())
return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped())
class PipedSubprocess:
def __init__(self, binary: str) -> None:
self.binary = binary
self.tempdir_ctx = tempfile.TemporaryDirectory()
def __enter__(self) -> "PipedSubprocess":
self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0)
self.tempdir = self.tempdir_ctx.__enter__()
self.file_counter = 0
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb)
def temp_filename(self, suffix: str) -> str:
self.file_counter += 1
return os.path.join(self.tempdir, f"{self.file_counter}{suffix}")
def write(self, *args) -> None:
for a in args:
self.subp.stdin.write(str(a) + " ")
def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None:
print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}")
tensor_u8 = _tensor_from_storage(tensor, torch.uint8)
self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0])
filename = self.temp_filename(f"{name}.tensor")
assert tensor.storage_offset() == 0
with open(filename, "wb+") as fd:
fd.write(bytes(tensor_u8.numpy()))
self.write("file", filename)
self.write("tensor_end")
for stride_name, stride_value in zip(stride_names, tensor.stride()):
self.write(stride_name, stride_value)
def readTensor(self, name, stride_name, shape) -> torch.Tensor:
tmpfile = self.temp_filename(f"{name}.tensor")
self.write("tmpfile", tmpfile)
self.readExpect("tensor_begin")
dtype_str, name = self.read().split(":")
print(f"C++->Py : {dtype_str}:{name}")
u8len = int(self.read())
dtype = NAME_TORCH_DTYPE[dtype_str]
self.readExpect("file")
self.readExpect(tmpfile)
with open(tmpfile, "rb") as fd:
data = fd.read(u8len)
# `np.array` is not strictly needed, but avoids a torch warning
tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len)
self.readExpect("tensor_end")
tensor = _tensor_from_storage(tensor_u8, dtype)
strides = []
for sn in stride_name:
self.readExpect(sn)
strides.append(int(self.read()))
if len(strides) != shape:
strides.append(1)
assert len(strides) == len(shape), name
return torch.as_strided(tensor, shape, strides)
def readNamed(self, name: str):
self.readExpect(name)
return self.read()
def readExpect(self, what: str) -> None:
r = self.read()
if r != what:
raise ValueError(f"Read {r} but expected {what}")
def read(self):
read_all = []
# Skip initial whitespace
while True:
r = self.subp.stdout.read(1)
if r not in [' ', "\n"]:
read_all.append(r)
break
# Read data
while True:
r = self.subp.stdout.read(1)
if r in [' ', "\n"]:
break
read_all.append(r)
return ''.join(read_all)

View File

@ -30,6 +30,11 @@ include(GNUInstallDirs)
find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)
# Set Python3_EXECUTABLE to be visible from global scope.
# In CMake 3.24, this could be supported by adding the GLOBAL field
# to find_package above (https://cmake.org/cmake/help/latest/command/find_package.html#id7)
set(Python3_EXECUTABLE ${Python3_EXECUTABLE} CACHE INTERNAL "Path to python3 executable")
add_library(cutlass_library_includes INTERFACE)
add_library(nvidia::cutlass::library::includes ALIAS cutlass_library_includes)
set_target_properties(cutlass_library_includes PROPERTIES EXPORT_NAME library::includes)