Make operator() const-correct and add missing static functions. (#936)

* Make operator() const-correct and add missing static functions.

Currently, `*Converter::operator()` requires a mutable object to invoke,
and there are missing `static result_type convert(source_type const &
source)` overloads for certain partial specializations of `*Converter`
objects. This commit makes `operator()` const-correct and adds missing
function overloads where appropriate.

* minor changes

* format

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Gregory Meyer (gregjm) 2023-05-09 13:33:01 -07:00 committed by GitHub
parent 24c8b7d8a2
commit b250faccd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -81,7 +81,7 @@ struct NumericConverter {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -107,7 +107,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_to_nearest> {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -126,7 +126,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_toward_zero> {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -145,7 +145,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_to_nearest> {
return (result_type)std::nearbyint(s);
}
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -162,7 +162,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_toward_zero> {
return (result_type)std::nearbyint(s);
}
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -192,7 +192,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -214,7 +214,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_toward_zero> {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -241,7 +241,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
return static_cast<result_type>(intermediate);
}
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -266,7 +266,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_toward_zero> {
return static_cast<result_type>(intermediate);
}
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -290,7 +290,7 @@ struct NumericConverter<T, T, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -318,7 +318,7 @@ struct NumericConverter<float, half_t, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -340,7 +340,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_to_nearest> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -409,7 +409,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -435,7 +435,7 @@ struct NumericConverter<float, bfloat16_t, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -452,7 +452,7 @@ struct NumericConverter<bfloat16_t, float, FloatRoundStyle::round_to_nearest> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -482,7 +482,7 @@ struct NumericConverter<bfloat16_t, float, FloatRoundStyle::round_half_ulp_trunc
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -503,7 +503,7 @@ struct NumericConverter<bfloat16_t, float, FloatRoundStyle::round_toward_zero> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -529,7 +529,7 @@ struct NumericConverter<float, tfloat32_t, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -579,7 +579,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_to_nearest> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -596,7 +596,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_half_ulp_trunc
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -621,7 +621,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_half_ulp_trunc
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -639,7 +639,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_toward_zero> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -682,7 +682,7 @@ struct NumericConverterFastF32 {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -715,7 +715,7 @@ struct NumericConverterClamp {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -732,12 +732,16 @@ struct NumericConverterClamp<cutlass::half_t, S> {
using source_type = S;
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return static_cast<cutlass::half_t>(s);
static result_type convert(source_type const &source) {
return static_cast<cutlass::half_t>(source);
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Conversion operator for Array
@ -782,7 +786,7 @@ struct NumericArrayConverter {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -804,20 +808,23 @@ struct NumericArrayConverter<T, T, N, Round, Transform> {
"Unary Operator not supported.");
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
if( platform::is_same<Transform, cutlass::transform::thread::UnaryTransform::Identity>::value )
{
return s;
static result_type convert(source_type const &source) {
if (platform::is_same<Transform, cutlass::transform::thread::UnaryTransform::Identity>::value) {
return source;
} else {
result_type result;
for (int i = 0; i < N; ++i) {
result[i] = conj(s[i]);
result[i] = conj(source[i]);
}
return result;
}
}
};
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -846,7 +853,7 @@ struct NumericArrayConverter<half_t, float, 2, FloatRoundStyle::round_to_nearest
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -876,7 +883,7 @@ struct NumericArrayConverter<float, half_t, 2, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -918,7 +925,7 @@ struct NumericArrayConverter<half_t, float, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -959,12 +966,11 @@ struct NumericArrayConverter<float, half_t, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
@ -989,7 +995,7 @@ struct NumericArrayConverter<bfloat16_t, float, 2, FloatRoundStyle::round_to_nea
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1031,7 +1037,7 @@ struct NumericArrayConverter<bfloat16_t, float, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1067,7 +1073,7 @@ struct NumericArrayConverter<int8_t, int, 1, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1096,7 +1102,7 @@ struct NumericArrayConverter<int8_t, int, 2, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1127,7 +1133,7 @@ struct NumericArrayConverter<int8_t, int, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1163,7 +1169,7 @@ struct NumericArrayConverter<int8_t, int, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1190,7 +1196,7 @@ struct NumericArrayConverter<uint8_t, int, 1, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1219,7 +1225,7 @@ struct NumericArrayConverter<uint8_t, int, 2, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1250,7 +1256,7 @@ struct NumericArrayConverter<uint8_t, int, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1286,7 +1292,7 @@ struct NumericArrayConverter<uint8_t, int, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1349,7 +1355,7 @@ struct NumericArrayConverter<float, float_e4m3_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1397,7 +1403,7 @@ struct NumericArrayConverter<float_e4m3_t, float, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1452,7 +1458,7 @@ struct NumericArrayConverter<float, float_e5m2_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1500,7 +1506,7 @@ struct NumericArrayConverter<float_e5m2_t, float, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1551,7 +1557,7 @@ struct NumericArrayConverter<half_t, float_e4m3_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1600,7 +1606,7 @@ struct NumericArrayConverter<float_e4m3_t, half_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1645,7 +1651,7 @@ struct NumericArrayConverter<half_t, float_e5m2_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1694,7 +1700,7 @@ struct NumericArrayConverter<float_e5m2_t, half_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1748,7 +1754,7 @@ struct NumericArrayConverter<bfloat16_t, float_e4m3_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1794,7 +1800,7 @@ struct NumericArrayConverter<float_e4m3_t, bfloat16_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1842,7 +1848,7 @@ struct NumericArrayConverter<bfloat16_t, float_e5m2_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1888,7 +1894,7 @@ struct NumericArrayConverter<float_e5m2_t, bfloat16_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1925,7 +1931,7 @@ struct NumericArrayConverter<float_e4m3_t, float_e5m2_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1956,7 +1962,7 @@ struct NumericArrayConverter<float_e5m2_t, float_e4m3_t, 4, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -1986,8 +1992,13 @@ struct NumericArrayConverter<float_e4m3_t, float_e4m3_t, 4, Round> {
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return s;
static result_type convert(source_type const &source) {
return source;
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2004,8 +2015,13 @@ struct NumericArrayConverter<float_e5m2_t, float_e5m2_t, 4, Round> {
static FloatRoundStyle const round_style = Round;
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
return s;
static result_type convert(source_type const &source) {
return source;
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2063,7 +2079,7 @@ public:
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2165,7 +2181,7 @@ struct NumericArrayConverter<int8_t, float, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2206,7 +2222,7 @@ struct NumericArrayConverter<int4b_t, int, 8, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2242,7 +2258,7 @@ struct NumericArrayConverter<int4b_t, int, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2277,7 +2293,7 @@ struct NumericArrayConverter<uint4b_t, int, 8, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2313,7 +2329,7 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) {
result_type operator()(source_type const &s) const {
return convert(s);
}
};
@ -2341,7 +2357,7 @@ struct FastNumericArrayConverter {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
result_type operator()(source_type const &s) const { return convert(s); }
};
/// Partial specialization for Array<float> <= Array<int>
@ -2365,7 +2381,7 @@ struct FastNumericArrayConverter<float, T, N, Round> {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
result_type operator()(source_type const &s) const { return convert(s); }
};
/// Partial specialization for Array<int8_t, 4> <= Array<float, 4>
@ -2393,7 +2409,7 @@ struct FastNumericArrayConverter<int8_t, float, 4, Round> {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
result_type operator()(source_type const &s) const { return convert(s); }
};
/// Partial specialization for Array<int8_t> <= Array<float>
@ -2425,7 +2441,7 @@ struct FastNumericArrayConverter<int8_t, float, N, Round> {
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) { return convert(s); }
result_type operator()(source_type const &s) const { return convert(s); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////