[BYOC][ETHOSN] Add support for quantized convolution (#6335)
authormbaret <55580676+mbaret@users.noreply.github.com>
Thu, 27 Aug 2020 21:03:45 +0000 (22:03 +0100)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 21:03:45 +0000 (06:03 +0900)
* [BYOC][ETHOSN] Add support for quantized convolution

This PR adds support for quantized convolution. This
includes mapping it via a composite function and all
the necessary methods to convert from Relay to the
APIs in Support Library.

Co-authored-by: Leo Blonk <Leo.Blonk@arm.com>
Co-authored-by: Tristan O'Connor <tristan.oconnor@arm.com>
* Fix padding change

Change-Id: I0794b0ac6190478e2d1b858ad0dd90f37fc0207b

* Add docs to Tvm2Npu methods

Change-Id: Iab865619b449a3d0dd6bb0dbdcb198acd529fc4e

* Remove generate tests

Change-Id: I51f90499f7ce82a1ce49f0731d3d50627e1d0225

Co-authored-by: Leo Blonk <Leo.Blonk@arm.com>
Co-authored-by: Tristan O'Connor <tristan.oconnor@arm.com>
python/tvm/relay/op/contrib/ethosn.py
src/relay/backend/contrib/ethosn/codegen.cc
src/relay/backend/contrib/ethosn/codegen_ethosn.h
src/relay/backend/contrib/ethosn/ethosn_api.cc
src/relay/backend/contrib/ethosn/ethosn_api.h
tests/python/contrib/test_ethosn/infrastructure.py
tests/python/contrib/test_ethosn/test_conv2d.py [new file with mode: 0644]

index de70297..a93b0e5 100644 (file)
@@ -18,7 +18,9 @@
 """Arm(R) Ethos(TM) -N NPU supported operators."""
 from enum import Enum
 import tvm.ir
+from ...dataflow_pattern import wildcard, is_op, is_constant
 from ... import qnn as _qnn
+from .register import register_pattern_table
 from . import _ethosn as support
 
 
@@ -40,6 +42,30 @@ def ethosn_available():
     return Available.SW_AND_HW if hw else Available.SW_ONLY
 
 
+@register_pattern_table("ethos-n")
+def pattern_table():
+    """Get the Ethos-N compiler pattern table."""
+    def qnn_conv_pattern():
+        pattern = is_op('nn.pad')(wildcard()) | wildcard()
+        pattern = is_op('qnn.conv2d')(
+            pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
+        pattern = is_op('nn.bias_add')(pattern, is_constant())
+        pattern = is_op('qnn.requantize')(
+            pattern, is_constant(), is_constant(), is_constant(), is_constant())
+        return pattern
+
+    def check_conv2d(extract):
+        """Check if a conv2d is supported by Ethos-N."""
+        if not ethosn_available():
+            return False
+
+        return support.conv2d(extract)
+
+    return [
+        ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
+    ]
+
+
 @tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n")
 def qnn_concatenate(attrs, args):
     """Check if a concatenate is supported by Ethos-N."""
index f66eb94..58cd5bf 100644 (file)
@@ -50,6 +50,16 @@ bool IsEthosnOp(const Call& call, const std::string& op_name) {
   }
 }
 
+bool IsEthosnFunc(const Call& call, const std::string& op_name) {
+  if (call->op->IsInstance<FunctionNode>()) {
+    Function func = Downcast<Function>(call->op);
+    CHECK(func.defined());
+    auto name_node = func->GetAttr<String>(attr::kComposite);
+    return name_node.value() == op_name;
+  }
+  return false;
+}
+
 std::map<Expr, std::vector<sl::TensorInfo>> InferTensorsVisitor::Infer(const Expr& expr) {
   tensor_table_.clear();
   CHECK(expr->checked_type().defined());
@@ -69,7 +79,11 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) {
   EthosnError err;
   Call call = GetRef<Call>(cn);
   // Determine call -> NPU mapping
-  if (IsEthosnOp(call, "qnn.concatenate")) {
+  if (IsEthosnFunc(call, "ethos-n.qnn_conv2d")) {
+    ConvolutionParams params;
+    err += EthosnAPI::QnnConv2d(cn->op.as<FunctionNode>()->body, &params);
+    tensor_table_[cn->args[0]] = {params.activation_info};
+  } else if (IsEthosnOp(call, "qnn.concatenate")) {
     ConcatenateParams params;
     err = EthosnAPI::Concatenate(call, &params);
     tensor_table_[cn->args[0]] = params.input_infos;
@@ -181,7 +195,10 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) {
   sl::TensorAndId<sl::Operand> tensor;
   sl::TensorsAndId tensors;
   // Determine call -> NPU mapping
-  if (IsEthosnOp(call, "qnn.concatenate")) {
+  if (IsEthosnFunc(call, "ethos-n.qnn_conv2d")) {
+    if ((err = MakeConvolutionLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
+  } else if (IsEthosnOp(call, "qnn.concatenate")) {
     if ((err = MakeConcatenateLayer(call, &tensor))) ReportFatalError(call, err);
     return MakeOps(tensor);
   } else if (IsEthosnOp(call, "split")) {
@@ -227,6 +244,28 @@ void ConstructNetworkVisitor::VisitLeaf(const Expr& expr) {
   if (!expr->IsInstance<FunctionNode>()) MixedModeVisitor::VisitLeaf(expr);
 }
 
+EthosnError ConstructNetworkVisitor::MakeConvolutionLayer(const Call& call,
+                                                          sl::TensorAndId<sl::Operand>* out) {
+  ConvolutionParams params;
+  if (auto err = EthosnAPI::QnnConv2d(call->op.as<FunctionNode>()->body, &params)) {
+    return err;
+  }
+
+  auto activation = operand_table_[call->args[0]][0];
+  auto weights = AddConstant(network_, params.weights_info, params.raw_weights).tensor;
+  auto bias = AddConstant(network_, params.bias_info, params.raw_bias).tensor;
+  try {
+    if (params.is_depthwise) {
+      *out = AddDepthwiseConvolution(network_, *activation, *bias, *weights, params.conv_info);
+    } else {
+      *out = AddConvolution(network_, *activation, *bias, *weights, params.conv_info);
+    }
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
 EthosnError ConstructNetworkVisitor::MakeConcatenateLayer(const Call& call,
                                                           sl::TensorAndId<sl::Operand>* out) {
   ConcatenateParams params;
index 714a22d..7d1fe9c 100644 (file)
@@ -197,6 +197,7 @@ class ConstructNetworkVisitor : public MixedModeVisitor, private ErrorReportingP
   void VisitLeaf(const Expr& expr) final;
 
   // Make a support library operand from a Call
+  EthosnError MakeConvolutionLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
   EthosnError MakeConcatenateLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
   EthosnError MakeSplitLayer(const Call& call, sl::TensorsAndId* outs);
 
index d92e35a..b7cac65 100644 (file)
@@ -40,6 +40,105 @@ namespace relay {
 namespace contrib {
 namespace ethosn {
 
+EthosnError EthosnAPI::QnnConv2d(const Expr& expr, ConvolutionParams* params) {
+  Call requantize = Downcast<Call>(expr);
+  Call bias_add = Downcast<Call>(requantize->args[0]);
+  Call conv = Downcast<Call>(bias_add->args[0]);
+  Call pad;
+  if (conv->args[0]->IsInstance<CallNode>() &&
+      Downcast<Call>(conv->args[0])->op == Op::Get("nn.pad"))
+    pad = Downcast<Call>(conv->args[0]);
+  const auto& conv_attr = conv->attrs.as<Conv2DAttrs>();
+
+  // Extract the quantization params from the arguments
+  int input_zero_point;
+  int kernel_zero_point;
+  int output_zero_point;
+  float input_scale;
+  float kernel_scale;
+  float output_scale;
+  EthosnError err = AsConstant<int>(conv->args[2], &input_zero_point);
+  err += AsConstant<int>(conv->args[3], &kernel_zero_point);
+  err += AsConstant<int>(requantize->args[4], &output_zero_point);
+  err += AsConstant<float>(conv->args[4], &input_scale);
+  err += AsConstant<float>(conv->args[5], &kernel_scale);
+  err += AsConstant<float>(requantize->args[3], &output_scale);
+
+  // Convert quantization params
+  sl::QuantizationInfo data_q_info;
+  sl::QuantizationInfo weights_q_info;
+  sl::QuantizationInfo bias_q_info;
+  sl::QuantizationInfo output_q_info;
+  err += Tvm2Npu(input_zero_point, input_scale, &data_q_info);
+  err += Tvm2Npu(kernel_zero_point, kernel_scale, &weights_q_info);
+  err += Tvm2Npu(0, data_q_info.m_Scale * weights_q_info.m_Scale, &bias_q_info);
+  err += Tvm2Npu(output_zero_point, output_scale, &output_q_info);
+
+  // Convert convolution attributes
+  sl::Padding padding;
+  if (pad.defined()) {
+    Tvm2Npu(conv_attr->padding, &padding);
+    // Don't support both standalone operator padding and attribute defined padding
+    if (padding != sl::Padding({0, 0, 0, 0})) {
+      err += EthosnError(
+          ErrStrm() << "both op and attr padding exist, must be either op/attr only or no padding");
+    }
+    err += Tvm2Npu(pad->attrs.as<PadAttrs>()->pad_width, &padding);
+  } else {
+    err += Tvm2Npu(conv_attr->padding, &padding);
+  }
+  sl::Stride stride;
+  err += Tvm2Npu(conv_attr->strides, &stride);
+  // Dilation is not supported
+  std::array<uint32_t, 4> dilation = {1, 1, 1, 1};
+  AsArray(conv_attr->dilation, &dilation);
+  if (conv_attr->dilation.size() != 2 || dilation[0] != 1 || dilation[1] != 1) {
+    err +=
+        EthosnError(ErrStrm() << "dilation=" << conv_attr->dilation << ", dilation must = [1, 1]");
+  }
+  // Create convolution info
+  params->conv_info = sl::ConvolutionInfo(padding, stride, output_q_info);
+
+  // Create data info
+  const TensorTypeNode* data_dtype;
+  if (pad.defined()) {
+    data_dtype = pad->args[0]->checked_type().as<TensorTypeNode>();
+  } else {
+    data_dtype = conv->args[0]->checked_type().as<TensorTypeNode>();
+  }
+  sl::TensorShape activation_tensor_shape;
+  sl::DataType activation_data_type;
+  err += Tvm2Npu(data_dtype->shape, &activation_tensor_shape);
+  err += Tvm2Npu(data_dtype->dtype, &activation_data_type);
+  params->activation_info = sl::TensorInfo(activation_tensor_shape, activation_data_type,
+                                           sl::DataFormat::NHWC, data_q_info);
+
+  // Create weights info
+  params->is_depthwise = conv_attr->channels.defined() &&
+                         tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) &&
+                         conv_attr->groups != 1;
+
+  const auto* weights_dtype = conv->args[1]->checked_type().as<TensorTypeNode>();
+  sl::TensorShape weights_tensor_shape;
+  sl::DataType weights_data_type;
+  sl::DataFormat weights_data_format;
+  // Ignore the error here because weights don't have a batch axis
+  Tvm2Npu(weights_dtype->shape, &weights_tensor_shape);
+  err += Tvm2Npu(weights_dtype->dtype, &weights_data_type);
+  err += Tvm2Npu(params->is_depthwise ? "HWIM" : "HWIO", &weights_data_format);
+  params->weights_info =
+      sl::TensorInfo(weights_tensor_shape, weights_data_type, weights_data_format, weights_q_info);
+  params->raw_weights = conv->args[1].as<ConstantNode>()->data->data;
+
+  // Create bias info
+  params->bias_info = sl::TensorInfo(
+      {1, 1, 1, params->is_depthwise ? weights_tensor_shape[2] : weights_tensor_shape[3]},
+      sl::DataType::INT32_QUANTIZED, sl::DataFormat::NHWC, bias_q_info);
+  params->raw_bias = bias_add->args[1].as<ConstantNode>()->data->data;
+
+  return err;
+}
+
 EthosnError EthosnAPI::Concatenate(const Expr& expr, ConcatenateParams* params) {
   Call call = Downcast<Call>(expr);
   const auto& attrs = call->attrs.as<ConcatenateAttrs>();
@@ -107,6 +206,60 @@ EthosnError EthosnAPI::Split(const Expr& expr, SplitParams* params) {
   return err;
 }
 
+EthosnError EthosnAPI::Tvm2Npu(const Array<IndexExpr>& padding, sl::Padding* npu_padding) {
+  std::array<uint32_t, 4> dim;
+  if (EthosnError err = AsArray<IndexExpr, uint32_t>(padding, &dim)) {
+    return err;
+  }
+  switch (padding.size()) {
+    case 1:
+      *npu_padding = sl::Padding(dim[0], dim[0], dim[0], dim[0]);
+      break;
+    case 2:
+      // Height, width -> top, bottom, left, right
+      *npu_padding = sl::Padding(dim[0], dim[0], dim[1], dim[1]);
+      break;
+    case 4:
+      // Top, left, bottom, right -> top, bottom, left, right
+      *npu_padding = sl::Padding(dim[0], dim[2], dim[1], dim[3]);
+      break;
+    default:
+      return EthosnError(ErrStrm() << "padding tuple size=" << padding.size()
+                                   << ", padding tuple size must be {1, 2, 4}");
+  }
+  return EthosnError();
+}
+
+EthosnError EthosnAPI::Tvm2Npu(const Array<IndexExpr>& strides, sl::Stride* npu_stride) {
+  if (strides.size() != 2) {
+    return EthosnError(ErrStrm() << "stride size=" << strides.size() << ", stride size must = 2");
+  }
+  std::array<uint32_t, 4> dim;
+  if (EthosnError err = AsArray<IndexExpr, uint32_t>(strides, &dim)) {
+    return err;
+  }
+  *npu_stride = sl::Stride(dim[1], dim[0]);
+  return EthosnError();
+}
+
+EthosnError EthosnAPI::Tvm2Npu(const std::string& dformat, sl::DataFormat* data_format) {
+  if (dformat == "NCHW") {
+    *data_format = sl::DataFormat::NCHW;
+    return EthosnError();
+  } else if (dformat == "NHWC") {
+    *data_format = sl::DataFormat::NHWC;
+    return EthosnError();
+  } else if (dformat == "HWIO") {
+    *data_format = sl::DataFormat::HWIO;
+    return EthosnError();
+  } else if (dformat == "HWIM") {
+    *data_format = sl::DataFormat::HWIM;
+    return EthosnError();
+  }
+  return EthosnError(ErrStrm() << "format=" << dformat
+                               << ", format must be {NCHW, NHWC, HWIO, HWIM}");
+}
+
 EthosnError EthosnAPI::Tvm2Npu(const Array<IndexExpr>& shape, sl::TensorShape* npu_shape) {
   EthosnError err = AsArray<IndexExpr, uint32_t>(shape, npu_shape);
   if (npu_shape->front() != 1) {
@@ -128,6 +281,29 @@ EthosnError EthosnAPI::Tvm2Npu(const tvm::DataType& dtype, sl::DataType* data_ty
   return EthosnError(ErrStrm() << "dtype=\'" << dtype << "\', dtype must be either uint8 or int32");
 }
 
+EthosnError EthosnAPI::Tvm2Npu(int32_t zero_point, float scale, sl::QuantizationInfo* npu_qinfo) {
+  *npu_qinfo = sl::QuantizationInfo(zero_point, scale);
+  return EthosnError();
+}
+
+EthosnError EthosnAPI::Tvm2Npu(const Array<Array<Integer>>& padding, sl::Padding* npu_padding) {
+  if (padding.size() != 4) {
+    return EthosnError(ErrStrm() << "padding tuple size=" << padding.size()
+                                 << ", padding tuple size must = 4");
+  }
+  Array<IndexExpr> reduced_padding;
+  reduced_padding.push_back(padding[1][0]);
+  reduced_padding.push_back(padding[1][1]);
+  reduced_padding.push_back(padding[2][0]);
+  reduced_padding.push_back(padding[2][1]);
+  std::array<uint32_t, 4> dim;
+  if (EthosnError err = AsArray<IndexExpr, uint32_t>(reduced_padding, &dim)) {
+    return err;
+  }
+  *npu_padding = sl::Padding(dim[0], dim[1], dim[2], dim[3]);
+  return EthosnError();
+}
+
 // Convert an array of IntImmNodes into ValueT
 // IndexT type of Array indexing variable
 // ValueT type of resulting value
@@ -158,6 +334,20 @@ EthosnError EthosnAPI::AsConstant(const Expr& expr, T* out) {
   return EthosnError();
 }
 
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      ConvolutionParams params;
+      auto err = EthosnAPI::QnnConv2d(call, &params);
+      if (params.is_depthwise) {
+        *rv = !err && sl::IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info,
+                                                          params.conv_info, params.activation_info);
+      } else {
+        *rv = !err && sl::IsConvolutionSupported(params.bias_info, params.weights_info,
+                                                 params.conv_info, params.activation_info);
+      }
+    });
+
 TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate")
     .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
       Call call = args[0];
index 34af7ce..20fe8be 100644 (file)
@@ -44,6 +44,16 @@ namespace ethosn {
 
 namespace sl = ::ethosn::support_library;
 
+struct ConvolutionParams {
+  sl::ConvolutionInfo conv_info;
+  sl::TensorInfo activation_info;
+  sl::TensorInfo weights_info;
+  sl::TensorInfo bias_info;
+  void* raw_weights = nullptr;
+  void* raw_bias = nullptr;
+  bool is_depthwise = false;
+};
+
 struct ConcatenateParams {
   sl::QuantizationInfo qInfo;
   sl::ConcatenationInfo concat_info = sl::ConcatenationInfo(1, qInfo);
@@ -115,6 +125,8 @@ class EthosnError {
  */
 class EthosnAPI {
  public:
+  /*! \brief Extract the Support Library convolution params from an ethos-n.qnn_conv2d func */
+  static EthosnError QnnConv2d(const Expr& expr, ConvolutionParams* params);
   /*! \brief Extract the Support Library concatenate params from a Relay qnn.concatenate call */
   static EthosnError Concatenate(const Expr& expr, ConcatenateParams* params);
   /*! \brief Extract the Support Library split params from a Relay split call */
@@ -125,6 +137,16 @@ class EthosnAPI {
   static EthosnError Tvm2Npu(const Array<IndexExpr>& shape, sl::TensorShape* npu_shape);
   /*! \brief Convert a TVM data type to a SL data type */
   static EthosnError Tvm2Npu(const tvm::DataType& dtype, sl::DataType* data_type);
+  /*! \brief Convert TVM 1D padding to SL padding */
+  static EthosnError Tvm2Npu(const Array<IndexExpr>& padding, sl::Padding* npu_padding);
+  /*! \brief Convert TVM 1D striding to SL striding */
+  static EthosnError Tvm2Npu(const Array<IndexExpr>& strides, sl::Stride* npu_stride);
+  /*! \brief Convert TVM data format to SL data format */
+  static EthosnError Tvm2Npu(const std::string& dformat, sl::DataFormat* data_format);
+  /*! \brief Convert TVM quantization info to SL quantization info */
+  static EthosnError Tvm2Npu(int32_t zero_point, float scale, sl::QuantizationInfo* npu_qinfo);
+  /*! \brief Convert TVM 2D padding to SL padding */
+  static EthosnError Tvm2Npu(const Array<Array<Integer>>& padding, sl::Padding* npu_padding);
 
   // Convert an array of IntImmNodes into ValueT
   // IndexT type of Array indexing variable
index c627833..b43d273 100644 (file)
@@ -94,6 +94,8 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1):
                 f = relay.build_module.bind_params_by_name(mod["main"], params)
                 mod = tvm.IRModule()
                 mod["main"] = f
+                pattern = get_pattern_table("ethos-n")
+                mod = relay.transform.MergeComposite(pattern)(mod)
                 mod = relay.transform.AnnotateTarget("ethos-n")(mod)
                 mod = relay.transform.MergeCompilerRegions()(mod)
                 mod = relay.transform.PartitionGraph()(mod)
diff --git a/tests/python/contrib/test_ethosn/test_conv2d.py b/tests/python/contrib/test_ethosn/test_conv2d.py
new file mode 100644 (file)
index 0000000..52e3de9
--- /dev/null
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration conv2d tests"""
+
+import numpy as np
+import math
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+
+
+def _get_same_padding(data, kernel, dilation, stride):
+    dilated_kernel_h = dilation[0] * (kernel[0] - 1) + 1
+    dilated_kernel_w = dilation[1] * (kernel[1] - 1) + 1
+    out = int(math.ceil(float(data[0]) / float(stride[0])))
+    pad = max(0, (out - 1) * stride[0] + dilated_kernel_h - data[0])
+    pad_top = pad // 2
+    pad_bottom = pad - pad_top
+
+    out = int(math.ceil(float(data[1]) / float(stride[1])))
+    pad = max(0, (out - 1) * stride[1] + dilated_kernel_w - data[1])
+    pad_left = pad // 2
+    pad_right = pad - pad_left
+    return [pad_top, pad_left, pad_bottom, pad_right]
+
+
+def _get_model(shape, kernel_h, kernel_w,
+               input_zp, input_sc,
+               kernel_zp, kernel_sc,
+               output_zp, output_sc,
+               pad, strides, dilation,
+               groups, dtype,
+               out_channels, weight_format):
+    """Return a model and any parameters it may have"""
+    a = relay.var("a", shape=shape, dtype=dtype)
+    if pad == "op" or pad == "both":
+        p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides)
+        a = relay.nn.pad(a,
+                         pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)],
+                         pad_value=input_zp, pad_mode="constant")
+        shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3])
+
+    p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides)
+    if weight_format == "HWIO":
+        weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels)
+    else:
+        weight_shape = (kernel_h, kernel_w, out_channels, 1)
+    w = tvm.nd.array(np.random.randint(np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=weight_shape, dtype=dtype))
+    weights = relay.const(w, dtype)
+    conv = relay.qnn.op.conv2d(
+        a,
+        weights,
+        input_zero_point=relay.const(input_zp, "int32"),
+        kernel_zero_point=relay.const(kernel_zp, "int32"),
+        input_scale=relay.const(input_sc, "float32"),
+        kernel_scale=relay.const(kernel_sc, "float32"),
+        kernel_size=(kernel_h, kernel_w),
+        data_layout="NHWC",
+        kernel_layout=weight_format,
+        dilation=dilation,
+        strides=strides,
+        groups=groups,
+        channels=out_channels,
+        padding=p if pad == "attr" or pad == "both" else (0, 0, 0, 0),
+        out_dtype="int32",
+    )
+    b = tvm.nd.array(np.random.randint(0, high=10, size=(out_channels,), dtype="int32"))
+    biasc = relay.const(b, "int32")
+    bias = relay.nn.bias_add(conv, biasc, axis=3)
+    req = relay.qnn.op.requantize(
+        bias,
+        relay.const(input_sc * kernel_sc, 'float32'),  # input zero scale
+        relay.const(0, 'int32'),                       # input zero point
+        relay.const(output_sc, 'float32'),             # output zero scale
+        relay.const(output_zp, 'int32'),               # output zero point
+        out_dtype="uint8"
+    )
+    params = {"w": w,
+              "b": b}
+    return req, params
+
+
+def _get_conv2d_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, channels):
+    input_max = input_sc * (255 - input_zp)
+    input_min = - input_sc * input_zp
+    kernel_max = kernel_sc * (255 - kernel_zp)
+    kernel_min = - kernel_sc * kernel_zp
+    output_limits = [kernel_max * kernel_h * kernel_w * channels * input_max,
+                     kernel_min * kernel_h * kernel_w * channels * input_max,
+                     kernel_min * kernel_h * kernel_w * channels * input_min,
+                     kernel_max * kernel_h * kernel_w * channels * input_min]
+    output_max = max(output_limits)
+    output_min = min(output_limits)
+    output_sc = (output_max - output_min) / 255
+    output_zp = - int(output_min / output_sc)
+    return output_zp, output_sc
+
+
+def test_conv2d():
+    if not ethosn_available():
+        return
+
+    trials = [
+        [(1, 17, 20, 26), 4, 3, 1, 'attr', (2, 2), (1, 1)],
+        [(1, 30, 27, 30), 5, 5, 3, 'none', (1, 1), (1, 1)],
+        [(1, 14, 28, 11), 6, 2, 2, 'op', (2, 2), (1, 1)],
+        [(1, 9, 20, 30), 7, 1, 5, 'none', (1, 1), (1, 1)],
+        [(1, 21, 21, 22), 8, 5, 1, 'attr', (2, 2), (1, 1)],
+        [(1, 21, 25, 29), 9, 2, 5, 'op', (1, 1), (1, 1)],
+        [(1, 31, 28, 15), 10, 1, 2, 'attr', (2, 2), (1, 1)],
+        [(1, 21, 21, 8), 11, 3, 3, 'none', (1, 1), (1, 1)],
+        [(1, 5, 11, 6), 12, 5, 2, 'op', (2, 2), (1, 1)],
+        [(1, 12, 7, 18), 13, 1, 3, 'op', (1, 1), (1, 1)],
+        [(1, 24, 6, 26), 14, 3, 5, 'none', (2, 2), (1, 1)],
+        [(1, 19, 24, 16), 15, 2, 1, 'attr', (1, 1), (1, 1)],
+    ]
+
+    np.random.seed(0)
+    for depthwise in [False, True]:
+        for shape, out_channels, kernel_h, kernel_w, pad, stride, dilation in trials:
+            if depthwise:
+                out_channels = shape[3]
+                groups = out_channels
+                kernel_w = kernel_h
+                weight_format = "HWOI"
+                stride = (1, 1) if kernel_w == 1 else (2, 2)
+            else:
+                groups = 1
+                weight_format = "HWIO"
+
+            outputs = []
+            inputs = {
+                "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")),
+            }
+            input_zp = np.random.randint(0, 255)
+            input_sc = np.random.random() * 2
+            kernel_zp = np.random.randint(0, 255)
+            kernel_sc = np.random.random() * 2
+            output_zp, output_sc = _get_conv2d_qnn_params(input_zp, input_sc,
+                                                          kernel_zp, kernel_sc,
+                                                          kernel_h, kernel_w, shape[3])
+            model, params = _get_model(shape, kernel_h, kernel_w,
+                                       input_zp, input_sc,
+                                       kernel_zp, kernel_sc,
+                                       output_zp, output_sc,
+                                       pad, stride, dilation,
+                                       groups, "uint8",
+                                       out_channels, weight_format)
+            for npu in [False, True]:
+                mod = tei.make_module(model, params)
+                outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu))
+
+            tei.verify(outputs, 1)
+
+
+def test_conv2d_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 1, "none", (1, 1), (1, 1), 1, "uint8", 8, "HWIO",
+         "Overall scale (of the input * weights / output) should be in the range [0, 1)"),
+        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 1, "none", (1, 1), (1, 1), 1, "int8", 8, "HWIO",
+         "dtype='int8', dtype must be either uint8 or int32"),
+        ((1, 4, 4, 4), 2, 2, 0, 1, 0, 1, 0, 2, "both", (1, 1), (1, 1), 1, "uint8", 8, "HWIO",
+         "both op and attr padding exist, must be either op/attr only or no padding"),
+        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1, 1), (1, 1), 1, "uint8", 8, "HWIO",
+         "stride size=3, stride size must = 2"),
+        ((1, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1), (2, 1), 1, "uint8", 8, "HWIO",
+         "dilation=[2, 1], dilation must = [1, 1]"),
+        ((2, 4, 4, 4), 1, 1, 0, 1, 0, 1, 0, 2, "none", (1, 1), (1, 1), 1, "uint8", 8, "HWIO",
+         "batch size=2, batch size must = 1"),
+    ]
+
+    np.random.seed(0)
+    for shape, kernel_h, kernel_w, input_zp, input_sc, kernel_zp,\
+        kernel_sc, output_zp, output_sc, pad, stride, dilation,\
+        groups, dtype, out_channels, weight_format, err_msg in trials:
+        model, params = _get_model(shape, kernel_h, kernel_w,
+                                   input_zp, input_sc,
+                                   kernel_zp, kernel_sc,
+                                   output_zp, output_sc,
+                                   pad, stride, dilation,
+                                   groups, dtype,
+                                   out_channels, weight_format)
+        model = tei.make_ethosn_composite(model, "ethos-n.qnn_conv2d")
+        mod = tei.make_ethosn_partition(model)
+        tei.test_error(mod, {}, err_msg)