add two missing files (#636)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
9f2e3faa69
commit
97bff52e8c
82
CITATION.cff
Normal file
82
CITATION.cff
Normal file
@ -0,0 +1,82 @@
|
||||
cff-version: 1.2.0
|
||||
title: CUTLASS
|
||||
message: >-
|
||||
If you use this software, please cite using the
|
||||
following metadata.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Andrew
|
||||
email: akerr@nvidia.com
|
||||
family-names: Kerr
|
||||
affiliation: NVIDIA
|
||||
- given-names: Haicheng
|
||||
family-names: Wu
|
||||
affiliation: NVIDIA
|
||||
email: haichengw@nvidia.com
|
||||
- given-names: Manish
|
||||
family-names: Gupta
|
||||
affiliation: Google
|
||||
email: manigupta@google.com
|
||||
- given-names: Dustyn
|
||||
family-names: Blasig
|
||||
email: dblasig@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Pradeep
|
||||
family-names: Ramini
|
||||
email: prramani@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Duane
|
||||
family-names: Merrill
|
||||
email: dumerrill@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Aniket
|
||||
family-names: Shivam
|
||||
email: ashivam@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Piotr
|
||||
family-names: Majcher
|
||||
email: pmajcher@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Paul
|
||||
family-names: Springer
|
||||
email: pspringer@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Markus
|
||||
family-names: Hohnerbach
|
||||
affiliation: NVIDIA
|
||||
email: mhohnerbach@nvidia.com
|
||||
- given-names: Jin
|
||||
family-names: Wang
|
||||
email: jinw@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
- given-names: Matt
|
||||
family-names: Nicely
|
||||
email: mnicely@nvidia.com
|
||||
affiliation: NVIDIA
|
||||
repository-code: 'https://github.com/NVIDIA/cutlass'
|
||||
abstract: >-
|
||||
CUTLASS is a collection of CUDA C++ template
|
||||
abstractions for implementing high-performance
|
||||
matrix-multiplication (GEMM) and related
|
||||
computations at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical
|
||||
decomposition and data movement similar to those
|
||||
used to implement cuBLAS and cuDNN. CUTLASS
|
||||
decomposes these "moving parts" into reusable,
|
||||
modular software components abstracted by C++
|
||||
template classes. These thread-wide, warp-wide,
|
||||
block-wide, and device-wide primitives can be
|
||||
specialized and tuned via custom tiling sizes, data
|
||||
types, and other algorithmic policy. The resulting
|
||||
flexibility simplifies their use as building blocks
|
||||
within custom kernels and applications.
|
||||
keywords:
|
||||
- 'cutlass, tensor cores, cuda'
|
||||
license: BSD-3-Clause
|
||||
license-url: https://github.com/NVIDIA/cutlass/blob/v2.9.0/LICENSE.txt
|
||||
version: '2.9'
|
||||
date-released: '2022-04-27'
|
||||
identifiers:
|
||||
- type: url
|
||||
value: "https://github.com/NVIDIA/cutlass/tree/v2.9.0"
|
||||
description: The GitHub release URL of tag 2.9.0
|
402
tools/util/include/cutlass/util/device_groupnorm.h
Normal file
402
tools/util/include/cutlass/util/device_groupnorm.h
Normal file
@ -0,0 +1,402 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C'].
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_coord.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "device_utils.h"
|
||||
#include <float.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/** \brief interface to do group norm on a device memory tensor with NHWC layout.
|
||||
* \tparam T: data type
|
||||
*/
|
||||
template <typename T>
|
||||
void groupnorm(cutlass::Tensor4DCoord input_size,
|
||||
const int num_groups,
|
||||
const float eps,
|
||||
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||
TensorRef<T, layout::TensorNHWC> ref_gamma,
|
||||
TensorRef<T, layout::TensorNHWC> ref_beta,
|
||||
cudaStream_t stream);
|
||||
|
||||
extern __shared__ char groupnorm_shm[];
|
||||
|
||||
// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory,
|
||||
// we store the input in the shared memory.
|
||||
// grid(num_groups, dim0)
|
||||
// block(BLOCKSIZE)
|
||||
// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
|
||||
template<typename TVec, typename T, int T_PER_TVec>
|
||||
__global__ void groupnorm_twopass_store_locally(T* output,
|
||||
const T* input,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
int num_groups,
|
||||
int prod_dim1_to_last_dim,
|
||||
int last_dim,
|
||||
const float eps,
|
||||
const int TVecs_PER_THREAD)
|
||||
{
|
||||
const int bid = blockIdx.y; // index of batch
|
||||
const int gid = blockIdx.x; // index of group
|
||||
const int tid = threadIdx.x; // index of thread
|
||||
const int bdimx = blockDim.x;
|
||||
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
||||
const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
|
||||
const int s_group_stride = last_dim / num_groups;
|
||||
const int v_group_stride = s_group_stride / T_PER_TVec;
|
||||
const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
|
||||
const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
|
||||
TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
|
||||
T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid;
|
||||
float local_sum[1] = {0.0f};
|
||||
|
||||
// load from global memory into shared memory
|
||||
#pragma unroll
|
||||
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||
const int offset_in_group =
|
||||
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||
/ T_PER_TVec;
|
||||
if (current_load_start_idx < s_reduce_elements) {
|
||||
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||
const int local_val_offset = i * T_PER_TVec;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_PER_TVec; j++) {
|
||||
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
||||
local_sum[0] += tmp;
|
||||
local_val[local_val_offset + j] = tmp_vec_ptr[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
__shared__ float s_mean, s_variance;
|
||||
|
||||
// reduction for mean
|
||||
if (bdimx <= 32) {
|
||||
warpReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
else {
|
||||
blockReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
if (tid == 0) {
|
||||
s_mean = local_sum[0] / s_reduce_elements;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduction for std
|
||||
local_sum[0] = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||
if (current_load_start_idx < s_reduce_elements) {
|
||||
const int local_val_offset = i * T_PER_TVec;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_PER_TVec; j++) {
|
||||
float tmp = static_cast<float>(local_val[local_val_offset + j]);
|
||||
tmp -= s_mean;
|
||||
local_sum[0] += tmp * tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bdimx <= 32) {
|
||||
warpReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
else {
|
||||
blockReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
if (tid == 0) {
|
||||
s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// normalize
|
||||
const int gamma_offset_of_group = gid * v_group_stride;
|
||||
const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
|
||||
const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||
const int offset_in_group =
|
||||
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||
/ T_PER_TVec;
|
||||
const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
|
||||
const int local_val_offset = i * T_PER_TVec;
|
||||
if (current_load_start_idx < s_reduce_elements) {
|
||||
TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
|
||||
TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
|
||||
T* gamma_val_ptr = (T*)(&gamma_val);
|
||||
T* beta_val_ptr = (T*)(&beta_val);
|
||||
TVec tmp_vec;
|
||||
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_PER_TVec; j++) {
|
||||
float tmp = (static_cast<float>(local_val[local_val_offset + j]) - s_mean) * s_variance
|
||||
* static_cast<float>(gamma_val_ptr[j])
|
||||
+ static_cast<float>(beta_val_ptr[j]);
|
||||
if (sizeof(T) == sizeof(half)) {
|
||||
tmp_vec_ptr[j] = T(__float2half_rn(tmp));
|
||||
}
|
||||
else {
|
||||
tmp_vec_ptr[j] = T(tmp);
|
||||
}
|
||||
}
|
||||
output_TVec_ptr[offset_in_group] = tmp_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For large prod_dim1_to_last_dim/num_groups,
|
||||
// in which the data cannot be stored locally,
|
||||
// we will load from global memory multiple times,
|
||||
// grid(num_groups, dim0)
|
||||
// block(BLOCKSIZE)
|
||||
// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
|
||||
template<typename TVec, typename T, int T_PER_TVec>
|
||||
__global__ void groupnorm_twopass_multiple_load(T* output,
|
||||
const T* input,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
int num_groups,
|
||||
int prod_dim1_to_last_dim,
|
||||
int last_dim,
|
||||
const float eps,
|
||||
const int TVecs_PER_THREAD)
|
||||
{
|
||||
const int bid = blockIdx.y; // index of batch
|
||||
const int gid = blockIdx.x; // index of group
|
||||
const int tid = threadIdx.x; // index of thread
|
||||
const int bdimx = blockDim.x;
|
||||
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
||||
const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
|
||||
const int s_group_stride = last_dim / num_groups;
|
||||
const int v_group_stride = s_group_stride / T_PER_TVec;
|
||||
const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
|
||||
const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
|
||||
TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
|
||||
float local_sum[1] = {0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||
if (current_load_start_idx < s_reduce_elements) {
|
||||
const int offset_in_group =
|
||||
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||
/ T_PER_TVec;
|
||||
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_PER_TVec; j++) {
|
||||
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
||||
local_sum[0] += tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
__shared__ float s_mean, s_variance;
|
||||
|
||||
// reduction for mean
|
||||
if (bdimx <= 32) {
|
||||
warpReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
else {
|
||||
blockReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
if (tid == 0) {
|
||||
s_mean = local_sum[0] / s_reduce_elements;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduction for std
|
||||
local_sum[0] = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||
if (current_load_start_idx < s_reduce_elements) {
|
||||
const int offset_in_group =
|
||||
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||
/ T_PER_TVec;
|
||||
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_PER_TVec; j++) {
|
||||
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
||||
tmp -= s_mean;
|
||||
local_sum[0] += tmp * tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bdimx <= 32) {
|
||||
warpReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
else {
|
||||
blockReduceSum<float, 1>(local_sum);
|
||||
}
|
||||
if (tid == 0) {
|
||||
s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// normalize
|
||||
const int gamma_offset_of_group = gid * v_group_stride;
|
||||
const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
|
||||
const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||
if (current_load_start_idx < s_reduce_elements) {
|
||||
const int offset_in_group =
|
||||
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||
/ T_PER_TVec;
|
||||
const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
|
||||
TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
|
||||
TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
|
||||
T* gamma_val_ptr = (T*)(&gamma_val);
|
||||
T* beta_val_ptr = (T*)(&beta_val);
|
||||
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||
TVec output_tmp_vec;
|
||||
T* output_tmp_vec_ptr = (T*)(&output_tmp_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < T_PER_TVec; j++) {
|
||||
float tmp =
|
||||
(static_cast<float>(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast<float>(gamma_val_ptr[j])
|
||||
+ static_cast<float>(beta_val_ptr[j]);
|
||||
if (sizeof(T) == sizeof(half)) {
|
||||
output_tmp_vec_ptr[j] = T(__float2half_rn(tmp));
|
||||
}
|
||||
else {
|
||||
output_tmp_vec_ptr[j] = T(tmp);
|
||||
}
|
||||
}
|
||||
output_TVec_ptr[offset_in_group] = output_tmp_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//ref_input & ref_output should be [N, H, W, C]
|
||||
//ref_gamma & ref_beta shoud be [1, 1, 1, C]
|
||||
template <typename T>
|
||||
void groupnorm(cutlass::Tensor4DCoord input_size,
|
||||
const int num_groups,
|
||||
const float eps,
|
||||
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||
TensorRef<T, layout::TensorNHWC> ref_gamma,
|
||||
TensorRef<T, layout::TensorNHWC> ref_beta,
|
||||
cudaStream_t stream){
|
||||
const int N = input_size.n();
|
||||
const int H = input_size.h();
|
||||
const int W = input_size.w();
|
||||
const int C = input_size.c();
|
||||
if (C % num_groups != 0){
|
||||
printf("[ERROR] C should be a multiple of num_groups.\n");
|
||||
}
|
||||
T* output = ref_output.data();
|
||||
const T* input = ref_input.data();
|
||||
const T* gamma = ref_gamma.data();
|
||||
const T* beta = ref_beta.data();
|
||||
|
||||
const int dim0 = N;
|
||||
const int last_dim = C;
|
||||
const int prod_dim1_to_last_dim = H*W*C;
|
||||
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
||||
const int s_group_stride = last_dim / num_groups;
|
||||
dim3 grid(num_groups, dim0);
|
||||
int threadblock_size = 32;
|
||||
if (s_group_stride % 2 == 0) {
|
||||
const int T_PER_TVec = 2;
|
||||
while (threadblock_size < 1024) {
|
||||
if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
|
||||
break;
|
||||
threadblock_size *= 2;
|
||||
}
|
||||
dim3 block(threadblock_size);
|
||||
const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
|
||||
const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
|
||||
// for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32;
|
||||
// the size of grid & block may have better choice for different cases.
|
||||
// ensure shared memory is smaller than 48KB
|
||||
if (std::is_same<T, float>::value){
|
||||
if (shm_size < 48 * 1024) {
|
||||
groupnorm_twopass_store_locally<float2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
||||
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||
}
|
||||
else {
|
||||
groupnorm_twopass_multiple_load<float2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
||||
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||
}
|
||||
}
|
||||
else{
|
||||
if (shm_size < 48 * 1024) {
|
||||
groupnorm_twopass_store_locally<half2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
||||
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||
}
|
||||
else {
|
||||
groupnorm_twopass_multiple_load<half2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
||||
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
const int T_PER_TVec = 1;
|
||||
while (threadblock_size < 1024) {
|
||||
if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
|
||||
break;
|
||||
threadblock_size *= 2;
|
||||
}
|
||||
dim3 block(threadblock_size);
|
||||
const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
|
||||
const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
|
||||
if (shm_size < 48 * 1024) {
|
||||
groupnorm_twopass_store_locally<T, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
||||
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||
}
|
||||
else {
|
||||
groupnorm_twopass_multiple_load<T, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
||||
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} //namespace cutlass
|
Loading…
Reference in New Issue
Block a user