* Update splitk_gemm.cu

* Update gemm_bias_relu.cu

* Update mma_sm75.h
This commit is contained in:
hwu36 2020-07-13 17:25:52 -04:00 committed by GitHub
parent fd7e058d0c
commit 4dac7490e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 4 deletions

View File

@ -205,7 +205,7 @@ int run() {
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a( cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b( cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.nk()); // <- Create matrix B with dimensions N x K problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c( cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d( cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(

View File

@ -132,7 +132,7 @@ int run() {
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a( cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b( cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.nk()); // <- Create matrix B with dimensions N x K problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias( cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias(
{problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1 {problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1
@ -234,7 +234,6 @@ int run() {
tensor_a.device_ref(), tensor_a.device_ref(),
tensor_b.device_ref(), tensor_b.device_ref(),
0, 0,
tensor_c_bias.device_ref(),
tensor_ref_d.device_ref()); tensor_ref_d.device_ref());
// Wait for kernels to finish // Wait for kernels to finish

View File

@ -823,7 +823,7 @@ struct Mma<
int const *C = reinterpret_cast<int const *>(&c); int const *C = reinterpret_cast<int const *>(&c);
int *D = reinterpret_cast<int *>(&d); int *D = reinterpret_cast<int *>(&d);
asm volatile("_mma.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(D[0]), "=r"(D[1]) : "=r"(D[0]), "=r"(D[1])
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));