Make operator() const-correct and add missing static functions. (#936)
* Make operator() const-correct and add missing static functions. Currently, `*Converter::operator()` requires a mutable object to invoke, and there are missing `static result_type convert(source_type const & source)` overloads for certain partial specializations of `*Converter` objects. This commit makes `operator()` const-correct and adds missing function overloads where appropriate. * minor changes * format --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
24c8b7d8a2
commit
b250faccd3
@ -81,7 +81,7 @@ struct NumericConverter {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -107,7 +107,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -126,7 +126,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -145,7 +145,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
return (result_type)std::nearbyint(s);
|
||||
}
|
||||
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -162,7 +162,7 @@ struct NumericConverter<int32_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
return (result_type)std::nearbyint(s);
|
||||
}
|
||||
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -192,7 +192,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -214,7 +214,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -241,7 +241,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
return static_cast<result_type>(intermediate);
|
||||
}
|
||||
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -266,7 +266,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
return static_cast<result_type>(intermediate);
|
||||
}
|
||||
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -290,7 +290,7 @@ struct NumericConverter<T, T, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -318,7 +318,7 @@ struct NumericConverter<float, half_t, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -340,7 +340,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -409,7 +409,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -435,7 +435,7 @@ struct NumericConverter<float, bfloat16_t, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -452,7 +452,7 @@ struct NumericConverter<bfloat16_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -482,7 +482,7 @@ struct NumericConverter<bfloat16_t, float, FloatRoundStyle::round_half_ulp_trunc
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -503,7 +503,7 @@ struct NumericConverter<bfloat16_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -529,7 +529,7 @@ struct NumericConverter<float, tfloat32_t, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -579,7 +579,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -596,7 +596,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_half_ulp_trunc
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -621,7 +621,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_half_ulp_trunc
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -639,7 +639,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_toward_zero> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -682,7 +682,7 @@ struct NumericConverterFastF32 {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -715,7 +715,7 @@ struct NumericConverterClamp {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -732,12 +732,16 @@ struct NumericConverterClamp<cutlass::half_t, S> {
|
||||
using source_type = S;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return static_cast<cutlass::half_t>(s);
|
||||
static result_type convert(source_type const &source) {
|
||||
return static_cast<cutlass::half_t>(source);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Conversion operator for Array
|
||||
@ -782,7 +786,7 @@ struct NumericArrayConverter {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -804,20 +808,23 @@ struct NumericArrayConverter<T, T, N, Round, Transform> {
|
||||
"Unary Operator not supported.");
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
if( platform::is_same<Transform, cutlass::transform::thread::UnaryTransform::Identity>::value )
|
||||
{
|
||||
return s;
|
||||
static result_type convert(source_type const &source) {
|
||||
if (platform::is_same<Transform, cutlass::transform::thread::UnaryTransform::Identity>::value) {
|
||||
return source;
|
||||
} else {
|
||||
result_type result;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
result[i] = conj(s[i]);
|
||||
result[i] = conj(source[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -846,7 +853,7 @@ struct NumericArrayConverter<half_t, float, 2, FloatRoundStyle::round_to_nearest
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -876,7 +883,7 @@ struct NumericArrayConverter<float, half_t, 2, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -918,7 +925,7 @@ struct NumericArrayConverter<half_t, float, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -959,12 +966,11 @@ struct NumericArrayConverter<float, half_t, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
@ -989,7 +995,7 @@ struct NumericArrayConverter<bfloat16_t, float, 2, FloatRoundStyle::round_to_nea
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1031,7 +1037,7 @@ struct NumericArrayConverter<bfloat16_t, float, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1067,7 +1073,7 @@ struct NumericArrayConverter<int8_t, int, 1, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1096,7 +1102,7 @@ struct NumericArrayConverter<int8_t, int, 2, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1127,7 +1133,7 @@ struct NumericArrayConverter<int8_t, int, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1163,7 +1169,7 @@ struct NumericArrayConverter<int8_t, int, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1190,7 +1196,7 @@ struct NumericArrayConverter<uint8_t, int, 1, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1219,7 +1225,7 @@ struct NumericArrayConverter<uint8_t, int, 2, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1250,7 +1256,7 @@ struct NumericArrayConverter<uint8_t, int, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1286,7 +1292,7 @@ struct NumericArrayConverter<uint8_t, int, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1349,7 +1355,7 @@ struct NumericArrayConverter<float, float_e4m3_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1397,7 +1403,7 @@ struct NumericArrayConverter<float_e4m3_t, float, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1452,7 +1458,7 @@ struct NumericArrayConverter<float, float_e5m2_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1500,7 +1506,7 @@ struct NumericArrayConverter<float_e5m2_t, float, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1551,7 +1557,7 @@ struct NumericArrayConverter<half_t, float_e4m3_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1600,7 +1606,7 @@ struct NumericArrayConverter<float_e4m3_t, half_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1645,7 +1651,7 @@ struct NumericArrayConverter<half_t, float_e5m2_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1694,7 +1700,7 @@ struct NumericArrayConverter<float_e5m2_t, half_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1748,7 +1754,7 @@ struct NumericArrayConverter<bfloat16_t, float_e4m3_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1794,7 +1800,7 @@ struct NumericArrayConverter<float_e4m3_t, bfloat16_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1842,7 +1848,7 @@ struct NumericArrayConverter<bfloat16_t, float_e5m2_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1888,7 +1894,7 @@ struct NumericArrayConverter<float_e5m2_t, bfloat16_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1925,7 +1931,7 @@ struct NumericArrayConverter<float_e4m3_t, float_e5m2_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1956,7 +1962,7 @@ struct NumericArrayConverter<float_e5m2_t, float_e4m3_t, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -1986,8 +1992,13 @@ struct NumericArrayConverter<float_e4m3_t, float_e4m3_t, 4, Round> {
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return s;
|
||||
static result_type convert(source_type const &source) {
|
||||
return source;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
@ -2004,8 +2015,13 @@ struct NumericArrayConverter<float_e5m2_t, float_e5m2_t, 4, Round> {
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
return s;
|
||||
static result_type convert(source_type const &source) {
|
||||
return source;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
@ -2063,7 +2079,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -2165,7 +2181,7 @@ struct NumericArrayConverter<int8_t, float, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -2206,7 +2222,7 @@ struct NumericArrayConverter<int4b_t, int, 8, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -2242,7 +2258,7 @@ struct NumericArrayConverter<int4b_t, int, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -2277,7 +2293,7 @@ struct NumericArrayConverter<uint4b_t, int, 8, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -2313,7 +2329,7 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) {
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
@ -2341,7 +2357,7 @@ struct FastNumericArrayConverter {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
result_type operator()(source_type const &s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<float> <= Array<int>
|
||||
@ -2365,7 +2381,7 @@ struct FastNumericArrayConverter<float, T, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
result_type operator()(source_type const &s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int8_t, 4> <= Array<float, 4>
|
||||
@ -2393,7 +2409,7 @@ struct FastNumericArrayConverter<int8_t, float, 4, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
result_type operator()(source_type const &s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int8_t> <= Array<float>
|
||||
@ -2425,7 +2441,7 @@ struct FastNumericArrayConverter<int8_t, float, N, Round> {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) { return convert(s); }
|
||||
result_type operator()(source_type const &s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Loading…
Reference in New Issue
Block a user