[QNN] Convolution 2D Implementation. (#3580)
authorAnimesh Jain <anijain@umich.edu>
Wed, 4 Sep 2019 17:05:22 +0000 (11:05 -0600)
committerZhi <5145158+zhiics@users.noreply.github.com>
Wed, 4 Sep 2019 17:05:22 +0000 (10:05 -0700)
Rebasing. Empty commit.

Clang-format styling.

docs/langref/relay_op.rst
include/tvm/relay/qnn/attrs.h
python/tvm/relay/qnn/op/qnn.py
src/relay/pass/pattern_util.h
src/relay/qnn/op/convolution.cc [new file with mode: 0644]
src/relay/qnn/util.h
tests/python/relay/test_qnn_conv2d.py [new file with mode: 0644]

index 4fad352..57325b5 100644 (file)
@@ -211,6 +211,7 @@ This level supports dialect operators.
    :nosignatures:
 
    tvm.relay.qnn.op.requantize
+   tvm.relay.qnn.op.conv2d
 
 
 Level 1 Definitions
@@ -357,3 +358,4 @@ Level 10 Definitions
 Level 11 Definitions
 --------------------
 .. autofunction:: tvm.relay.qnn.op.requantize
+.. autofunction:: tvm.relay.qnn.op.conv2d
index b8d775c..492ed14 100644 (file)
@@ -125,6 +125,68 @@ struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
   }
 };  // struct QnnConcatenateAttrs
 
+/*! \brief Attribute for QNN Conv2d operator */
+struct QnnConv2DAttrs : public tvm::AttrsNode<QnnConv2DAttrs> {
+  // Traditional conv2d attributes.
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  Array<IndexExpr> dilation;
+  int groups;
+  IndexExpr channels;
+  Array<IndexExpr> kernel_size;
+  std::string data_layout;
+  std::string kernel_layout;
+  std::string out_layout;
+  DataType out_dtype;
+
+  // Quantization related attributes.
+  int32_t input_zero_point;
+  int32_t kernel_zero_point;
+
+  TVM_DECLARE_ATTRS(QnnConv2DAttrs, "relay.attrs.QnnConv2DAttrs") {
+    TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+        .describe("Specifies the strides of the convolution.");
+    TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
+        .describe("If padding is non-zero, then the input is implicitly zero-padded"
+                  "on both sides for padding number of points");
+    TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
+        .describe("Specifies the dilation rate to use for dilated convolution.");
+    TVM_ATTR_FIELD(groups).set_default(1)
+        .describe("Controls the connections between inputs and outputs."
+                  "At groups=1, all inputs are convolved to all outputs."
+                  "At groups=2, the operation becomes equivalent to having two convolution"
+                  "layers side by side, each seeing half the input channels, and producing"
+                  "half the output channels, and both subsequently concatenated.");
+    TVM_ATTR_FIELD(channels)
+        .describe("The number of output channels in the convolution."
+                  " If it is not set, inferred by shape of the weight.")
+        .set_default(NullValue<IndexExpr>());
+    TVM_ATTR_FIELD(kernel_size)
+        .describe("Specifies the dimensions of the convolution window.")
+        .set_default(NullValue<Array<IndexExpr> >());
+    TVM_ATTR_FIELD(data_layout).set_default("NCHW")
+        .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+                  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+                  "dimensions respectively. Convolution is applied on the 'H' and"
+                  "'W' dimensions.");
+    TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
+        .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+                  "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
+                  "dimensions respectively.");
+    TVM_ATTR_FIELD(out_layout).set_default("")
+        .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+                  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+                  "dimensions respectively. Default to be same as input layout.");
+    TVM_ATTR_FIELD(out_dtype)
+        .set_default(NullValue<DataType>())
+        .describe("Output data type, set to explicit type under mixed precision setting");
+    TVM_ATTR_FIELD(input_zero_point)
+        .describe("The zero point of the input tensor.");
+    TVM_ATTR_FIELD(kernel_zero_point)
+        .describe("The zero point of the kernel tensor.");
+  }
+};
+
 }  // namespace qnn
 }  // namespace relay
 }  // namespace tvm
index 7eb0408..354e4d2 100644 (file)
@@ -183,3 +183,83 @@ def concatenate(data,
                              output_scale,
                              output_zero_point,
                              axis)
+
+
+def conv2d(data,
+           kernel,
+           input_zero_point,
+           kernel_zero_point,
+           strides=(1, 1),
+           padding=(0, 0),
+           dilation=(1, 1),
+           groups=1,
+           channels=None,
+           kernel_size=None,
+           data_layout="NCHW",
+           kernel_layout="OIHW",
+           out_layout="",
+           out_dtype="int32"):
+    r"""Quantized 2D convolution.
+
+    This operator convolves quantized data with quantized kernel. The scale of
+    the output quantized tensor is the product of the kernel_scale and
+    input_scale of the input quantized tensors. The zero point of the output
+    quantized tensor is 0. By default, the dtype of output is int32. Please also
+    refer to Requantize operator to understand how to scale back the int32
+    output to (u)int8.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    kernel : tvm.relay.Expr
+        The kernel expressions.
+
+    input_zero_point: int
+           The zero point of the data distribution.
+
+    kernel_zero_point: int
+           The zero point of the quantized_kernel distribution.
+
+    strides : tuple of int, optional
+        The strides of convolution.
+
+    padding : tuple of int, optional
+        The padding of convolution on both sides of inputs before convolution.
+
+    dilation : tuple of int, optional
+        Specifies the dilation rate to be used for dilated convolution.
+
+    groups : int, optional
+        Number of groups for grouped convolution.
+
+    channels : int, optional
+        Number of output channels of this convolution.
+
+    kernel_size : tuple of int, optional
+        The spatial of the convolution kernel.
+
+    data_layout : str, optional
+        Layout of the input.
+
+    kernel_layout : str, optional
+        Layout of the kernel.
+
+    out_layout : str, optional
+        Layout of the output, by default, out_layout is the same as data_layout
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    return _make.conv2d(data, kernel,
+                        input_zero_point, kernel_zero_point,
+                        strides, padding, dilation,
+                        groups, channels, kernel_size,
+                        data_layout, kernel_layout, out_layout, out_dtype)
index 18e5df3..682e2e3 100644 (file)
@@ -415,6 +415,71 @@ static inline Expr Full(Expr fill_value,
   return CallNode::make(op, {fill_value}, Attrs(attrs), {});
 }
 
+static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
+                          Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
+                          IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
+                          std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+  auto attrs = make_node<Conv2DAttrs>();
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
+  attrs->groups = groups;
+  attrs->channels = std::move(channels);
+  attrs->kernel_size = std::move(kernel_size);
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = std::move(out_layout);
+  attrs->out_dtype = std::move(out_dtype);
+  static const Op& op = Op::Get("nn.conv2d");
+  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+}
+
+static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
+  auto attrs = make_node<ReduceAttrs>();
+  attrs->axis = std::move(axis);
+  attrs->keepdims = keepdims;
+  attrs->exclude = exclude;
+  static const Op& op = Op::Get("sum");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+static inline Expr Reshape(Expr data, Array<Integer> newshape) {
+  auto attrs = make_node<ReshapeAttrs>();
+  attrs->newshape = std::move(newshape);
+  attrs->reverse = false;
+  static const Op& op = Op::Get("reshape");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
+                             Array<IndexExpr> padding, std::string layout, bool ceil_mode,
+                             bool count_include_pad) {
+  auto attrs = make_node<AvgPool2DAttrs>();
+  attrs->pool_size = std::move(pool_size);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->layout = std::move(layout);
+  attrs->ceil_mode = ceil_mode;
+  attrs->count_include_pad = count_include_pad;
+  static const Op& op = Op::Get("nn.avg_pool2d");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value) {
+  auto attrs = make_node<PadAttrs>();
+  attrs->pad_value = pad_value;
+  attrs->pad_width = std::move(pad_width);
+  static const Op& op = Op::Get("nn.pad");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+static inline Expr Tile(Expr data, Array<Integer> reps) {
+  auto attrs = make_node<TileAttrs>();
+  attrs->reps = reps;
+  static const Op& op = Op::Get("tile");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
 Expr MakeConcatenate(Expr data, int axis);
 
 Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc
new file mode 100644 (file)
index 0000000..6e1d13e
--- /dev/null
@@ -0,0 +1,485 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/qnn/op/convolution.cc
+ * \brief Property def of qnn convolution operator.
+ */
+#include <tvm/data_layout.h>
+#include <tvm/ir_pass.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/base.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/qnn/attrs.h>
+#include <tvm/relay/transform.h>
+#include "../../op/nn/convolution.h"
+#include "../../pass/pattern_util.h"
+#include "../util.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+// relay.op.qnn.conv2d
+TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs);
+
+// Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w
+using WorkloadType = std::tuple<int, int, int, int, int>;
+
+/*
+ * \brief Get the conv parameters like batch_size, kernel_height etc.
+ * \param ref_call The original callnode.
+ * \param param The qnn conv2d attributes.
+ * \return A tuple of workload.
+ */
+WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv2DAttrs* param) {
+  // Get conv parameters.
+  auto get_shape = [](const Type& type) {
+    auto input_tt = type.as<TensorTypeNode>();
+    CHECK(input_tt != nullptr) << "Type information missing."
+                               << " Please run infer_type pass.";
+    return input_tt->shape;
+  };
+
+  const auto in_shape = get_shape(arg_types[0]);
+  int batch_size, in_channels;
+  if (param->data_layout == "NCHW") {
+    batch_size = get_const_int(in_shape[0]);
+    in_channels = get_const_int(in_shape[1]);
+  } else if (param->data_layout == "NHWC") {
+    batch_size = get_const_int(in_shape[0]);
+    in_channels = get_const_int(in_shape[3]);
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
+  }
+
+  const auto kernel_shape = get_shape(arg_types[1]);
+  int out_channels, kernel_h, kernel_w;
+  if (param->kernel_layout == "OIHW") {
+    out_channels = get_const_int(kernel_shape[0]);
+    kernel_h = get_const_int(kernel_shape[2]);
+    kernel_w = get_const_int(kernel_shape[3]);
+  } else if (param->kernel_layout == "HWIO") {
+    kernel_h = get_const_int(kernel_shape[0]);
+    kernel_w = get_const_int(kernel_shape[1]);
+    out_channels = get_const_int(kernel_shape[3]);
+  } else if (param->kernel_layout == "HWOI") {
+    kernel_h = get_const_int(kernel_shape[0]);
+    kernel_w = get_const_int(kernel_shape[1]);
+    out_channels = get_const_int(kernel_shape[2]);
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
+  }
+  return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w);
+}
+
+/*
+ * \brief Fallback to simpler lowering for dilation or depthwise conv.
+ * \param data The input expr.
+ * \param weight The weight expr.
+ * \param zp_data The data zero point expr.
+ * \param zp_kernel The kernel zero point expr.
+ * \param param The qnn conv2d attributes.
+ * \return The fallback lowered sequence of Relay expr.
+ * \note In case of dilation, normal lowering would require a dilated pool.
+ *       Since, we don't have dilated pool, we fallback to a simpler sequence of
+ *       Relay operations. This will potentially lead to performance degradation
+ *       as the convolution is called on int32 tensors instead of int8 tensors.
+ */
+Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& zp_data,
+                    const Expr& zp_kernel, const QnnConv2DAttrs* param) {
+  auto shifted_data = data;
+  if (param->input_zero_point != 0) {
+    shifted_data = Subtract(Cast(data, Int(32)), zp_data);
+  }
+
+  auto shifted_kernel = weight;
+  if (param->kernel_zero_point != 0) {
+    shifted_kernel = Subtract(Cast(weight, Int(32)), zp_kernel);
+  }
+
+  return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
+                param->groups, param->channels, param->kernel_size, param->data_layout,
+                param->kernel_layout, param->out_layout, param->out_dtype);
+}
+
+/*
+ * \brief Pad the input data.
+ * \param data The input expr.
+ * \return The padded input expr.
+ * \note For quantized convolution, the input has to be padded with zero point
+ *       instead of zero. This might lead to performance degradation as pad
+ *       cannot be fused with conv in Relay. In case we see performance
+ *       degradation, we can change the conv2D API to accept a pad_const value.
+ */
+Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
+  // 1) Pad the input data
+  auto padded_data = data;
+  auto pad_h_value = get_const_int(param->padding[0]);
+  auto pad_w_value = get_const_int(param->padding[1]);
+  if (pad_h_value != 0 || pad_w_value != 0) {
+    Array<IndexExpr> pad_n({0, 0});
+    Array<IndexExpr> pad_c({0, 0});
+    Array<IndexExpr> pad_h({param->padding[0], param->padding[0]});
+    Array<IndexExpr> pad_w({param->padding[1], param->padding[1]});
+
+    Array<Array<IndexExpr>> pad_width;
+    if (param->data_layout == "NCHW") {
+      pad_width = {pad_n, pad_c, pad_h, pad_w};
+    } else if (param->data_layout == "NHWC") {
+      pad_width = {pad_n, pad_h, pad_w, pad_c};
+    } else {
+      LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
+    }
+    padded_data = Pad(data, pad_width, param->input_zero_point);
+  }
+  return padded_data;
+}
+
+/*
+ * \brief Calculates the first term in the qnn.conv2d lowering sequence.
+ * \param data The input expr.
+ * \param weight The weight expr.
+ * \param param The qnn conv2d attributes.
+ * \return The sequence of Relay operatos for term1.
+ * \note The term1 is
+ *       Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s)
+ *       This is just conv2d on int tensors.
+ */
+Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2DAttrs* param) {
+  // Lowering for Term 1
+  Array<IndexExpr> padding({0, 0});
+  return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups,
+                param->channels, param->kernel_size, param->data_layout, param->kernel_layout,
+                param->out_layout, param->out_dtype);
+}
+
+/*
+ * \brief Calculates the second term in the qnn.conv2d lowering sequence.
+ * \param padded_data The padded data expr.
+ * \param zp_kernel The kernel zero point expr.
+ * \param param The qnn conv2d attributes.
+ * \param kernel_h The height of kernel.
+ * \param kernel_w The width of kernel.
+ * \return The sequence of Relay operatos for term2.
+ * \note The term2 looks like this
+ *
+ *       Sigma(c,r,s) zp_w * QA(n, c, h + r, w + s)
+ *
+ *       Second term is not directly represetable by one Relay operator.
+ *       However, deeper analysis shows that we can reduce r,s using avg_pool2d,
+ *       followed by a reduce on the C axis. Using avg_pool2d also gives an
+ *       opportunity to reuse alter_op_layout infrastructure.
+ */
+Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnConv2DAttrs* param,
+                      int kernel_h, int kernel_w, int out_channels) {
+  auto casted_t2 = Cast(padded_data, Int(32));
+
+  // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
+  // Since, this is integer division (floor), we can first multiply the data by the pool_size and
+  // then perform avg_pool2d. Reversing this causes inaccuracy due to floor division.
+  auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w));
+  Array<IndexExpr> padding({0, 0});
+
+  // If the pool_size is 1x1, we don't need avg_pool2d.
+  auto reduced_hw_t2 = scaled_hw_t2;
+  if (kernel_h * kernel_w != 1) {
+    reduced_hw_t2 =
+        AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout,
+                  false,   // ceil_mode
+                  false);  // count_include_pad
+  }
+
+  // Reduce the C dimension. Find the dimension.
+  Array<Integer> axes_t2;
+  if (param->data_layout == "NCHW") {
+    axes_t2 = {1};
+  } else if (param->data_layout == "NHWC") {
+    axes_t2 = {3};
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
+  }
+  // Keep dims true to retain 4D tensor
+  auto reduced_t2 = Sum(reduced_hw_t2, axes_t2, true, false);
+  auto multiplied_t2 = reduced_t2;
+  if (param->kernel_zero_point != 1) {
+    multiplied_t2 = Multiply(zp_kernel, reduced_t2);
+  }
+
+  // Replicate to go back to NHWC/NCHW. This is not necessarily needed, but it fails AlterOpLayout.
+  // We can remove this once AlterOpLayout refactoring completes -
+  // https://github.com/dmlc/tvm/issues/3670
+  Array<Integer> reps;
+  if (param->data_layout == "NCHW") {
+    reps = {1, out_channels, 1, 1};
+  } else if (param->data_layout == "NHWC") {
+    reps = {1, 1, 1, out_channels};
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
+  }
+  return Tile(multiplied_t2, reps);
+}
+
+/*
+ * \brief Calculates the third term in the qnn.conv2d lowering sequence.
+ * \param weight The weight expr.
+ * \param zp_data The data zero point expr.
+ * \param param The qnn conv2d attributes.
+ * \param batch_size The batch size.
+ * \param out_channels The number of output channels.
+ * \return The sequence of Relay operatos for term3.
+ * \note The term3 looks like this
+ *
+ *       Sigma(c,r,s) zp_a * QW(k, c, r, s)
+ *
+ *       This can be achieved by calling reduce on c, r and s axis, resulting in
+ *       a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
+ *       format.
+ */
+Expr Conv2DThirdTerm(const Expr& weight, const Expr& zp_data, const QnnConv2DAttrs* param,
+                     int batch_size, int out_channels) {
+  // Find which dimensions are C, R, S.
+  Array<Integer> axes_t3;
+  if (param->kernel_layout == "OIHW") {
+    // For OIHW kernel layout, IHW are reduce axis
+    axes_t3 = {1, 2, 3};
+  } else if (param->kernel_layout == "HWIO") {
+    axes_t3 = {0, 1, 2};
+  } else if (param->kernel_layout == "HWOI") {
+    axes_t3 = {0, 1, 3};
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout";
+  }
+  auto reduced_t3 = Sum(Cast(weight, Int(32)), axes_t3, false, false);
+
+  // Find the newshape depending on NCHW/NHWC layout.
+  Array<Integer> newshape;
+  if (param->data_layout == "NCHW") {
+    newshape = {batch_size, out_channels, 1, 1};
+  } else if (param->data_layout == "NHWC") {
+    newshape = {batch_size, 1, 1, out_channels};
+  } else {
+    LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
+  }
+  auto reshaped_t3 = Reshape(reduced_t3, newshape);
+
+  if (param->input_zero_point == 1) {
+    return reshaped_t3;
+  }
+  return Multiply(zp_data, reshaped_t3);
+}
+
+/*
+ * \brief Calculates the fourth term in the qnn.conv2d lowering sequence.
+ * \param param The qnn conv2d attributes.
+ * \param batch_size The batch size.
+ * \param in_channels The number of input channels.
+ * \param kernel_h The height of kernel.
+ * \param kernel_w The width of kernel.
+ * \return The sequence of Relay operatos for term4.
+ * \note The term4 looks like this
+ *
+ *       Sigma(c,r,s) zp_a * zp_w
+ *
+ */
+Expr Conv2DFourthTerm(const QnnConv2DAttrs* param, int batch_size, int in_channels, int kernel_h,
+                      int kernel_w) {
+  int scalar_term4 =
+      param->input_zero_point * param->kernel_zero_point * in_channels * kernel_h * kernel_w;
+  return MakeConstantScalar(Int(32), scalar_term4);
+}
+
+/*
+ * \brief Combines different terms of qnn conv2d lowering.
+ * \param term1 The term1 of qnn conv2d lowering.
+ * \param term2 The term2 of qnn conv2d lowering.
+ * \param term3 The term3 of qnn conv2d lowering.
+ * \param term4 The term4 of qnn conv2d lowering.
+ * \param param The qnn conv2d attributes.
+ * \return The combined sequence of relay operations.
+ * \note The combined operation looks like this
+ *
+ *       Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s)  // Term1
+ *     - Sigma(c,r,s) zp_w * QA(n, c, h + r, w + s)            // Term2
+ *     - Sigma(c,r,s) zp_a * QW(k, c, r, s)                    // Term3
+ *     + Sigma(c,r,s) zp_a * zp_w                              // Term4
+ *
+ */
+Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, const Expr& term4,
+                        const QnnConv2DAttrs* param) {
+  if (param->input_zero_point == 0 && param->kernel_zero_point == 0) {
+    // term 2, 3 and 4 become zero.
+    return term1;
+  } else if (param->input_zero_point == 0 && param->kernel_zero_point != 0) {
+    // term 3 and term 4 become zero.
+    return Subtract(term1, term2);
+  } else if (param->input_zero_point != 0 && param->kernel_zero_point == 0) {
+    // term 2 and term 4 become zero.
+    return Subtract(term1, term3);
+  } else {
+    auto data_term = Subtract(term1, term2);
+    // Putting constant terms together, so that constant folding can fold it.
+    auto const_term = Subtract(term4, term3);
+    return Add(data_term, const_term);
+  }
+}
+
+/*
+ * \brief Forward rewrite the qnn conv2d op.
+ * \param attrs The QNN conv2d attrs.
+ * \param new_args The new mutated args to the call node.
+ * \param arg_types The types of input and output.
+ * \return The sequence of Relay ops for qnn cov2d op.
+ * \node Lowering of the qnn.conv2d operator
+ *       A quantized tensor is represented in following manner
+ *          A = scale_a x (QA - zp_A)
+ *       where QA is quantized tensor, scale_a and zp_A are quantizations
+ *       params.
+ *
+ *       Quantized convlution convolves two quantized tensors and returns a
+ *       quantized tensor of default dtype of int32, with scale equaling to the
+ *       product of scales of input tensors, and a zero point of zero.
+ *
+ *       For symmetric quantization, the zp_* for all tensors is 0. So, the
+ *       lowering of qnn.conv2d is
+ *
+ *          QA(n, ic, oh + r, ow + s) (conv) QW(oc, ic, r, s)
+ *
+ *       For asymmetric computation, we can perform similar unrolling. We can
+ *       find more details at
+ *       https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh
+ *       The computation gets unrolled into following 4 terms
+ *
+ *            Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s)  // Term1
+ *          - Sigma(c,r,s) zp_w * QA(n, c, h + r, w + s)            // Term2
+ *          - Sigma(c,r,s) zp_a * QW(k, c, r, s)                    // Term3
+ *          + Sigma(c,r,s) zp_a * zp_w                              // Term4
+ *
+ *       Term3 and Term4 can be computed at compile time.
+ *
+ *       Key points to notice:
+ *         1) Padding is done explicitly because the input has to be padded with
+ *         zero point. This might leave some performance opportunity at the
+ *         table. Can be avoided by modifying conv2d API to accept the
+ *         pad_const_value.
+ *         2) Second term is not directly represetable by one Relay operator.
+ *         However, deeper analysis shows that we can reduce r,s using
+ *         avg_pool2d, followed by a reduce on the C axis. Using avg_pool2d also
+ *         gives an opportunity to reuse alter_op_layout infrastructure.
+ *         3) For dilated conv, in current lowering, we need dilated pool. So as
+ *         a workaround, we fall back to simpler lowering using int32 conv if
+ *         the conv is dilated. We fallback also in case of depthwise conv.
+ *
+ *       The whole process can be broken down into following steps
+ *       * Assertion checks for exisiting support, fallback if necessary
+ *       * Pad the input.
+ *       * Get Term1.
+ *       * Get Term2.
+ *       * Get Term3.
+ *       * Get Term4.
+ *       * Combine the terms.
+ */
+Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
+                           const Array<tvm::relay::Type>& arg_types) {
+  CHECK_EQ(new_args.size(), 2);
+  Expr data = new_args[0];
+  Expr weight = new_args[1];
+  const auto* param = attrs.as<QnnConv2DAttrs>();
+  CHECK(param != nullptr);
+  // Assertion checks for exisiing support.
+  CHECK_EQ(param->padding.size(), 2) << "qnn.conv2d only supports 2D padding";
+  CHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC")
+      << "qnn.conv2d supports only NCHW/NHWC input data layout.";
+  CHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" ||
+        param->kernel_layout == "HWOI")
+      << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";
+
+  int batch_size, in_channels, out_channels, kernel_h, kernel_w;
+  std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
+      GetWorkload(arg_types, param);
+  auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
+  auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);
+
+  // Fallback to int32 conv if there is dilation or depthwise conv2d
+  CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
+  auto dilation_h = get_const_int(param->dilation[0]);
+  auto dilation_w = get_const_int(param->dilation[1]);
+  if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
+    return Conv2DFallBack(data, weight, zp_data, zp_kernel, param);
+  }
+
+  auto padded_data = Conv2DPadInput(data, param);
+  auto term1 = Conv2DFirstTerm(padded_data, weight, param);
+  auto term2 = Conv2DSecondTerm(padded_data, zp_kernel, param, kernel_h, kernel_w, out_channels);
+  auto term3 = Conv2DThirdTerm(weight, zp_data, param, batch_size, out_channels);
+  auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
+  return Conv2DCombineTerms(term1, term2, term3, term4, param);
+}
+
+// Positional relay function to create quantized conv2d operator
+// used by frontend FFI.
+Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t kernel_zero_point,
+                   Array<IndexExpr> strides, Array<IndexExpr> padding, Array<IndexExpr> dilation,
+                   int groups, IndexExpr channels, Array<IndexExpr> kernel_size,
+                   std::string data_layout, std::string kernel_layout, std::string out_layout,
+                   DataType out_dtype) {
+  auto attrs = make_node<QnnConv2DAttrs>();
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
+  attrs->groups = groups;
+  attrs->channels = std::move(channels);
+  attrs->kernel_size = std::move(kernel_size);
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = std::move(out_layout);
+  attrs->out_dtype = std::move(out_dtype);
+  attrs->input_zero_point = std::move(input_zero_point);
+  attrs->kernel_zero_point = std::move(kernel_zero_point);
+  static const Op& op = Op::Get("qnn.conv2d");
+  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+}
+
+RELAY_REGISTER_OP("qnn.conv2d")
+.describe(R"code(2D quantized convolution layer.
+This operator convolves quantized weight with quantized data. The scale of the
+output quantized tensor is the product of the weight_scale and input_scale of
+the input quantized tensors. The zero point of the output quantized tensor is
+0. By default, the dtype of output is int32. Please also refer to Requantize
+operator to understand how to scale back the int32 output to (u)int8.
+- **data**: This depends on the `layout` parameter. Input is 4D array of shape
+            (batch_size, in_channels, height, width) if `layout` is `NCHW`.
+- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
+- **out**:  This depends on the `layout` parameter. Output is 4D array of shape
+            (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.QnnConv2DAttrs")
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The quantized input data tensor.")
+.add_argument("weight", "Tensor", "The quantized weight tensor.")
+.set_support_level(11)
+.add_type_rel("QnnConv2D", Conv2DRel<QnnConv2DAttrs>)
+.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize);
+
+TVM_REGISTER_API("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D);
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm
index 208e04f..9482331 100644 (file)
@@ -86,6 +86,12 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh
   return RequantizeLower(data, attrs.operator->(), input_shape, out_dtype);
 }
 
+static inline int64_t get_const_int(const tvm::Expr& x) {
+  auto* value_ptr = as_const_int(x);
+  CHECK(value_ptr) << "Expr is not a constant int";
+  return value_ptr[0];
+}
+
 }  // namespace qnn
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_qnn_conv2d.py b/tests/python/relay/test_qnn_conv2d.py
new file mode 100644 (file)
index 0000000..99cb482
--- /dev/null
@@ -0,0 +1,624 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.testing import create_workload
+from tvm.relay.testing import run_infer_type
+from tvm.contrib import graph_runtime
+
+def get_ref_func(data,
+                 kernel,
+                 input_zero_point,
+                 kernel_zero_point,
+                 kernel_size,
+                 padding,
+                 strides,
+                 dilation,
+                 data_layout,
+                 kernel_layout,
+                 out_dtype):
+    casted_data = relay.op.cast(data, "int32")
+    casted_kernel = relay.op.cast(kernel, "int32")
+    shifted_data = relay.op.subtract(casted_data,
+            relay.const(input_zero_point, "int32"))
+    shifted_kernel = relay.op.subtract(casted_kernel,
+            relay.const(kernel_zero_point, "int32"))
+    func = relay.op.nn.conv2d(shifted_data,
+                             shifted_kernel,
+                             padding=padding,
+                             strides=strides,
+                             dilation=dilation,
+                             kernel_size=kernel_size,
+                             out_dtype=out_dtype,
+                             data_layout=data_layout,
+                             kernel_layout=kernel_layout)
+
+    func = relay.Function(relay.analysis.free_vars(func), func)
+    return func
+
+def get_qnn_func(data,
+                 kernel,
+                 input_zero_point,
+                 kernel_zero_point,
+                 kernel_size,
+                 padding,
+                 strides,
+                 dilation,
+                 data_layout,
+                 kernel_layout,
+                 out_dtype):
+    func = relay.qnn.op.conv2d(
+            data, kernel,
+            input_zero_point=input_zero_point,
+            kernel_zero_point=kernel_zero_point,
+            kernel_size=kernel_size,
+            strides=strides,
+            dilation=dilation,
+            padding=padding,
+            out_dtype=out_dtype,
+            data_layout=data_layout,
+            kernel_layout=kernel_layout)
+
+    mod = relay.Function(relay.analysis.free_vars(func), func)
+    mod = relay.Module.from_expr(mod)
+    mod = relay.qnn.transform.CanonicalizeOps()(mod)
+    return mod
+
+def get_funcs(data_shape,
+              data_dtype,
+              kernel_shape,
+              kernel_dtype,
+              input_zero_point,
+              kernel_zero_point,
+              kernel_size,
+              padding,
+              strides,
+              dilation,
+              data_layout,
+              kernel_layout,
+              out_dtype):
+    data = relay.var("data", shape=data_shape,
+            dtype=data_dtype)
+    kernel = relay.var("kernel", shape=kernel_shape,
+            dtype=kernel_dtype)
+    ref_func = get_ref_func(data,
+                            kernel,
+                            input_zero_point,
+                            kernel_zero_point,
+                            kernel_size,
+                            padding,
+                            strides,
+                            dilation,
+                            data_layout,
+                            kernel_layout,
+                            out_dtype)
+    ref_func = run_infer_type(ref_func)
+    qnn_func = get_qnn_func(data,
+                            kernel,
+                            input_zero_point,
+                            kernel_zero_point,
+                            kernel_size,
+                            padding,
+                            strides,
+                            dilation,
+                            data_layout,
+                            kernel_layout,
+                            out_dtype)
+    return (ref_func, qnn_func)
+
+def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
+        kernel_dtype):
+    def get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype):
+        # Keeping inputs multiple of 4 because of a bug in Average Pool2d
+        # https://discuss.tvm.ai/t/pool2d-gives-bad-output-for-integer-inputs/3377
+        low = -128
+        high = 127
+        if data_dtype == "uint8":
+            low = 0
+            high = 255
+        golden_data = np.random.random_integers(low=low, high=high,
+                size=data_shape).astype(data_dtype)
+        low = -128
+        high = 127
+        if kernel_dtype == "uint8":
+            low = 0
+            high = 255
+        golden_weight = np.random.random_integers(low=low, high=high,
+                size=kernel_shape).astype(kernel_dtype)
+        return (golden_data, golden_weight)
+
+
+    def get_output(func, golden_inputs):
+        with relay.build_config(opt_level=2):
+            golden_data, golden_weight = golden_inputs
+            params = {'kernel': golden_weight}
+            graph, lib, params = relay.build(func, "llvm", params=params)
+            mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+            mod.set_input("data", golden_data)
+            mod.set_input(**params)
+            mod.run()
+            res = mod.get_output(0).asnumpy()
+            return res
+    golden_inputs = get_inputs(data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+    golden_output = get_output(ref_func, golden_inputs)
+    qnn_output = get_output(qnn_func, golden_inputs)
+    np.testing.assert_equal(qnn_output, golden_output)
+
+def no_zero_point_test():
+    # uint8 input
+    data_shape = (2, 1, 2, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 1, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=0,
+                                   kernel_zero_point=0,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+    # int8 input
+    data_shape = (2, 1, 2, 4)
+    data_dtype = 'int8'
+    kernel_shape = (3, 1, 2, 2)
+    kernel_dtype = 'int8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=0,
+                                   kernel_zero_point=0,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+def kernel_zero_point_test():
+    # uint8 input
+    data_shape = (2, 4, 2, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=0,
+                                   kernel_zero_point=1,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+    # int8 input
+    data_shape = (2, 1, 2, 4)
+    data_dtype = 'int8'
+    kernel_shape = (3, 1, 2, 2)
+    kernel_dtype = 'int8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=0,
+                                   kernel_zero_point=5,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+
+def input_zero_point_test():
+    # uint8 input
+    data_shape = (2, 4, 2, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=0,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+    # int8 input
+    data_shape = (2, 4, 2, 4)
+    data_dtype = 'int8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'int8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=0,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+def both_zero_point_test():
+    # uint8 input
+    data_shape = (2, 4, 2, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=3,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+    # int8 input
+    data_shape = (2, 4, 2, 4)
+    data_dtype = 'int8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'int8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=3,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+def layout_test():
+    # uint8 input
+    data_shape = (2, 2, 4, 4) # NHWC
+    data_dtype = 'uint8'
+    kernel_shape = (2, 2, 4, 3) # HWIO
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=3,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NHWC",
+                                   kernel_layout="HWIO",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+    # NHWC and HWIO layout. Used in depthwise conv.
+    data_shape = (2, 2, 4, 1) # NHWC
+    data_dtype = 'uint8'
+    kernel_shape = (2, 2, 1, 1) # HWOI
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=3,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NHWC",
+                                   kernel_layout="HWOI",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+
+
+def padding_test():
+    # uint8 input
+    data_shape = (1, 4, 2, 2)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=8,
+                                   kernel_zero_point=5,
+                                   kernel_size=(2, 2),
+                                   padding=(1, 1),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+    # Try different layout
+    data_shape = (2, 2, 4, 4) # NHWC
+    data_dtype = 'uint8'
+    kernel_shape = (2, 2, 4, 3) # HWIO
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=8,
+                                   kernel_zero_point=3,
+                                   kernel_size=(2, 2),
+                                   padding=(1, 1),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NHWC",
+                                   kernel_layout="HWIO",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+def dilation_test():
+    # uint8 input
+    data_shape = (2, 4, 4, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=3,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(2, 2),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+
+def const_folding_test():
+    data_shape = (2, 4, 2, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 4, 2, 2)
+    kernel_dtype = 'uint8'
+
+    golden_weight = np.random.random_integers(low=0, high=255,
+            size=kernel_shape).astype(kernel_dtype)
+    data = relay.var("data", shape=data_shape,
+            dtype=data_dtype)
+    kernel = relay.const(golden_weight)
+    qnn_func = get_qnn_func(data,
+                            kernel,
+                            input_zero_point=8,
+                            kernel_zero_point=3,
+                            kernel_size=(2, 2),
+                            padding=(0, 0),
+                            strides=(1, 1),
+                            dilation=(1, 1),
+                            data_layout="NCHW",
+                            kernel_layout="OIHW",
+                            out_dtype="int32")
+    folded_mod = transform.FoldConstant()(qnn_func)
+    folded_func = folded_mod["main"]
+    assert "reshape" not in folded_func.astext()
+
+def kernel_size_1x1_test():
+    # uint8 input
+    data_shape = (2, 4, 2, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 4, 1, 1)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=5,
+                                   kernel_zero_point=3,
+                                   kernel_size=(1, 1),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    assert 'avg_pool2d' not in qnn_func.astext()
+    verify(ref_func, qnn_func, data_shape, data_dtype,
+            kernel_shape, kernel_dtype)
+
+def tflite_large_irregular_test():
+    # uint8 input
+    data_shape = (1, 1024, 1, 1)
+    data_dtype = 'uint8'
+    kernel_shape = (1001, 1024, 1, 1)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=127,
+                                   kernel_zero_point=127,
+                                   kernel_size=(1, 1),
+                                   padding=(0, 0),
+                                   strides=(1, 1),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    golden_data = np.full(data_shape, 127).astype('uint8')
+    golden_weight = np.full(kernel_shape, 127).astype('uint8')
+
+    with relay.build_config(opt_level=2):
+        params = {'kernel': golden_weight}
+        graph, lib, params = relay.build(qnn_func, "llvm", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+        mod.set_input("data", golden_data)
+        mod.set_input(**params)
+        mod.run()
+        qnn_output = mod.get_output(0).asnumpy()
+    golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
+    np.testing.assert_equal(qnn_output, golden_output)
+
+def tflite_output_multiplier_greater_than_one():
+    # uint8 input
+    data_shape = (2, 1, 2, 4)
+    data_dtype = 'uint8'
+    kernel_shape = (3, 1, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=128,
+                                   kernel_zero_point=128,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(2, 2),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    golden_data = 128 + np.array((1, 1, 1, 1,
+                                  2, 2, 2, 2,
+                                  1, 2, 3, 4,
+                                  1, 2, 3, 4)).reshape(data_shape).astype('uint8')
+    golden_weight = 128 + np.array((1, 2, 3, 4,
+                                    -1, 1, -1, 1,
+                                    -1, -1, 1, 1)).reshape(kernel_shape)
+    golden_weight = golden_weight.astype('uint8')
+
+    with relay.build_config(opt_level=2):
+        params = {'kernel': golden_weight}
+        graph, lib, params = relay.build(qnn_func, "llvm", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+        mod.set_input("data", golden_data)
+        mod.set_input(**params)
+        mod.run()
+        qnn_output = mod.get_output(0).asnumpy()
+    golden_output = np.array((17, 17,
+                              0, 0,
+                              2, 2,
+                              16, 36,
+                              2, 2,
+                              0, 0)).reshape(2, 3, 1, 2)
+    np.testing.assert_equal(qnn_output, golden_output)
+
+def tflite_anistropic_strides():
+    # uint8 input
+    data_shape = (1, 1, 3, 6)
+    data_dtype = 'uint8'
+    kernel_shape = (1, 1, 2, 2)
+    kernel_dtype = 'uint8'
+    ref_func, qnn_func = get_funcs(data_shape=data_shape,
+                                   data_dtype=data_dtype,
+                                   kernel_shape=kernel_shape,
+                                   kernel_dtype=kernel_dtype,
+                                   input_zero_point=127,
+                                   kernel_zero_point=127,
+                                   kernel_size=(2, 2),
+                                   padding=(0, 0),
+                                   strides=(1, 3),
+                                   dilation=(1, 1),
+                                   data_layout="NCHW",
+                                   kernel_layout="OIHW",
+                                   out_dtype="int32")
+    golden_data = np.array((133, 131, 129, 125, 123, 121,
+                            135, 133, 131, 123, 121, 119,
+                            137, 135, 133, 121, 119, 117)).reshape(data_shape)
+    golden_data = golden_data.astype('uint8')
+    golden_weight = np.array((129, 131, 133, 135)).reshape(kernel_shape)
+    golden_weight = golden_weight.astype('uint8')
+
+    with relay.build_config(opt_level=2):
+        params = {'kernel': golden_weight}
+        graph, lib, params = relay.build(qnn_func, "llvm", params=params)
+        mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+        mod.set_input("data", golden_data)
+        mod.set_input(**params)
+        mod.run()
+        qnn_output = mod.get_output(0).asnumpy()
+    golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
+    np.testing.assert_equal(qnn_output, golden_output)
+
+if __name__ == "__main__":
+    no_zero_point_test()
+    input_zero_point_test()
+    kernel_zero_point_test()
+    both_zero_point_test()
+    layout_test()
+    padding_test()
+    dilation_test()
+    const_folding_test()
+    kernel_size_1x1_test()
+    tflite_large_irregular_test()
+    tflite_output_multiplier_greater_than_one()
+    tflite_anistropic_strides()