add two missing files (#636)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
9f2e3faa69
commit
97bff52e8c
82
CITATION.cff
Normal file
82
CITATION.cff
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
title: CUTLASS
|
||||||
|
message: >-
|
||||||
|
If you use this software, please cite using the
|
||||||
|
following metadata.
|
||||||
|
type: software
|
||||||
|
authors:
|
||||||
|
- given-names: Andrew
|
||||||
|
email: akerr@nvidia.com
|
||||||
|
family-names: Kerr
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Haicheng
|
||||||
|
family-names: Wu
|
||||||
|
affiliation: NVIDIA
|
||||||
|
email: haichengw@nvidia.com
|
||||||
|
- given-names: Manish
|
||||||
|
family-names: Gupta
|
||||||
|
affiliation: Google
|
||||||
|
email: manigupta@google.com
|
||||||
|
- given-names: Dustyn
|
||||||
|
family-names: Blasig
|
||||||
|
email: dblasig@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Pradeep
|
||||||
|
family-names: Ramini
|
||||||
|
email: prramani@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Duane
|
||||||
|
family-names: Merrill
|
||||||
|
email: dumerrill@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Aniket
|
||||||
|
family-names: Shivam
|
||||||
|
email: ashivam@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Piotr
|
||||||
|
family-names: Majcher
|
||||||
|
email: pmajcher@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Paul
|
||||||
|
family-names: Springer
|
||||||
|
email: pspringer@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Markus
|
||||||
|
family-names: Hohnerbach
|
||||||
|
affiliation: NVIDIA
|
||||||
|
email: mhohnerbach@nvidia.com
|
||||||
|
- given-names: Jin
|
||||||
|
family-names: Wang
|
||||||
|
email: jinw@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
- given-names: Matt
|
||||||
|
family-names: Nicely
|
||||||
|
email: mnicely@nvidia.com
|
||||||
|
affiliation: NVIDIA
|
||||||
|
repository-code: 'https://github.com/NVIDIA/cutlass'
|
||||||
|
abstract: >-
|
||||||
|
CUTLASS is a collection of CUDA C++ template
|
||||||
|
abstractions for implementing high-performance
|
||||||
|
matrix-multiplication (GEMM) and related
|
||||||
|
computations at all levels and scales within CUDA.
|
||||||
|
It incorporates strategies for hierarchical
|
||||||
|
decomposition and data movement similar to those
|
||||||
|
used to implement cuBLAS and cuDNN. CUTLASS
|
||||||
|
decomposes these "moving parts" into reusable,
|
||||||
|
modular software components abstracted by C++
|
||||||
|
template classes. These thread-wide, warp-wide,
|
||||||
|
block-wide, and device-wide primitives can be
|
||||||
|
specialized and tuned via custom tiling sizes, data
|
||||||
|
types, and other algorithmic policy. The resulting
|
||||||
|
flexibility simplifies their use as building blocks
|
||||||
|
within custom kernels and applications.
|
||||||
|
keywords:
|
||||||
|
- 'cutlass, tensor cores, cuda'
|
||||||
|
license: BSD-3-Clause
|
||||||
|
license-url: https://github.com/NVIDIA/cutlass/blob/v2.9.0/LICENSE.txt
|
||||||
|
version: '2.9'
|
||||||
|
date-released: '2022-04-27'
|
||||||
|
identifiers:
|
||||||
|
- type: url
|
||||||
|
value: "https://github.com/NVIDIA/cutlass/tree/v2.9.0"
|
||||||
|
description: The GitHub release URL of tag 2.9.0
|
402
tools/util/include/cutlass/util/device_groupnorm.h
Normal file
402
tools/util/include/cutlass/util/device_groupnorm.h
Normal file
@ -0,0 +1,402 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \file
|
||||||
|
* \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C'].
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/layout/tensor.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
#include "cutlass/tensor_coord.h"
|
||||||
|
#include "cutlass/tensor_ref.h"
|
||||||
|
#include "device_utils.h"
|
||||||
|
#include <float.h>
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
|
||||||
|
/** \brief interface to do group norm on a device memory tensor with NHWC layout.
|
||||||
|
* \tparam T: data type
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void groupnorm(cutlass::Tensor4DCoord input_size,
|
||||||
|
const int num_groups,
|
||||||
|
const float eps,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_gamma,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_beta,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
extern __shared__ char groupnorm_shm[];
|
||||||
|
|
||||||
|
// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory,
|
||||||
|
// we store the input in the shared memory.
|
||||||
|
// grid(num_groups, dim0)
|
||||||
|
// block(BLOCKSIZE)
|
||||||
|
// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
|
||||||
|
template<typename TVec, typename T, int T_PER_TVec>
|
||||||
|
__global__ void groupnorm_twopass_store_locally(T* output,
|
||||||
|
const T* input,
|
||||||
|
const T* gamma,
|
||||||
|
const T* beta,
|
||||||
|
int num_groups,
|
||||||
|
int prod_dim1_to_last_dim,
|
||||||
|
int last_dim,
|
||||||
|
const float eps,
|
||||||
|
const int TVecs_PER_THREAD)
|
||||||
|
{
|
||||||
|
const int bid = blockIdx.y; // index of batch
|
||||||
|
const int gid = blockIdx.x; // index of group
|
||||||
|
const int tid = threadIdx.x; // index of thread
|
||||||
|
const int bdimx = blockDim.x;
|
||||||
|
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
||||||
|
const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
|
||||||
|
const int s_group_stride = last_dim / num_groups;
|
||||||
|
const int v_group_stride = s_group_stride / T_PER_TVec;
|
||||||
|
const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
|
||||||
|
const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
|
||||||
|
TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
|
||||||
|
T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid;
|
||||||
|
float local_sum[1] = {0.0f};
|
||||||
|
|
||||||
|
// load from global memory into shared memory
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||||
|
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||||
|
const int offset_in_group =
|
||||||
|
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||||
|
/ T_PER_TVec;
|
||||||
|
if (current_load_start_idx < s_reduce_elements) {
|
||||||
|
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||||
|
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||||
|
const int local_val_offset = i * T_PER_TVec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < T_PER_TVec; j++) {
|
||||||
|
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
||||||
|
local_sum[0] += tmp;
|
||||||
|
local_val[local_val_offset + j] = tmp_vec_ptr[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__shared__ float s_mean, s_variance;
|
||||||
|
|
||||||
|
// reduction for mean
|
||||||
|
if (bdimx <= 32) {
|
||||||
|
warpReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
blockReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
s_mean = local_sum[0] / s_reduce_elements;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// reduction for std
|
||||||
|
local_sum[0] = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||||
|
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||||
|
if (current_load_start_idx < s_reduce_elements) {
|
||||||
|
const int local_val_offset = i * T_PER_TVec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < T_PER_TVec; j++) {
|
||||||
|
float tmp = static_cast<float>(local_val[local_val_offset + j]);
|
||||||
|
tmp -= s_mean;
|
||||||
|
local_sum[0] += tmp * tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bdimx <= 32) {
|
||||||
|
warpReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
blockReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// normalize
|
||||||
|
const int gamma_offset_of_group = gid * v_group_stride;
|
||||||
|
const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
|
||||||
|
const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||||
|
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||||
|
const int offset_in_group =
|
||||||
|
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||||
|
/ T_PER_TVec;
|
||||||
|
const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
|
||||||
|
const int local_val_offset = i * T_PER_TVec;
|
||||||
|
if (current_load_start_idx < s_reduce_elements) {
|
||||||
|
TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
|
||||||
|
TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
|
||||||
|
T* gamma_val_ptr = (T*)(&gamma_val);
|
||||||
|
T* beta_val_ptr = (T*)(&beta_val);
|
||||||
|
TVec tmp_vec;
|
||||||
|
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < T_PER_TVec; j++) {
|
||||||
|
float tmp = (static_cast<float>(local_val[local_val_offset + j]) - s_mean) * s_variance
|
||||||
|
* static_cast<float>(gamma_val_ptr[j])
|
||||||
|
+ static_cast<float>(beta_val_ptr[j]);
|
||||||
|
if (sizeof(T) == sizeof(half)) {
|
||||||
|
tmp_vec_ptr[j] = T(__float2half_rn(tmp));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
tmp_vec_ptr[j] = T(tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_TVec_ptr[offset_in_group] = tmp_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For large prod_dim1_to_last_dim/num_groups,
|
||||||
|
// in which the data cannot be stored locally,
|
||||||
|
// we will load from global memory multiple times,
|
||||||
|
// grid(num_groups, dim0)
|
||||||
|
// block(BLOCKSIZE)
|
||||||
|
// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group
|
||||||
|
template<typename TVec, typename T, int T_PER_TVec>
|
||||||
|
__global__ void groupnorm_twopass_multiple_load(T* output,
|
||||||
|
const T* input,
|
||||||
|
const T* gamma,
|
||||||
|
const T* beta,
|
||||||
|
int num_groups,
|
||||||
|
int prod_dim1_to_last_dim,
|
||||||
|
int last_dim,
|
||||||
|
const float eps,
|
||||||
|
const int TVecs_PER_THREAD)
|
||||||
|
{
|
||||||
|
const int bid = blockIdx.y; // index of batch
|
||||||
|
const int gid = blockIdx.x; // index of group
|
||||||
|
const int tid = threadIdx.x; // index of thread
|
||||||
|
const int bdimx = blockDim.x;
|
||||||
|
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
||||||
|
const int v_reduce_elements = s_reduce_elements / T_PER_TVec;
|
||||||
|
const int s_group_stride = last_dim / num_groups;
|
||||||
|
const int v_group_stride = s_group_stride / T_PER_TVec;
|
||||||
|
const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec;
|
||||||
|
const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group;
|
||||||
|
TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group;
|
||||||
|
float local_sum[1] = {0.0f};
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||||
|
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||||
|
if (current_load_start_idx < s_reduce_elements) {
|
||||||
|
const int offset_in_group =
|
||||||
|
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||||
|
/ T_PER_TVec;
|
||||||
|
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||||
|
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < T_PER_TVec; j++) {
|
||||||
|
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
||||||
|
local_sum[0] += tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__shared__ float s_mean, s_variance;
|
||||||
|
|
||||||
|
// reduction for mean
|
||||||
|
if (bdimx <= 32) {
|
||||||
|
warpReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
blockReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
s_mean = local_sum[0] / s_reduce_elements;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// reduction for std
|
||||||
|
local_sum[0] = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||||
|
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||||
|
if (current_load_start_idx < s_reduce_elements) {
|
||||||
|
const int offset_in_group =
|
||||||
|
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||||
|
/ T_PER_TVec;
|
||||||
|
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||||
|
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < T_PER_TVec; j++) {
|
||||||
|
float tmp = static_cast<float>(tmp_vec_ptr[j]);
|
||||||
|
tmp -= s_mean;
|
||||||
|
local_sum[0] += tmp * tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bdimx <= 32) {
|
||||||
|
warpReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
blockReduceSum<float, 1>(local_sum);
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// normalize
|
||||||
|
const int gamma_offset_of_group = gid * v_group_stride;
|
||||||
|
const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group;
|
||||||
|
const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < TVecs_PER_THREAD; i += 1) {
|
||||||
|
const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec;
|
||||||
|
if (current_load_start_idx < s_reduce_elements) {
|
||||||
|
const int offset_in_group =
|
||||||
|
((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride))
|
||||||
|
/ T_PER_TVec;
|
||||||
|
const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec;
|
||||||
|
TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group];
|
||||||
|
TVec beta_val = beta_TVec_ptr[gamma_offset_in_group];
|
||||||
|
T* gamma_val_ptr = (T*)(&gamma_val);
|
||||||
|
T* beta_val_ptr = (T*)(&beta_val);
|
||||||
|
TVec tmp_vec = input_TVec_ptr[offset_in_group];
|
||||||
|
T* tmp_vec_ptr = (T*)(&tmp_vec);
|
||||||
|
TVec output_tmp_vec;
|
||||||
|
T* output_tmp_vec_ptr = (T*)(&output_tmp_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < T_PER_TVec; j++) {
|
||||||
|
float tmp =
|
||||||
|
(static_cast<float>(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast<float>(gamma_val_ptr[j])
|
||||||
|
+ static_cast<float>(beta_val_ptr[j]);
|
||||||
|
if (sizeof(T) == sizeof(half)) {
|
||||||
|
output_tmp_vec_ptr[j] = T(__float2half_rn(tmp));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output_tmp_vec_ptr[j] = T(tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_TVec_ptr[offset_in_group] = output_tmp_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//ref_input & ref_output should be [N, H, W, C]
|
||||||
|
//ref_gamma & ref_beta shoud be [1, 1, 1, C]
|
||||||
|
template <typename T>
|
||||||
|
void groupnorm(cutlass::Tensor4DCoord input_size,
|
||||||
|
const int num_groups,
|
||||||
|
const float eps,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_output,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_input,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_gamma,
|
||||||
|
TensorRef<T, layout::TensorNHWC> ref_beta,
|
||||||
|
cudaStream_t stream){
|
||||||
|
const int N = input_size.n();
|
||||||
|
const int H = input_size.h();
|
||||||
|
const int W = input_size.w();
|
||||||
|
const int C = input_size.c();
|
||||||
|
if (C % num_groups != 0){
|
||||||
|
printf("[ERROR] C should be a multiple of num_groups.\n");
|
||||||
|
}
|
||||||
|
T* output = ref_output.data();
|
||||||
|
const T* input = ref_input.data();
|
||||||
|
const T* gamma = ref_gamma.data();
|
||||||
|
const T* beta = ref_beta.data();
|
||||||
|
|
||||||
|
const int dim0 = N;
|
||||||
|
const int last_dim = C;
|
||||||
|
const int prod_dim1_to_last_dim = H*W*C;
|
||||||
|
const int s_reduce_elements = prod_dim1_to_last_dim / num_groups;
|
||||||
|
const int s_group_stride = last_dim / num_groups;
|
||||||
|
dim3 grid(num_groups, dim0);
|
||||||
|
int threadblock_size = 32;
|
||||||
|
if (s_group_stride % 2 == 0) {
|
||||||
|
const int T_PER_TVec = 2;
|
||||||
|
while (threadblock_size < 1024) {
|
||||||
|
if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
|
||||||
|
break;
|
||||||
|
threadblock_size *= 2;
|
||||||
|
}
|
||||||
|
dim3 block(threadblock_size);
|
||||||
|
const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
|
||||||
|
const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
|
||||||
|
// for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32;
|
||||||
|
// the size of grid & block may have better choice for different cases.
|
||||||
|
// ensure shared memory is smaller than 48KB
|
||||||
|
if (std::is_same<T, float>::value){
|
||||||
|
if (shm_size < 48 * 1024) {
|
||||||
|
groupnorm_twopass_store_locally<float2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
||||||
|
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
groupnorm_twopass_multiple_load<float2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
||||||
|
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
if (shm_size < 48 * 1024) {
|
||||||
|
groupnorm_twopass_store_locally<half2, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
||||||
|
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
groupnorm_twopass_multiple_load<half2, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
||||||
|
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const int T_PER_TVec = 1;
|
||||||
|
while (threadblock_size < 1024) {
|
||||||
|
if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8)
|
||||||
|
break;
|
||||||
|
threadblock_size *= 2;
|
||||||
|
}
|
||||||
|
dim3 block(threadblock_size);
|
||||||
|
const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size;
|
||||||
|
const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T);
|
||||||
|
if (shm_size < 48 * 1024) {
|
||||||
|
groupnorm_twopass_store_locally<T, T, T_PER_TVec><<<grid, block, shm_size, stream>>>(
|
||||||
|
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
groupnorm_twopass_multiple_load<T, T, T_PER_TVec><<<grid, block, 0, stream>>>(
|
||||||
|
output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
} //namespace cutlass
|
Loading…
Reference in New Issue
Block a user