[Pytorch Edge] Quantized Ops Dtype Selective (#63680)
authorJacob Szwejbka <jakeszwe@fb.com>
Mon, 13 Sep 2021 17:54:08 +0000 (10:54 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 18:04:07 +0000 (11:04 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63680

Quantized ops not covered by DType Selectivity. Add the check, and adjust call sites to be constexpr friendly.

Test Plan: CI (this covers all model unit tests), verified that segmentation (a model that uses some of these quant ops) still works on instagram.

Reviewed By: dhruvbird, raymondethan

Differential Revision: D30457626

fbshipit-source-id: 5ba850d2b53a18558dfbb1cfaa78d8f53b5dbad8

aten/src/ATen/Dispatch.h
aten/src/ATen/native/quantized/affine_quantizer.cpp
aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
aten/src/ATen/native/quantized/cuda/affine_quantizer.cu

index c7d2314..1dbf1a0 100644 (file)
@@ -80,9 +80,13 @@ inline constexpr bool should_include_kernel_dtype(
 #define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED
 #endif // defined(__CUDACC__) && CUDA_VERSION <= 10100
 
+#if defined __cpp_if_constexpr
 #define AT_QINT_PRIVATE_CASE_TYPE(                                           \
-    enum_type, type, underlying_enum, underlying_type, ...)                  \
+    NAME, enum_type, type, underlying_enum, underlying_type, ...)            \
   case enum_type: {                                                          \
+    if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) {       \
+      AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
+    }                                                                        \
     using scalar_t = type;                                                   \
     using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND =                 \
         scalar_t::underlying;                                                \
@@ -93,10 +97,57 @@ inline constexpr bool should_include_kernel_dtype(
     /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */         \
     return __VA_ARGS__();                                                    \
   }
+#else
+#define AT_QINT_PRIVATE_CASE_TYPE(                                               \
+    NAME, enum_type, type, underlying_enum, underlying_type, ...)                \
+  case enum_type: {                                                              \
+    at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
+      [] {                                                                       \
+        AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME);   \
+      }                                                                          \
+    );                                                                           \
+    using scalar_t = type;                                                       \
+    using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND =                     \
+        scalar_t::underlying;                                                    \
+    const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type;     \
+    const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND =            \
+        toUnderlying(enum_type);                                                 \
+    (void)SCALAR_TYPE;  /* Suppress unused-var compiler warning */               \
+    /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */             \
+    return __VA_ARGS__();                                                        \
+  }
+#endif
 
+#if defined __cpp_if_constexpr
+#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                       \
+    NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...)            \
+  case enum_type: {                                                               \
+      if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) {          \
+      AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
+    }                                                                             \
+    using scalar_t = type;                                                        \
+    using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND =                      \
+        scalar_t::underlying;                                                     \
+    const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type;      \
+    const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND =             \
+        toUnderlying(enum_type);                                                  \
+    int bit_width = bitwidth;                                                     \
+    int64_t quant_min = qmin;                                                     \
+    int64_t quant_max = qmax;                                                     \
+    (void)bit_width; /* Suppress unused variable warning */                       \
+    (void)quant_min; /* Suppress unused variable warning */                       \
+    (void)quant_max; /* Suppress unused variable warning */                       \
+    return __VA_ARGS__();                                                         \
+  }
+#else
 #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                       \
-    enum_type, type, underlying_type, bitwidth, qmin, qmax, ...)                  \
+    NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...)            \
   case enum_type: {                                                               \
+      at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
+      [] {                                                                        \
+        AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME);    \
+      }                                                                           \
+    );                                                                            \
     using scalar_t = type;                                                        \
     using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND =                      \
         scalar_t::underlying;                                                     \
@@ -111,6 +162,7 @@ inline constexpr bool should_include_kernel_dtype(
     (void)quant_max; /* Suppress unused variable warning */                       \
     return __VA_ARGS__();                                                         \
   }
+#endif
 
 namespace detail {
 
@@ -449,11 +501,11 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
     RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st);                                \
     switch (_st) {                                                          \
       AT_QINT_PRIVATE_CASE_TYPE(                                            \
-          at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__)            \
+          NAME, at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__)      \
       AT_QINT_PRIVATE_CASE_TYPE(                                            \
-          at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__)         \
+          NAME, at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__)   \
       AT_QINT_PRIVATE_CASE_TYPE(                                            \
-          at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__)              \
+          NAME, at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__)        \
       default:                                                              \
         AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");     \
     }                                                                       \
@@ -467,13 +519,13 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
     RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st);                                                   \
     switch (_st) {                                                                             \
       AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                                      \
-          at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__)          \
+          NAME, at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__)    \
       AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                                      \
-          at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__)               \
+          NAME, at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__)         \
       AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                                      \
-          at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
+          NAME, at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
       AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                                      \
-          at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__)                         \
+          NAME, at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__)                   \
       default:                                                                                 \
         AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");                        \
     }                                                                                          \
index de60147..6be2bf7 100644 (file)
@@ -107,7 +107,7 @@ Tensor& quantize_tensor_per_tensor_affine(
     Tensor& qtensor,
     double scale,
     int64_t zero_point) {
-  static const std::string fn_name = "quantize_tensor_per_tensor_affine";
+  static constexpr auto fn_name = "quantize_tensor_per_tensor_affine";
 
   checkRoundingMode(fn_name);
   checkFloatTensor(fn_name, rtensor);
@@ -138,7 +138,7 @@ Tensor& quantize_tensor_per_channel_affine(
     Tensor scales,
     Tensor zero_points,
     int64_t axis) {
-  static const std::string fn_name = "quantize_tensor_per_channel_affine";
+  static constexpr auto fn_name = "quantize_tensor_per_channel_affine";
 
   checkRoundingMode(fn_name);
   checkFloatTensor(fn_name, rtensor);
@@ -178,7 +178,7 @@ Tensor& quantize_tensor_per_channel_float_qparams(
     Tensor scales,
     Tensor zero_points,
     int64_t axis) {
-  static const std::string fn_name =
+  static constexpr auto fn_name =
       "quantize_tensor_per_channel_float_qparams";
 
   checkRoundingMode(fn_name);
@@ -216,7 +216,7 @@ Tensor& dequantize_tensor_per_tensor_affine(
     Tensor& rtensor,
     double scale,
     int64_t zero_point) {
-  static const std::string fn_name = "dequantize_tensor_per_tensor_affine";
+  static constexpr auto fn_name = "dequantize_tensor_per_tensor_affine";
   checkFloatTensor(fn_name, rtensor);
   checkSameDevice(fn_name, rtensor, qtensor);
   checkSameSize(fn_name, qtensor, rtensor);
@@ -243,7 +243,7 @@ Tensor& dequantize_tensor_per_channel_affine(
     Tensor scales,
     Tensor zero_points,
     int64_t axis) {
-  static const std::string fn_name = "dequantize_tensor_per_channel_affine";
+  static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
 
   checkFloatTensor(fn_name, rtensor);
   checkSameDevice(fn_name, rtensor, qtensor);
@@ -282,7 +282,7 @@ Tensor& dequantize_tensor_per_channel_float_qparams(
     Tensor scales,
     Tensor zero_points,
     int64_t axis) {
-  static const std::string fn_name = "dequantize_tensor_per_channel_affine";
+  static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
 
   checkFloatTensor(fn_name, rtensor);
   checkSameDevice(fn_name, rtensor, qtensor);
index d44193e..f1f80e4 100644 (file)
@@ -1518,8 +1518,8 @@ void do_avg_pool_on_AVX_n(
 #endif
 }
 
+template <typename T>
 void _qadaptive_avg_pool_kernel(
-    const std::string& fn_name,
     const Tensor& qx,
     Tensor& qy,
     int64_t b,
@@ -1535,11 +1535,11 @@ void _qadaptive_avg_pool_kernel(
     int64_t istrideD,  // Set to 1 for 2d
     int64_t istrideH,
     int64_t istrideW) {
-  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), fn_name, [&]() {
-    scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
-    scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
+
+    T* idata = static_cast<T*>(qx.data_ptr());
+    T* odata = static_cast<T*>(qy.data_ptr());
     auto* i_p =
-        reinterpret_cast<typename scalar_t::underlying*>(idata + b * istrideB);
+        reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
 
     float input_scale = qx.q_scale();
     float output_scale = qy.q_scale();
@@ -1555,7 +1555,7 @@ void _qadaptive_avg_pool_kernel(
         int iendH = (int)std::ceil((float)((oh + 1) * isizeH) / osizeH);
         int kH = iendH - istartH;
         for (int64_t ow = 0; ow < osizeW; ow++) {
-          auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
+          auto* o_p = reinterpret_cast<typename T::underlying*>(
               odata +
               b * osizeD * osizeH * osizeW * sizeC +
               od * osizeH * osizeW * sizeC +
@@ -1579,7 +1579,7 @@ void _qadaptive_avg_pool_kernel(
           // Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop.
           //       In that case, the following loop takes over
           // TODO: more vectorization with loop interleaving
-          do_avg_pool_on_AVX_n<scalar_t>(
+          do_avg_pool_on_AVX_n<T>(
               internal_i_p,
               o_p,
               c,
@@ -1615,14 +1615,13 @@ void _qadaptive_avg_pool_kernel(
               }
             }
             // clamp
-            o_p[c] = at::native::quantize_val<scalar_t>(1.0f / multiplier,
+            o_p[c] = at::native::quantize_val<T>(1.0f / multiplier,
                                                         output_zero_point,
                                                         acc_int32).val_;
           } // c
         } // oh
       } // ow
     } // od
-  });
 }
 
 void qadaptive_avg_pool2d_nhwc_kernel(
@@ -1638,22 +1637,25 @@ void qadaptive_avg_pool2d_nhwc_kernel(
     int64_t istrideC,
     int64_t istrideH,
     int64_t istrideW) {
-  _qadaptive_avg_pool_kernel("adaptive_avg_pool2d_nhwc",
-                             qx,
-                             qy,
-                             b,
-                             sizeC,
-                             /*isizeD=*/1,
-                             isizeH,
-                             isizeW,
-                             /*osizeD=*/1,
-                             osizeH,
-                             osizeW,
-                             istrideB,
-                             istrideC,
-                             /*istrideD=*/1,
-                             istrideH,
-                             istrideW);
+    AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool2d_nhwc", [&]() {
+        _qadaptive_avg_pool_kernel<scalar_t>(
+          qx,
+          qy,
+          b,
+          sizeC,
+          /*isizeD=*/1,
+          isizeH,
+          isizeW,
+          /*osizeD=*/1,
+          osizeH,
+          osizeW,
+          istrideB,
+          istrideC,
+          /*istrideD=*/1,
+          istrideH,
+          istrideW);
+      }
+    );
 }
 
 void qadaptive_avg_pool3d_ndhwc_kernel(
@@ -1672,26 +1674,29 @@ void qadaptive_avg_pool3d_ndhwc_kernel(
     int64_t istrideD,
     int64_t istrideH,
     int64_t istrideW) {
-  _qadaptive_avg_pool_kernel("adaptive_avg_pool3d_ndhwc",
-                             qx,
-                             qy,
-                             b,
-                             sizeC,
-                             isizeD,
-                             isizeH,
-                             isizeW,
-                             osizeD,
-                             osizeH,
-                             osizeW,
-                             istrideB,
-                             istrideC,
-                             istrideD,
-                             istrideH,
-                             istrideW);
+  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool3d_ndhwc", [&]() {
+    _qadaptive_avg_pool_kernel<scalar_t>(
+      qx,
+      qy,
+      b,
+      sizeC,
+      isizeD,
+      isizeH,
+      isizeW,
+      osizeD,
+      osizeH,
+      osizeW,
+      istrideB,
+      istrideC,
+      istrideD,
+      istrideH,
+      istrideW);
+    }
+  );
 }
 
+template <typename T>
 void _qavg_pool_nhwc_kernel(
-    const std::string& fn_name,
     const Tensor& qx,
     Tensor& qy,
     int64_t b,
@@ -1713,104 +1718,102 @@ void _qavg_pool_nhwc_kernel(
     int padD,
     bool count_include_pad,
     c10::optional<int64_t> divisor_override) {
-  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), fn_name, [&]() {
-    scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
-    scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
-    int strideC = 1;
-    int strideW = strideC * nInputPlane;
-    int istrideH = strideW * inputWidth;
-    int istrideD = istrideH * inputHeight;
-    int istrideB = istrideD * inputDepth;
-    int ostrideH = strideW * outputWidth;
-    int ostrideD = ostrideH * outputHeight;
-    int ostrideB = ostrideD * outputDepth;
-    auto* i_p =
-        reinterpret_cast<typename scalar_t::underlying*>(idata + b * istrideB);
-
-    // lift these operations outside the loop to reduce access overheads
-    float input_scale = qx.q_scale();
-    float output_scale = qy.q_scale();
-    int input_zero_point = qx.q_zero_point();
-    int output_zero_point = qy.q_zero_point();
-    int64_t divisor_override_factor =
-        divisor_override.has_value() ? divisor_override.value() : 0;
-
-    for (int od = 0; od < outputDepth; od++) {
-      for (int oh = 0; oh < outputHeight; oh++) {
-        for (int ow = 0; ow < outputWidth; ow++) {
-          auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
-              odata + b * ostrideB + od * ostrideD + oh * ostrideH +
-              ow * strideW);
-          int dstart = od * dD - padD;
-          int hstart = oh * dH - padH;
-          int wstart = ow * dW - padW;
-
-          int dend = std::min(dstart + kD, (int)inputDepth + padD);
-          int hend = std::min(hstart + kH, (int)inputHeight + padH);
-          int wend = std::min(wstart + kW, (int)inputWidth + padW);
-          int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
-
-          dstart = std::max(dstart, 0);
-          hstart = std::max(hstart, 0);
-          wstart = std::max(wstart, 0);
-          dend = std::min(dend, (int)inputDepth);
-          hend = std::min(hend, (int)inputHeight);
-          wend = std::min(wend, (int)inputWidth);
-
-          int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
-          int divide_size = count_include_pad ? pool_size : size;
-          int divide_factor =
-              divisor_override_factor ? divisor_override_factor : divide_size;
-          float multiplier = input_scale / output_scale / divide_factor;
-          int input_zero_point_m_size = -input_zero_point * size;
-
-          int c_start = 0;
-
-          // For int8 quantization, we implicitly use int32 as accumulation
-          // Or else, it will go to the slow path
-          // TODO: support 16bit, 32bit, and etc.
-          do_avg_pool_nhwc_on_AVX_n<scalar_t>(
-              i_p,
-              o_p,
-              c_start,
-              input_zero_point_m_size,
-              output_zero_point,
-              multiplier,
-              dstart,
-              dend,
-              hstart,
-              hend,
-              wstart,
-              wend,
-              inputDepth,
-              inputHeight,
-              inputWidth,
-              nInputPlane);
-
-          // 1) The following loop handles the remaining channels
-          // 2) It also handles the Non-AVX2 path
-          for (int c = c_start; c < nInputPlane; ++c) {
-            int32_t acc_int32 = input_zero_point_m_size;
-            for (int64_t id = dstart; id < dend; id++) {
-              for (int64_t ih = hstart; ih < hend; ih++) {
-                for (int64_t iw = wstart; iw < wend; iw++) {
-                  auto val =
-                      *(i_p + id * istrideD + ih * istrideH + iw * strideW +
-                        c * strideC);
-                  acc_int32 += val;
-                }
+  T* idata = static_cast<T*>(qx.data_ptr());
+  T* odata = static_cast<T*>(qy.data_ptr());
+  int strideC = 1;
+  int strideW = strideC * nInputPlane;
+  int istrideH = strideW * inputWidth;
+  int istrideD = istrideH * inputHeight;
+  int istrideB = istrideD * inputDepth;
+  int ostrideH = strideW * outputWidth;
+  int ostrideD = ostrideH * outputHeight;
+  int ostrideB = ostrideD * outputDepth;
+  auto* i_p =
+      reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
+
+  // lift these operations outside the loop to reduce access overheads
+  float input_scale = qx.q_scale();
+  float output_scale = qy.q_scale();
+  int input_zero_point = qx.q_zero_point();
+  int output_zero_point = qy.q_zero_point();
+  int64_t divisor_override_factor =
+      divisor_override.has_value() ? divisor_override.value() : 0;
+
+  for (int od = 0; od < outputDepth; od++) {
+    for (int oh = 0; oh < outputHeight; oh++) {
+      for (int ow = 0; ow < outputWidth; ow++) {
+        auto* o_p = reinterpret_cast<typename T::underlying*>(
+            odata + b * ostrideB + od * ostrideD + oh * ostrideH +
+            ow * strideW);
+        int dstart = od * dD - padD;
+        int hstart = oh * dH - padH;
+        int wstart = ow * dW - padW;
+
+        int dend = std::min(dstart + kD, (int)inputDepth + padD);
+        int hend = std::min(hstart + kH, (int)inputHeight + padH);
+        int wend = std::min(wstart + kW, (int)inputWidth + padW);
+        int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+
+        dstart = std::max(dstart, 0);
+        hstart = std::max(hstart, 0);
+        wstart = std::max(wstart, 0);
+        dend = std::min(dend, (int)inputDepth);
+        hend = std::min(hend, (int)inputHeight);
+        wend = std::min(wend, (int)inputWidth);
+
+        int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
+        int divide_size = count_include_pad ? pool_size : size;
+        int divide_factor =
+            divisor_override_factor ? divisor_override_factor : divide_size;
+        float multiplier = input_scale / output_scale / divide_factor;
+        int input_zero_point_m_size = -input_zero_point * size;
+
+        int c_start = 0;
+
+        // For int8 quantization, we implicitly use int32 as accumulation
+        // Or else, it will go to the slow path
+        // TODO: support 16bit, 32bit, and etc.
+        do_avg_pool_nhwc_on_AVX_n<T>(
+            i_p,
+            o_p,
+            c_start,
+            input_zero_point_m_size,
+            output_zero_point,
+            multiplier,
+            dstart,
+            dend,
+            hstart,
+            hend,
+            wstart,
+            wend,
+            inputDepth,
+            inputHeight,
+            inputWidth,
+            nInputPlane);
+
+        // 1) The following loop handles the remaining channels
+        // 2) It also handles the Non-AVX2 path
+        for (int c = c_start; c < nInputPlane; ++c) {
+          int32_t acc_int32 = input_zero_point_m_size;
+          for (int64_t id = dstart; id < dend; id++) {
+            for (int64_t ih = hstart; ih < hend; ih++) {
+              for (int64_t iw = wstart; iw < wend; iw++) {
+                auto val =
+                    *(i_p + id * istrideD + ih * istrideH + iw * strideW +
+                      c * strideC);
+                acc_int32 += val;
               }
             }
-            double acc_fp = acc_int32 * 1.0;
-            // clamp
-            o_p[c] = at::native::quantize_val<scalar_t>(
-                         1.0f / multiplier, output_zero_point, acc_fp)
-                         .val_;
-          } // c
-        } // ow
-      } // oh
-    } // od
-  });
+          }
+          double acc_fp = acc_int32 * 1.0;
+          // clamp
+          o_p[c] = at::native::quantize_val<T>(
+                        1.0f / multiplier, output_zero_point, acc_fp)
+                        .val_;
+        } // c
+      } // ow
+    } // oh
+  } // od
 }
 
 void qavg_pool2d_nhwc_kernel(
@@ -1830,8 +1833,8 @@ void qavg_pool2d_nhwc_kernel(
     int padH,
     bool count_include_pad,
     c10::optional<int64_t> divisor_override) {
-  _qavg_pool_nhwc_kernel(
-      "avg_pool2d_nhwc",
+  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool2d_nhwc", [&]() {
+    _qavg_pool_nhwc_kernel<scalar_t>(
       qx,
       qy,
       b,
@@ -1853,6 +1856,8 @@ void qavg_pool2d_nhwc_kernel(
       0,
       count_include_pad,
       divisor_override);
+    }
+  );
 }
 
 void qavg_pool3d_nhwc_kernel(
@@ -1877,8 +1882,8 @@ void qavg_pool3d_nhwc_kernel(
     int padD,
     bool count_include_pad,
     c10::optional<int64_t> divisor_override) {
-  _qavg_pool_nhwc_kernel(
-      "avg_pool3d_nhwc",
+  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool3d_nhwc", [&]() {
+    _qavg_pool_nhwc_kernel<scalar_t>(
       qx,
       qy,
       b,
@@ -1900,6 +1905,8 @@ void qavg_pool3d_nhwc_kernel(
       padD,
       count_include_pad,
       divisor_override);
+    }
+  );
 }
 
 template <typename T>
index f8e41e5..c265066 100644 (file)
@@ -81,7 +81,7 @@ void quantize_tensor_per_channel_affine_cuda(
     const Tensor& scales,
     const Tensor& zero_points,
     int64_t axis) {
-  static const std::string fn_name = "quantize_tensor_per_channel_affine_cuda";
+  static constexpr auto fn_name = "quantize_tensor_per_channel_affine_cuda";
   std::vector<int64_t> expected_shape(rtensor.dim(), 1);
   expected_shape[axis] = rtensor.size(axis);
 
@@ -124,7 +124,7 @@ void dequantize_tensor_per_channel_affine_cuda(
     const Tensor& scales,
     const Tensor& zero_points,
     int64_t axis) {
-  static const std::string fn_name = "dequantize_tensor_per_channel_affine_cuda";
+  static constexpr auto fn_name = "dequantize_tensor_per_channel_affine_cuda";
   std::vector<int64_t> expected_shape(rtensor.dim(), 1);
   expected_shape[axis] = rtensor.size(axis);
 
@@ -160,7 +160,7 @@ void quantize_tensor_per_channel_float_qparams_cuda(
     const Tensor& scales,
     const Tensor& zero_points,
     int64_t axis) {
-  static const std::string fn_name = "quantize_tensor_per_channel_float_qparams_cuda";
+  static constexpr auto fn_name = "quantize_tensor_per_channel_float_qparams_cuda";
   std::vector<int64_t> expected_shape(rtensor.dim(), 1);
   expected_shape[axis] = rtensor.size(axis);
 
@@ -208,7 +208,7 @@ void dequantize_tensor_per_channel_float_qparams_cuda(
     const Tensor& scales,
     const Tensor& zero_points,
     int64_t axis) {
-  static const std::string fn_name = "dequantize_tensor_per_channel_float_qparams_cuda";
+  static constexpr auto fn_name = "dequantize_tensor_per_channel_float_qparams_cuda";
   std::vector<int64_t> expected_shape(rtensor.dim(), 1);
   expected_shape[axis] = rtensor.size(axis);