#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED
#endif // defined(__CUDACC__) && CUDA_VERSION <= 10100
+#if defined __cpp_if_constexpr
#define AT_QINT_PRIVATE_CASE_TYPE( \
- enum_type, type, underlying_enum, underlying_type, ...) \
+ NAME, enum_type, type, underlying_enum, underlying_type, ...) \
case enum_type: { \
+ if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \
+ AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
+ } \
using scalar_t = type; \
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
scalar_t::underlying; \
/* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
return __VA_ARGS__(); \
}
+#else
+#define AT_QINT_PRIVATE_CASE_TYPE( \
+ NAME, enum_type, type, underlying_enum, underlying_type, ...) \
+ case enum_type: { \
+ at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
+ [] { \
+ AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \
+ } \
+ ); \
+ using scalar_t = type; \
+ using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
+ scalar_t::underlying; \
+ const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
+ const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
+ toUnderlying(enum_type); \
+ (void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \
+ /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
+ return __VA_ARGS__(); \
+ }
+#endif
+#if defined __cpp_if_constexpr
+#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
+ NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
+ case enum_type: { \
+ if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \
+ AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
+ } \
+ using scalar_t = type; \
+ using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
+ scalar_t::underlying; \
+ const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
+ const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
+ toUnderlying(enum_type); \
+ int bit_width = bitwidth; \
+ int64_t quant_min = qmin; \
+ int64_t quant_max = qmax; \
+ (void)bit_width; /* Suppress unused variable warning */ \
+ (void)quant_min; /* Suppress unused variable warning */ \
+ (void)quant_max; /* Suppress unused variable warning */ \
+ return __VA_ARGS__(); \
+ }
+#else
#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
- enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
+ NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
case enum_type: { \
+ at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
+ [] { \
+ AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \
+ } \
+ ); \
using scalar_t = type; \
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
scalar_t::underlying; \
(void)quant_max; /* Suppress unused variable warning */ \
return __VA_ARGS__(); \
}
+#endif
namespace detail {
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
switch (_st) { \
AT_QINT_PRIVATE_CASE_TYPE( \
- at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
+ NAME, at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
- at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
+ NAME, at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
- at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
+ NAME, at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
switch (_st) { \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
- at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
+ NAME, at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
- at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
+ NAME, at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
- at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
+ NAME, at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
- at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \
+ NAME, at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
#endif
}
+template <typename T>
void _qadaptive_avg_pool_kernel(
- const std::string& fn_name,
const Tensor& qx,
Tensor& qy,
int64_t b,
int64_t istrideD, // Set to 1 for 2d
int64_t istrideH,
int64_t istrideW) {
- AT_DISPATCH_QINT_TYPES(qx.scalar_type(), fn_name, [&]() {
- scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
- scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
+
+ T* idata = static_cast<T*>(qx.data_ptr());
+ T* odata = static_cast<T*>(qy.data_ptr());
auto* i_p =
- reinterpret_cast<typename scalar_t::underlying*>(idata + b * istrideB);
+ reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
float input_scale = qx.q_scale();
float output_scale = qy.q_scale();
int iendH = (int)std::ceil((float)((oh + 1) * isizeH) / osizeH);
int kH = iendH - istartH;
for (int64_t ow = 0; ow < osizeW; ow++) {
- auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
+ auto* o_p = reinterpret_cast<typename T::underlying*>(
odata +
b * osizeD * osizeH * osizeW * sizeC +
od * osizeH * osizeW * sizeC +
// Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop.
// In that case, the following loop takes over
// TODO: more vectorization with loop interleaving
- do_avg_pool_on_AVX_n<scalar_t>(
+ do_avg_pool_on_AVX_n<T>(
internal_i_p,
o_p,
c,
}
}
// clamp
- o_p[c] = at::native::quantize_val<scalar_t>(1.0f / multiplier,
+ o_p[c] = at::native::quantize_val<T>(1.0f / multiplier,
output_zero_point,
acc_int32).val_;
} // c
} // oh
} // ow
} // od
- });
}
void qadaptive_avg_pool2d_nhwc_kernel(
int64_t istrideC,
int64_t istrideH,
int64_t istrideW) {
- _qadaptive_avg_pool_kernel("adaptive_avg_pool2d_nhwc",
- qx,
- qy,
- b,
- sizeC,
- /*isizeD=*/1,
- isizeH,
- isizeW,
- /*osizeD=*/1,
- osizeH,
- osizeW,
- istrideB,
- istrideC,
- /*istrideD=*/1,
- istrideH,
- istrideW);
+ AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool2d_nhwc", [&]() {
+ _qadaptive_avg_pool_kernel<scalar_t>(
+ qx,
+ qy,
+ b,
+ sizeC,
+ /*isizeD=*/1,
+ isizeH,
+ isizeW,
+ /*osizeD=*/1,
+ osizeH,
+ osizeW,
+ istrideB,
+ istrideC,
+ /*istrideD=*/1,
+ istrideH,
+ istrideW);
+ }
+ );
}
void qadaptive_avg_pool3d_ndhwc_kernel(
int64_t istrideD,
int64_t istrideH,
int64_t istrideW) {
- _qadaptive_avg_pool_kernel("adaptive_avg_pool3d_ndhwc",
- qx,
- qy,
- b,
- sizeC,
- isizeD,
- isizeH,
- isizeW,
- osizeD,
- osizeH,
- osizeW,
- istrideB,
- istrideC,
- istrideD,
- istrideH,
- istrideW);
+ AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool3d_ndhwc", [&]() {
+ _qadaptive_avg_pool_kernel<scalar_t>(
+ qx,
+ qy,
+ b,
+ sizeC,
+ isizeD,
+ isizeH,
+ isizeW,
+ osizeD,
+ osizeH,
+ osizeW,
+ istrideB,
+ istrideC,
+ istrideD,
+ istrideH,
+ istrideW);
+ }
+ );
}
+template <typename T>
void _qavg_pool_nhwc_kernel(
- const std::string& fn_name,
const Tensor& qx,
Tensor& qy,
int64_t b,
int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
- AT_DISPATCH_QINT_TYPES(qx.scalar_type(), fn_name, [&]() {
- scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
- scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
- int strideC = 1;
- int strideW = strideC * nInputPlane;
- int istrideH = strideW * inputWidth;
- int istrideD = istrideH * inputHeight;
- int istrideB = istrideD * inputDepth;
- int ostrideH = strideW * outputWidth;
- int ostrideD = ostrideH * outputHeight;
- int ostrideB = ostrideD * outputDepth;
- auto* i_p =
- reinterpret_cast<typename scalar_t::underlying*>(idata + b * istrideB);
-
- // lift these operations outside the loop to reduce access overheads
- float input_scale = qx.q_scale();
- float output_scale = qy.q_scale();
- int input_zero_point = qx.q_zero_point();
- int output_zero_point = qy.q_zero_point();
- int64_t divisor_override_factor =
- divisor_override.has_value() ? divisor_override.value() : 0;
-
- for (int od = 0; od < outputDepth; od++) {
- for (int oh = 0; oh < outputHeight; oh++) {
- for (int ow = 0; ow < outputWidth; ow++) {
- auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
- odata + b * ostrideB + od * ostrideD + oh * ostrideH +
- ow * strideW);
- int dstart = od * dD - padD;
- int hstart = oh * dH - padH;
- int wstart = ow * dW - padW;
-
- int dend = std::min(dstart + kD, (int)inputDepth + padD);
- int hend = std::min(hstart + kH, (int)inputHeight + padH);
- int wend = std::min(wstart + kW, (int)inputWidth + padW);
- int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
-
- dstart = std::max(dstart, 0);
- hstart = std::max(hstart, 0);
- wstart = std::max(wstart, 0);
- dend = std::min(dend, (int)inputDepth);
- hend = std::min(hend, (int)inputHeight);
- wend = std::min(wend, (int)inputWidth);
-
- int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
- int divide_size = count_include_pad ? pool_size : size;
- int divide_factor =
- divisor_override_factor ? divisor_override_factor : divide_size;
- float multiplier = input_scale / output_scale / divide_factor;
- int input_zero_point_m_size = -input_zero_point * size;
-
- int c_start = 0;
-
- // For int8 quantization, we implicitly use int32 as accumulation
- // Or else, it will go to the slow path
- // TODO: support 16bit, 32bit, and etc.
- do_avg_pool_nhwc_on_AVX_n<scalar_t>(
- i_p,
- o_p,
- c_start,
- input_zero_point_m_size,
- output_zero_point,
- multiplier,
- dstart,
- dend,
- hstart,
- hend,
- wstart,
- wend,
- inputDepth,
- inputHeight,
- inputWidth,
- nInputPlane);
-
- // 1) The following loop handles the remaining channels
- // 2) It also handles the Non-AVX2 path
- for (int c = c_start; c < nInputPlane; ++c) {
- int32_t acc_int32 = input_zero_point_m_size;
- for (int64_t id = dstart; id < dend; id++) {
- for (int64_t ih = hstart; ih < hend; ih++) {
- for (int64_t iw = wstart; iw < wend; iw++) {
- auto val =
- *(i_p + id * istrideD + ih * istrideH + iw * strideW +
- c * strideC);
- acc_int32 += val;
- }
+ T* idata = static_cast<T*>(qx.data_ptr());
+ T* odata = static_cast<T*>(qy.data_ptr());
+ int strideC = 1;
+ int strideW = strideC * nInputPlane;
+ int istrideH = strideW * inputWidth;
+ int istrideD = istrideH * inputHeight;
+ int istrideB = istrideD * inputDepth;
+ int ostrideH = strideW * outputWidth;
+ int ostrideD = ostrideH * outputHeight;
+ int ostrideB = ostrideD * outputDepth;
+ auto* i_p =
+ reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
+
+ // lift these operations outside the loop to reduce access overheads
+ float input_scale = qx.q_scale();
+ float output_scale = qy.q_scale();
+ int input_zero_point = qx.q_zero_point();
+ int output_zero_point = qy.q_zero_point();
+ int64_t divisor_override_factor =
+ divisor_override.has_value() ? divisor_override.value() : 0;
+
+ for (int od = 0; od < outputDepth; od++) {
+ for (int oh = 0; oh < outputHeight; oh++) {
+ for (int ow = 0; ow < outputWidth; ow++) {
+ auto* o_p = reinterpret_cast<typename T::underlying*>(
+ odata + b * ostrideB + od * ostrideD + oh * ostrideH +
+ ow * strideW);
+ int dstart = od * dD - padD;
+ int hstart = oh * dH - padH;
+ int wstart = ow * dW - padW;
+
+ int dend = std::min(dstart + kD, (int)inputDepth + padD);
+ int hend = std::min(hstart + kH, (int)inputHeight + padH);
+ int wend = std::min(wstart + kW, (int)inputWidth + padW);
+ int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+
+ dstart = std::max(dstart, 0);
+ hstart = std::max(hstart, 0);
+ wstart = std::max(wstart, 0);
+ dend = std::min(dend, (int)inputDepth);
+ hend = std::min(hend, (int)inputHeight);
+ wend = std::min(wend, (int)inputWidth);
+
+ int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+ int divide_size = count_include_pad ? pool_size : size;
+ int divide_factor =
+ divisor_override_factor ? divisor_override_factor : divide_size;
+ float multiplier = input_scale / output_scale / divide_factor;
+ int input_zero_point_m_size = -input_zero_point * size;
+
+ int c_start = 0;
+
+ // For int8 quantization, we implicitly use int32 as accumulation
+ // Or else, it will go to the slow path
+ // TODO: support 16bit, 32bit, and etc.
+ do_avg_pool_nhwc_on_AVX_n<T>(
+ i_p,
+ o_p,
+ c_start,
+ input_zero_point_m_size,
+ output_zero_point,
+ multiplier,
+ dstart,
+ dend,
+ hstart,
+ hend,
+ wstart,
+ wend,
+ inputDepth,
+ inputHeight,
+ inputWidth,
+ nInputPlane);
+
+ // 1) The following loop handles the remaining channels
+ // 2) It also handles the Non-AVX2 path
+ for (int c = c_start; c < nInputPlane; ++c) {
+ int32_t acc_int32 = input_zero_point_m_size;
+ for (int64_t id = dstart; id < dend; id++) {
+ for (int64_t ih = hstart; ih < hend; ih++) {
+ for (int64_t iw = wstart; iw < wend; iw++) {
+ auto val =
+ *(i_p + id * istrideD + ih * istrideH + iw * strideW +
+ c * strideC);
+ acc_int32 += val;
}
}
- double acc_fp = acc_int32 * 1.0;
- // clamp
- o_p[c] = at::native::quantize_val<scalar_t>(
- 1.0f / multiplier, output_zero_point, acc_fp)
- .val_;
- } // c
- } // ow
- } // oh
- } // od
- });
+ }
+ double acc_fp = acc_int32 * 1.0;
+ // clamp
+ o_p[c] = at::native::quantize_val<T>(
+ 1.0f / multiplier, output_zero_point, acc_fp)
+ .val_;
+ } // c
+ } // ow
+ } // oh
+ } // od
}
void qavg_pool2d_nhwc_kernel(
int padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
- _qavg_pool_nhwc_kernel(
- "avg_pool2d_nhwc",
+ AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool2d_nhwc", [&]() {
+ _qavg_pool_nhwc_kernel<scalar_t>(
qx,
qy,
b,
0,
count_include_pad,
divisor_override);
+ }
+ );
}
void qavg_pool3d_nhwc_kernel(
int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
- _qavg_pool_nhwc_kernel(
- "avg_pool3d_nhwc",
+ AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool3d_nhwc", [&]() {
+ _qavg_pool_nhwc_kernel<scalar_t>(
qx,
qy,
b,
padD,
count_include_pad,
divisor_override);
+ }
+ );
}
template <typename T>