[QNN] Conv2D with dilation support. (#4796)
authorAnimesh Jain <anijain@umich.edu>
Mon, 3 Feb 2020 02:56:45 +0000 (18:56 -0800)
committerGitHub <noreply@github.com>
Mon, 3 Feb 2020 02:56:45 +0000 (10:56 +0800)
src/relay/qnn/op/convolution.cc
tests/python/relay/test_op_qnn_conv2d.py

index 5ebd9b9..91739e6 100644 (file)
@@ -130,17 +130,17 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DA
 }
 
 /*
- * \brief Fallback to simpler lowering for dilation or grouped conv.
+ * \brief Fallback to simpler lowering for dilation (when non-zero kernel point) or grouped conv.
  * \param data The input expr.
  * \param weight The weight expr.
  * \param input_zero_point The input zero point expr.
  * \param kernel_zero_point The kernel zero point expr.
  * \param param The qnn conv2d attributes.
  * \return The fallback lowered sequence of Relay expr.
- * \note In case of dilation, normal lowering would require a dilated pool.
- *       Since, we don't have dilated pool, we fallback to a simpler sequence of
- *       Relay operations. This will potentially lead to performance degradation
- *       as the convolution is called on int32 tensors instead of int8 tensors.
+ * \note In case of dilation with non-zero kernel zero point, normal lowering would require a
+ * dilated pool. Since, we don't have dilated pool, we fallback to a simpler sequence of Relay
+ * operations. This will potentially lead to performance degradation as the convolution is called on
+ * int32 tensors instead of int8 tensors.
  */
 Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point,
                     const Expr& kernel_zero_point, const Conv2DAttrs* param) {
@@ -598,12 +598,16 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
   auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
 
-  // Fallback to int32 conv if there is dilation or grouped conv2d
+  // Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d
+  // For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to
+  // traverse the elements in dilated manner. Currently, we do not have strided pool. So, in case of
+  // dilated conv with non-zero kernel point, we fall back to simpler but slow lowering.
 
   CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
   auto dilation_h = get_const_int(param->dilation[0]);
   auto dilation_w = get_const_int(param->dilation[1]);
-  if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
+  if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) ||
+      (param->groups != 1 && !is_depthwise(param))) {
     return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param);
   } else if (is_depthwise(param)) {
     CHECK_NE(channel_multiplier, -1);
index 9631ffc..ced12c8 100644 (file)
@@ -495,7 +495,7 @@ def test_padding():
 def test_dilation():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
 
-        # uint8 input
+        # Non-zero kernel point - fall back to simpler lowering.
         data_shape = (2, 4, 4, 4)
         data_dtype = 'uint8'
         kernel_shape = (3, 4, 2, 2)
@@ -518,6 +518,29 @@ def test_dilation():
         verify(ref_func, qnn_func, data_shape, data_dtype,
                 kernel_shape, kernel_dtype)
 
+        # Zero kernel point
+        data_shape = (2, 4, 4, 4)
+        data_dtype = 'uint8'
+        kernel_shape = (3, 4, 2, 2)
+        kernel_dtype = 'uint8'
+        ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                       data_dtype=data_dtype,
+                                       kernel_shape=kernel_shape,
+                                       kernel_dtype=kernel_dtype,
+                                       input_zero_point=0,
+                                       kernel_zero_point=0,
+                                       input_scale=1.0,
+                                       kernel_scale=1.0,
+                                       kernel_size=(2, 2),
+                                       padding=(0, 0),
+                                       strides=(1, 1),
+                                       dilation=(2, 2),
+                                       data_layout="NCHW",
+                                       kernel_layout="OIHW",
+                                       out_dtype="int32")
+        verify(ref_func, qnn_func, data_shape, data_dtype,
+                kernel_shape, kernel_dtype)
+
 
 def test_const_folding():
     with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):