[QNN] Lowering for Depthwise Convolution. (#4351)
authorAnimesh Jain <anijain@umich.edu>
Thu, 21 Nov 2019 05:22:25 +0000 (21:22 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Thu, 21 Nov 2019 05:22:25 +0000 (21:22 -0800)
src/relay/pass/pattern_util.h
src/relay/qnn/op/convolution.cc
tests/python/relay/test_op_qnn_conv2d.py
topi/python/topi/x86/conv2d_alter_op.py

index 525008b..0921e13 100644 (file)
@@ -503,6 +503,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
 
 Expr MakeConcatenate(Expr data, int axis);
 
+Expr MakeRepeat(Expr data, int repeats, int axis);
+
 Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
 
 Expr MakeStack(Expr data, int axis);
index 9cbb415..13eacda 100644 (file)
@@ -39,9 +39,7 @@ namespace qnn {
 // relay.op.qnn.conv2d
 TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs);
 
-bool QnnConv2DRel(const Array<Type>& types,
-                  int num_inputs,
-                  const Attrs& attrs,
+bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                   const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 3);
   const auto* data = types[0].as<TensorTypeNode>();
@@ -50,17 +48,22 @@ bool QnnConv2DRel(const Array<Type>& types,
   const auto* param = attrs.as<QnnConv2DAttrs>();
   CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr.";
   CHECK(data->dtype == Int(8) || data->dtype == UInt(8))
-    << "Expected qnn conv2d type(int8, uint8) for input but was " <<  data->dtype;
+      << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
   CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8))
-    << "Expected qnn conv2d type(int8, uint8) for weight but was " <<  weight->dtype;
+      << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
   CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32))
-    << "Expected qnn conv2d type(int32, int16) for output but was " <<  param->out_dtype;
+      << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
   CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
   return Conv2DRel<QnnConv2DAttrs>(types, num_inputs, attrs, reporter);
 }
 
-// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w
-using WorkloadType = std::tuple<int, int, int, int, int>;
+bool is_depthwise(const QnnConv2DAttrs* param) {
+  return param->channels.defined() && tvm::ir::Equal(param->channels, param->groups) &&
+         param->groups != 1;
+}
+
+// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier
+using WorkloadType = std::tuple<int, int, int, int, int, int>;
 
 /*
  * \brief Get the conv parameters like batch_size, kernel_height etc.
@@ -84,26 +87,39 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
 
   const auto kernel_shape = get_shape(arg_types[1]);
   int out_channels, kernel_h, kernel_w;
+  int channel_multiplier = -1;
+  bool depthwise = is_depthwise(param);
   if (param->kernel_layout == "OIHW") {
     out_channels = get_const_int(kernel_shape[0]);
     kernel_h = get_const_int(kernel_shape[2]);
     kernel_w = get_const_int(kernel_shape[3]);
+    if (depthwise) {
+      channel_multiplier = get_const_int(kernel_shape[1]);
+    }
   } else if (param->kernel_layout == "HWIO") {
     kernel_h = get_const_int(kernel_shape[0]);
     kernel_w = get_const_int(kernel_shape[1]);
     out_channels = get_const_int(kernel_shape[3]);
+    if (depthwise) {
+      channel_multiplier = get_const_int(kernel_shape[2]);
+    }
   } else if (param->kernel_layout == "HWOI") {
     kernel_h = get_const_int(kernel_shape[0]);
     kernel_w = get_const_int(kernel_shape[1]);
     out_channels = get_const_int(kernel_shape[2]);
+    if (depthwise) {
+      channel_multiplier = get_const_int(kernel_shape[3]);
+    }
   } else {
     LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
   }
-  return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w);
+
+  return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w,
+                         channel_multiplier);
 }
 
 /*
- * \brief Fallback to simpler lowering for dilation or depthwise conv.
+ * \brief Fallback to simpler lowering for dilation or grouped conv.
  * \param data The input expr.
  * \param weight The weight expr.
  * \param param The qnn conv2d attributes.
@@ -167,6 +183,129 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
 }
 
 /*
+ * \brief Calculates the second term in the qnn.conv2d depthwise lowering sequence.
+ * \param padded_data The padded data expr.
+ * \param param The qnn conv2d attributes.
+ * \param kernel_h The height of kernel.
+ * \param kernel_w The width of kernel.
+ * \param channel_multiplier The channel/depth multiplier.
+ * \return The sequence of Relay operators for term2.
+ * \note The term2 looks like this
+ *
+ *       Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s)
+ *
+ *       Second term is not directly representable by one Relay operator.
+ *       However, deeper analysis shows that we can reduce r,s using avg_pool2d,
+ *       followed by repeat on the C axis by cm times.
+ */
+Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
+                               int kernel_w, int channel_multiplier) {
+  // Constant Expr for the kernel zero point.
+  auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);
+
+  auto casted_t2 = Cast(padded_data, Int(32));
+
+  // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
+  // Since, this is integer division (floor), we can first multiply the data by the pool_size and
+  // then perform avg_pool2d. Reversing this causes inaccuracy due to floor division. If the
+  // pool_size is 1x1, we don't need avg_pool2d.
+  auto reduced_t2 = casted_t2;
+  if (kernel_h * kernel_w != 1) {
+    auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
+    Array<IndexExpr> padding({0, 0});
+    reduced_t2 =
+        AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout,
+                  false,   // ceil_mode
+                  false);  // count_include_pad
+  }
+
+  auto multiplied_t2 = reduced_t2;
+  if (param->kernel_zero_point != 1) {
+    multiplied_t2 = Multiply(zp_kernel, reduced_t2);
+  }
+
+  // Reduce the C dimension. Find the dimension.
+  int axis_t2 = 0;
+  if (param->data_layout == "NCHW") {
+    axis_t2 = 1;
+  } else if (param->data_layout == "NHWC") {
+    axis_t2 = 3;
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
+  }
+  auto repeated_t2 = multiplied_t2;
+  if (channel_multiplier != 1) {
+    repeated_t2 = MakeRepeat(multiplied_t2, channel_multiplier, axis_t2);
+  }
+  return repeated_t2;
+}
+
+/*
+ * \brief Calculates the third term in the qnn.conv2d depthwise lowering sequence.
+ * \param weight The weight expr.
+ * \param param The qnn conv2d attributes.
+ * \param out_channels The number of output channels.
+ * \param channel_multiplier The channel/depth multiplier.
+ * \return The sequence of Relay operatos for term3.
+ * \note The term3 looks like this
+ *
+ *       Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s)
+ *
+ *       This can be achieved by calling reduce on r and s axis. The tensor can be then reshaped to
+ *       (1, oc, 1, 1) as (oc/m, oc%m) are just contiguous memory locations.
+ */
+Expr DepthwiseConv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels,
+                              int channel_multiplier) {
+  // Constant expr for input zero point.
+  auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
+
+  // Find which dimensions are R, S.
+  Array<Integer> axes_t3;
+  if (param->kernel_layout == "OIHW") {
+    // For OIHW kernel layout, HW are reduce axis
+    axes_t3 = {2, 3};
+  } else if (param->kernel_layout == "HWIO") {
+    axes_t3 = {0, 1};
+  } else if (param->kernel_layout == "HWOI") {
+    axes_t3 = {0, 1};
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
+  }
+  auto reduced_t3 = Sum(Cast(weight, Int(32)), axes_t3, false, false);
+
+  // Find the newshape depending on NCHW/NHWC layout.
+  Array<Integer> newshape;
+  if (param->data_layout == "NCHW") {
+    newshape = {1, out_channels * channel_multiplier, 1, 1};
+  } else if (param->data_layout == "NHWC") {
+    newshape = {1, 1, 1, out_channels * channel_multiplier};
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
+  }
+  auto reshaped_t3 = Reshape(reduced_t3, newshape);
+
+  if (param->input_zero_point == 1) {
+    return reshaped_t3;
+  }
+  return Multiply(zp_data, reshaped_t3);
+}
+
+/*
+ * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence.
+ * \param param The qnn conv2d attributes.
+ * \param kernel_h The height of kernel.
+ * \param kernel_w The width of kernel.
+ * \return The sequence of Relay operators for term4.
+ * \note The term4 looks like this
+ *
+ *       Sigma(r, s) zp_a * zp_w
+ */
+Expr DepthwiseConv2DFourthTerm(const QnnConv2DAttrs* param, int kernel_h, int kernel_w) {
+  int scalar_term4 = param->input_zero_point * param->kernel_zero_point * kernel_h * kernel_w;
+  return MakeConstantScalar(Int(32), scalar_term4);
+}
+
+/*
  * \brief Calculates the first term in the qnn.conv2d lowering sequence.
  * \param data The input expr.
  * \param weight The weight expr.
@@ -210,7 +349,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
   // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
   // Since, this is integer division (floor), we can first multiply the data by the pool_size and
   // then perform avg_pool2d. Reversing this causes inaccuracy due to floor division.
-  auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
   Array<IndexExpr> padding({0, 0});
 
   // Reduce the C dimension. Find the dimension.
@@ -223,11 +361,12 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
     LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
   }
   // Keep dims true to retain 4D tensor
-  auto reduced_c_t2 = Sum(scaled_hw_t2, axes_t2, true, false);
+  auto reduced_c_t2 = Sum(casted_t2, axes_t2, true, false);
 
   // If the pool_size is 1x1, we don't need avg_pool2d.
   auto reduced_t2 = reduced_c_t2;
   if (kernel_h * kernel_w != 1) {
+    reduced_c_t2 = Multiply(reduced_c_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
     reduced_t2 =
         AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout,
                   false,   // ceil_mode
@@ -245,7 +384,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
  * \brief Calculates the third term in the qnn.conv2d lowering sequence.
  * \param weight The weight expr.
  * \param param The qnn conv2d attributes.
- * \param batch_size The batch size.
  * \param out_channels The number of output channels.
  * \return The sequence of Relay operatos for term3.
  * \note The term3 looks like this
@@ -256,8 +394,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int
  *       a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
  *       format.
  */
-Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size,
-                     int out_channels) {
+Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int out_channels) {
   // Constant expr for input zero point.
   auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
 
@@ -278,9 +415,9 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
   // Find the newshape depending on NCHW/NHWC layout.
   Array<Integer> newshape;
   if (param->data_layout == "NCHW") {
-    newshape = {batch_size, out_channels, 1, 1};
+    newshape = {1, out_channels, 1, 1};
   } else if (param->data_layout == "NHWC") {
-    newshape = {batch_size, 1, 1, out_channels};
+    newshape = {1, 1, 1, out_channels};
   } else {
     LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
   }
@@ -295,7 +432,6 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
 /*
  * \brief Calculates the fourth term in the qnn.conv2d lowering sequence.
  * \param param The qnn conv2d attributes.
- * \param batch_size The batch size.
  * \param in_channels The number of input channels.
  * \param kernel_h The height of kernel.
  * \param kernel_w The width of kernel.
@@ -305,8 +441,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_
  *       Sigma(c,r,s) zp_a * zp_w
  *
  */
-Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int batch_size, int in_channels, int kernel_h,
-                      int kernel_w) {
+Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int in_channels, int kernel_h, int kernel_w) {
   int scalar_term4 =
       param->input_zero_point * param->kernel_zero_point * in_channels * kernel_h * kernel_w;
   return MakeConstantScalar(Int(32), scalar_term4);
@@ -391,7 +526,20 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
  *         gives an opportunity to reuse alter_op_layout infrastructure.
  *         3) For dilated conv, in current lowering, we need dilated pool. So as
  *         a workaround, we fall back to simpler lowering using int32 conv if
- *         the conv is dilated. We fallback also in case of depthwise conv.
+ *         the conv is dilated. We fallback also in case of grouped conv.
+ *
+ *       For depthwise, we can similarly unroll the computation. The intial compute is as follows
+ *       wehere cm = channel_multiplier
+ *
+ *       Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w)
+ *                                     * (Qa(n, oc/cm, oh + r, ow + s) - zp_a)
+ *
+ *       This can be written as
+ *
+ *            Sigma(r, s) Qw(oc/m, oc%/m, r, s) * Qa(n, oc/cm, oh + r, ow + s)
+ *          - Sigma(r, s) zp_w * Qa(n, oc/cm, oh + r, ow + s)
+ *          - Sigma(r, s) zp_a * Qw(oc/m, oc%m, r, s)
+ *          - Sigma(r, s) zp_a * zp_w
  *
  *       The whole process can be broken down into following steps
  *       * Assertion checks for existing support, fallback if necessary
@@ -417,23 +565,33 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
         param->kernel_layout == "HWOI")
       << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";
 
-  int batch_size, in_channels, out_channels, kernel_h, kernel_w;
-  std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
+  int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
+  std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
       GetWorkload(arg_types, param);
 
-  // Fallback to int32 conv if there is dilation or depthwise conv2d
+  // Fallback to int32 conv if there is dilation or grouped conv2d
+
   CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
   auto dilation_h = get_const_int(param->dilation[0]);
   auto dilation_w = get_const_int(param->dilation[1]);
-  if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
+  if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
     return Conv2DFallBack(data, weight, param);
+  } else if (is_depthwise(param)) {
+    CHECK_NE(channel_multiplier, -1);
+    auto padded_data = Conv2DPadInput(data, param);
+    auto term1 = Conv2DFirstTerm(padded_data, weight, param);
+    auto term2 =
+        DepthwiseConv2DSecondTerm(padded_data, param, kernel_h, kernel_w, channel_multiplier);
+    auto term3 = DepthwiseConv2DThirdTerm(weight, param, out_channels, channel_multiplier);
+    auto term4 = DepthwiseConv2DFourthTerm(param, kernel_h, kernel_w);
+    return Conv2DCombineTerms(term1, term2, term3, term4, param);
   }
 
   auto padded_data = Conv2DPadInput(data, param);
   auto term1 = Conv2DFirstTerm(padded_data, weight, param);
   auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
-  auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
-  auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
+  auto term3 = Conv2DThirdTerm(weight, param, out_channels);
+  auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w);
   return Conv2DCombineTerms(term1, term2, term3, term4, param);
 }
 
index 40fc993..eda47e6 100644 (file)
@@ -42,7 +42,9 @@ def get_ref_func(data,
                  dilation,
                  data_layout,
                  kernel_layout,
-                 out_dtype):
+                 out_dtype,
+                 groups,
+                 channels=None):
     casted_data = relay.op.cast(data, "int32")
     casted_kernel = relay.op.cast(kernel, "int32")
     shifted_data = relay.op.subtract(casted_data,
@@ -54,6 +56,8 @@ def get_ref_func(data,
                              padding=padding,
                              strides=strides,
                              dilation=dilation,
+                             groups=groups,
+                             channels=channels,
                              kernel_size=kernel_size,
                              out_dtype=out_dtype,
                              data_layout=data_layout,
@@ -74,7 +78,9 @@ def get_qnn_func(data,
                  dilation,
                  data_layout,
                  kernel_layout,
-                 out_dtype):
+                 out_dtype,
+                 groups,
+                 channels=None):
     func = relay.qnn.op.conv2d(
             data, kernel,
             input_zero_point=input_zero_point,
@@ -86,6 +92,8 @@ def get_qnn_func(data,
             dilation=dilation,
             padding=padding,
             out_dtype=out_dtype,
+            groups=groups,
+            channels=channels,
             data_layout=data_layout,
             kernel_layout=kernel_layout)
 
@@ -107,7 +115,9 @@ def get_funcs(data_shape,
               dilation,
               data_layout,
               kernel_layout,
-              out_dtype):
+              out_dtype,
+              groups=1,
+              channels=None):
     data = relay.var("data", shape=data_shape,
             dtype=data_dtype)
     kernel = relay.var("kernel", shape=kernel_shape,
@@ -124,8 +134,11 @@ def get_funcs(data_shape,
                             dilation,
                             data_layout,
                             kernel_layout,
-                            out_dtype)
+                            out_dtype,
+                            groups,
+                            channels)
     ref_func = run_infer_type(ref_func)
+    ref_func = relay.Module.from_expr(ref_func)
     qnn_func = get_qnn_func(data,
                             kernel,
                             input_zero_point,
@@ -138,7 +151,9 @@ def get_funcs(data_shape,
                             dilation,
                             data_layout,
                             kernel_layout,
-                            out_dtype)
+                            out_dtype,
+                            groups,
+                            channels)
     return (ref_func, qnn_func)
 
 def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
@@ -151,14 +166,14 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
         if data_dtype == "uint8":
             low = 0
             high = 255
-        golden_data = np.random.random_integers(low=low, high=high,
+        golden_data = np.random.randint(low=low, high=high,
                 size=data_shape).astype(data_dtype)
         low = -128
         high = 127
         if kernel_dtype == "uint8":
             low = 0
             high = 255
-        golden_weight = np.random.random_integers(low=low, high=high,
+        golden_weight = np.random.randint(low=low, high=high,
                 size=kernel_shape).astype(kernel_dtype)
         return (golden_data, golden_weight)
 
@@ -512,7 +527,7 @@ def test_const_folding():
         kernel_shape = (3, 4, 2, 2)
         kernel_dtype = 'uint8'
 
-        golden_weight = np.random.random_integers(low=0, high=255,
+        golden_weight = np.random.randint(low=0, high=255,
                 size=kernel_shape).astype(kernel_dtype)
         data = relay.var("data", shape=data_shape,
                 dtype=data_dtype)
@@ -529,7 +544,8 @@ def test_const_folding():
                                 dilation=(1, 1),
                                 data_layout="NCHW",
                                 kernel_layout="OIHW",
-                                out_dtype="int32")
+                                out_dtype="int32",
+                                groups=1)
         folded_mod = transform.FoldConstant()(qnn_func)
         folded_func = folded_mod["main"]
         assert "reshape" not in folded_func.astext()
@@ -724,6 +740,112 @@ def test_broadcast_layout():
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
 
+def test_depthwise_depth_multiplier():
+    with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
+
+        # uint8 input, NCHW and OIHW
+        # Depthwise multiplier = 1
+        data_shape = (2, 4, 16, 16)
+        data_dtype = 'uint8'
+        kernel_shape = (4, 1, 3, 3)
+        kernel_dtype = 'uint8'
+        ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                       data_dtype=data_dtype,
+                                       kernel_shape=kernel_shape,
+                                       kernel_dtype=kernel_dtype,
+                                       input_zero_point=5,
+                                       kernel_zero_point=3,
+                                       input_scale=1.0,
+                                       kernel_scale=1.0,
+                                       kernel_size=(3, 3),
+                                       padding=(0, 0),
+                                       strides=(1, 1),
+                                       dilation=(1, 1),
+                                       data_layout="NCHW",
+                                       kernel_layout="OIHW",
+                                       out_dtype="int32",
+                                       groups=4,
+                                       channels=4)
+        verify(ref_func, qnn_func, data_shape, data_dtype,
+                kernel_shape, kernel_dtype)
+        
+        
+        # Depthwise multiplier = 2
+        data_shape = (10, 4, 16, 16)
+        data_dtype = 'uint8'
+        kernel_shape = (4, 2, 3, 3)
+        kernel_dtype = 'uint8'
+        ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                       data_dtype=data_dtype,
+                                       kernel_shape=kernel_shape,
+                                       kernel_dtype=kernel_dtype,
+                                       input_zero_point=5,
+                                       kernel_zero_point=3,
+                                       input_scale=1.0,
+                                       kernel_scale=1.0,
+                                       kernel_size=(3, 3),
+                                       padding=(0, 0),
+                                       strides=(1, 1),
+                                       dilation=(1, 1),
+                                       data_layout="NCHW",
+                                       kernel_layout="OIHW",
+                                       out_dtype="int32",
+                                       groups=8,
+                                       channels=8)
+        verify(ref_func, qnn_func, data_shape, data_dtype,
+                kernel_shape, kernel_dtype)
+        
+        # uint8 input, NHWC and HWOI
+        # Depthwise multiplier = 1
+        data_shape = (2, 16, 16, 4)
+        data_dtype = 'uint8'
+        kernel_shape = (3, 3, 4, 1)
+        kernel_dtype = 'uint8'
+        ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                       data_dtype=data_dtype,
+                                       kernel_shape=kernel_shape,
+                                       kernel_dtype=kernel_dtype,
+                                       input_zero_point=5,
+                                       kernel_zero_point=3,
+                                       input_scale=1.0,
+                                       kernel_scale=1.0,
+                                       kernel_size=(3, 3),
+                                       padding=(0, 0),
+                                       strides=(1, 1),
+                                       dilation=(1, 1),
+                                       data_layout="NHWC",
+                                       kernel_layout="HWOI",
+                                       out_dtype="int32",
+                                       groups=4,
+                                       channels=4)
+        verify(ref_func, qnn_func, data_shape, data_dtype,
+                kernel_shape, kernel_dtype)
+        
+        # Depthwise multiplier = 2
+        data_shape = (2, 16, 16, 4)
+        data_dtype = 'uint8'
+        kernel_shape = (3, 3, 4, 2)
+        kernel_dtype = 'uint8'
+        ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                       data_dtype=data_dtype,
+                                       kernel_shape=kernel_shape,
+                                       kernel_dtype=kernel_dtype,
+                                       input_zero_point=5,
+                                       kernel_zero_point=3,
+                                       input_scale=1.0,
+                                       kernel_scale=1.0,
+                                       kernel_size=(3, 3),
+                                       padding=(0, 0),
+                                       strides=(1, 1),
+                                       dilation=(1, 1),
+                                       data_layout="NHWC",
+                                       kernel_layout="HWOI",
+                                       out_dtype="int32",
+                                       groups=8,
+                                       channels=8)
+        verify(ref_func, qnn_func, data_shape, data_dtype,
+                kernel_shape, kernel_dtype)
+
 if __name__ == "__main__":
     test_no_zero_point()
     test_input_zero_point()
@@ -738,3 +860,4 @@ if __name__ == "__main__":
     test_broadcast_layout()
     test_tflite_output_multiplier_greater_than_one()
     test_tflite_anistropic_strides()
+    test_depthwise_depth_multiplier()
index f596bc0..a02f919 100644 (file)
@@ -197,6 +197,11 @@ def _conv2d_legalize(attrs, inputs, arg_types):
     if not (dilation[0] == 1 and dilation[1] == 1):
         return None
 
+    # No legalization for depthwise convolutions yet.
+    groups = attrs.get_int("groups")
+    if groups != 1:
+        return None
+
     # Collect the input tensors.
     data_tensor, kernel_tensor = arg_types[0], arg_types[1]
     data_dtype = data_tensor.dtype