[BYOC][ACL] Support asymmetric per-layer quantized operators (#6109)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Wed, 29 Jul 2020 16:18:57 +0000 (17:18 +0100)
committerGitHub <noreply@github.com>
Wed, 29 Jul 2020 16:18:57 +0000 (09:18 -0700)
* [BYOC][ACL] Support asymmetric per-layer quantization

Adds support for asymmetric per-layer quantization in the ACL runtime. This includes support for qnn.conv2d, nn.maxpool2d and reshape. Reflected these changes in codegen and runtime tests.

Change-Id: I8f610bd37af1e3740fd48c2d502bcc4727d9d712

* Address comments

Change-Id: I4f9e3e7dbf6053066927cf07c4c19ecc88572e9d

* Fix tutorial

Change-Id: I4371e9d97a120fb7776db40ffcde60f46927af4d

* Improve test infrastructure

* Doc-string for generate trials
* Output params on error

Change-Id: Ib2e2b1fcdf05cdc77f7f4fb4b46395f28c129957

12 files changed:
docs/deploy/arm_compute_lib.rst
python/tvm/relay/op/contrib/arm_compute_lib.py
python/tvm/relay/qnn/op/layout_conversions.py
src/relay/backend/contrib/arm_compute_lib/codegen.cc
src/runtime/contrib/arm_compute_lib/acl_runtime.cc
src/runtime/contrib/arm_compute_lib/acl_utils.cc
src/runtime/contrib/arm_compute_lib/acl_utils.h
tests/python/contrib/test_arm_compute_lib/infrastructure.py
tests/python/contrib/test_arm_compute_lib/test_conv2d.py
tests/python/contrib/test_arm_compute_lib/test_network.py
tests/python/contrib/test_arm_compute_lib/test_pooling.py
tests/python/contrib/test_arm_compute_lib/test_reshape.py

index 28abc9c..6dc8df0 100644 (file)
@@ -121,6 +121,33 @@ networks refer to the tests: `tests/python/contrib/test_arm_compute_lib`. Here y
 `infrastructure.py` to use the remote device you have setup.
 
 
+Operator support
+----------------
++--------------+-------------------------------------------------------------------------+
+| Relay Node   | Remarks                                                                 |
++==============+=========================================================================+
+| nn.conv2d    | fp32:                                                                   |
+|              |   Simple: nn.conv2d                                                     |
+|              |   Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?                 |
+|              |                                                                         |
+|              | (only groups = 1 supported)                                             |
++--------------+-------------------------------------------------------------------------+
+| qnn.conv2d   | uint8:                                                                  |
+|              |   Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize |
+|              |                                                                         |
+|              | (only groups = 1 supported)                                             |
++--------------+-------------------------------------------------------------------------+
+| nn.maxpool2d | fp32, uint8                                                             |
++--------------+-------------------------------------------------------------------------+
+| reshape      | fp32, uint8                                                             |
++--------------+-------------------------------------------------------------------------+
+
+.. note::
+    A composite operator is a series of operators that map to a single Arm Compute Library operator. You can view this
+    as being a single fused operator from the view point of Arm Compute Library. '?' denotes an optional operator in
+    the series of operators that make up a composite operator.
+
+
 Adding a new operator
 ---------------------
 Adding a new operator requires changes to a series of places. This section will give a hint on
index e5b2af5..2f031b3 100644 (file)
@@ -81,6 +81,23 @@ def arm_compute_lib_pattern_table():
         pattern = pattern.optional(is_op('nn.relu'))
         return pattern
 
+    def qnn_conv_pattern():
+        """Create a quantized convolution pattern.
+
+        Returns
+        -------
+        pattern : dataflow_pattern.AltPattern
+            Denotes the convolution pattern.
+        """
+        pattern = is_op('nn.pad')(wildcard()) | wildcard()
+        pattern = is_op('qnn.conv2d')(
+            pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
+        pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
+        pattern = pattern.optional(is_op('nn.relu'))
+        pattern = is_op('qnn.requantize')(
+            pattern, wildcard(), wildcard(), is_constant(), is_constant())
+        return pattern
+
     def check_conv(extract):
         """Check conv pattern is supported by ACL."""
         call = extract
@@ -88,7 +105,17 @@ def arm_compute_lib_pattern_table():
             call = call.args[0]
         return conv2d(call.attrs, call.args)
 
-    return [('arm_compute_lib.conv2d', conv_pattern(), check_conv)]
+    def check_qnn_conv(extract):
+        """Check qnn conv pattern is supported by ACL."""
+        if extract.attrs.out_dtype != "uint8":
+            return False
+        call = extract
+        while call.op.name != "qnn.conv2d":
+            call = call.args[0]
+        return qnn_conv2d(call.attrs, call.args)
+
+    return [('arm_compute_lib.conv2d', conv_pattern(), check_conv),
+            ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv)]
 
 
 def _register_external_op_helper(op_name, supported=True):
@@ -115,7 +142,24 @@ def conv2d(attrs, args):
     if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "float32":
         return False
     kernel_typ = args[1].checked_type
-    if kernel_typ.dtype != "float32":
+    if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32":
+        return False
+    return True
+
+
+def qnn_conv2d(attrs, args):
+    """Check if the external ACL codegen for qnn.conv2d should be used."""
+    if attrs.groups != 1:
+        return False
+    if attrs.data_layout != "NHWC":
+        return False
+    if attrs.out_dtype != "int32" and attrs.out_dtype != "":
+        return False
+    data_typ = args[0].checked_type
+    if len(data_typ.shape) != 4 or data_typ.shape[0] != 1 or data_typ.dtype != "uint8":
+        return False
+    kernel_typ = args[1].checked_type
+    if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "uint8":
         return False
     return True
 
@@ -126,6 +170,6 @@ def max_pool2d(attrs, args):
     if attrs.layout != "NHWC":
         return False
     typ = args[0].checked_type
-    if typ.dtype != "float32":
+    if typ.dtype not in ["float32", "uint8"]:
         return False
     return True
index caa4c56..391714a 100644 (file)
@@ -20,6 +20,8 @@ from __future__ import absolute_import
 
 from tvm.relay.op import op as reg
 
+from ...op.strategy.generic import is_depthwise_conv2d
+
 
 @reg.register_convert_op_layout("qnn.conv2d")
 def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts):
@@ -51,11 +53,20 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts):
     new_attrs = dict(attrs)
     new_attrs['data_layout'] = desired_data_layout
 
+    if desired_kernel_layout != "default":
+        new_attrs['kernel_layout'] = desired_kernel_layout
+        return relay.qnn.op.conv2d(*inputs, **new_attrs)
+
     if desired_data_layout == 'NCHW':
-        if desired_kernel_layout != "default":
-            new_attrs['kernel_layout'] = desired_kernel_layout
+        new_attrs['kernel_layout'] = 'OIHW'
+        return relay.qnn.op.conv2d(*inputs, **new_attrs)
+    if desired_data_layout == 'NHWC':
+        # Check for depthwise convolution.
+        if is_depthwise_conv2d(inputs[0].shape, attrs['data_layout'], inputs[1].shape,
+                               attrs['kernel_layout'], attrs['groups']):
+            new_attrs['kernel_layout'] = 'HWOI'
         else:
-            new_attrs['kernel_layout'] = 'OIHW'
+            new_attrs['kernel_layout'] = 'HWIO'
         return relay.qnn.op.conv2d(*inputs, **new_attrs)
 
     raise ValueError('Layout %s is not yet supported' % desired_data_layout)
index 8edbc15..88de3ed 100644 (file)
@@ -50,6 +50,18 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
   ACLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}
 
   /*!
+   * \brief A series of operators that form a composite
+   * convolution. Supports both nn.conv2d and qnn.conv2d.
+   */
+  struct CompositeConvNode {
+    const CallNode* pad = nullptr;
+    const CallNode* conv = nullptr;
+    const CallNode* bias = nullptr;
+    const CallNode* activation = nullptr;
+    const CallNode* requantize = nullptr;
+  };
+
+  /*!
    * \brief Visit call nodes and generate appropriate JSON node.
    *
    * \param cn The current call node.
@@ -68,7 +80,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
     CHECK(comp.defined()) << "Arm Compute Library JSON runtime only supports composite functions.";
     const std::string name = comp.value();
     std::shared_ptr<JSONGraphNode> json_node;
-    if (name == "arm_compute_lib.conv2d") {
+    if (name == "arm_compute_lib.conv2d" || name == "arm_compute_lib.qnn_conv2d") {
       json_node = CreateCompositeConvJSONNode(cn);
     } else {
       LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
@@ -78,57 +90,86 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
 
  private:
   /*!
-   * \brief Create a JSON representation of a composite convolution.
+   * \brief Extract convolution nodes from a composite function.
    *
-   * \param call The call to be represented.
-   * \return A JSON representation of a specific operator.
+   * \param cn The call node of the composite function.
+   * \return Extracted composite convolution nodes.
    */
-  std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
-    const std::string name = "nn.conv2d";
-    const CallNode* pad = nullptr;
-    const CallNode* conv = nullptr;
-    const CallNode* bias = nullptr;
-    bool has_activation = false;
-
-    // Unpack composite function
+  static CompositeConvNode UnpackCompositeConvolution(const CallNode* cn) {
+    CompositeConvNode nodes{};
     const auto* fn = cn->op.as<FunctionNode>();
     CHECK(fn);
+
+    // Traverse composite convolution function from child to parent
     const auto* current_call = fn->body.as<CallNode>();
+    if (backend::IsOp(current_call, "qnn.requantize")) {
+      nodes.requantize = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
     if (backend::IsOp(current_call, "nn.relu")) {
-      has_activation = true;
+      nodes.activation = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
     if (backend::IsOp(current_call, "nn.bias_add")) {
-      bias = current_call;
+      nodes.bias = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
-    CHECK(backend::IsOp(current_call, "nn.conv2d"));
-    conv = current_call;
+    // Enforce a convolution node exists at this point during traversal
+    if (nodes.requantize) {
+      CHECK(backend::IsOp(current_call, "qnn.conv2d"));
+    } else {
+      CHECK(backend::IsOp(current_call, "nn.conv2d"));
+    }
+    nodes.conv = current_call;
     if (!current_call->args.empty() && current_call->args[0]->IsInstance<CallNode>()) {
       current_call = current_call->args[0].as<CallNode>();
       if (backend::IsOp(current_call, "nn.pad")) {
-        pad = current_call;
+        nodes.pad = current_call;
       }
     }
+    return nodes;
+  }
+
+  /*!
+   * \brief Create a JSON representation of a composite convolution.
+   *
+   * \param cn The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
+    CompositeConvNode nodes = UnpackCompositeConvolution(cn);
+    std::string name = "nn.conv2d";
 
-    const auto* conv_attr = conv->attrs.as<Conv2DAttrs>();
+    const auto* conv_attr = nodes.conv->attrs.as<Conv2DAttrs>();
     CHECK(conv_attr);
     CHECK(conv_attr->kernel_layout == "OHWI")
         << "Kernel layout must be OHWI, has the module been pre-processed correctly?";
 
+    // Inputs must be added in the same order they appear in the relay graph.
     std::vector<JSONGraphNodeEntry> inputs;
     inputs.push_back(VisitExpr(cn->args[0])[0]);
-    inputs.push_back(VisitExpr(conv->args[1])[0]);
-    if (bias) {
-      inputs.push_back(VisitExpr(bias->args[1])[0]);
+    inputs.push_back(VisitExpr(nodes.conv->args[1])[0]);
+    if (nodes.requantize) {
+      name = "qnn.conv2d";
+      inputs.push_back(VisitExpr(nodes.conv->args[2])[0]);  // input zero-point
+      inputs.push_back(VisitExpr(nodes.conv->args[3])[0]);  // kernel zero-point
+      inputs.push_back(VisitExpr(nodes.conv->args[4])[0]);  // input scale
+      inputs.push_back(VisitExpr(nodes.conv->args[5])[0]);  // kernel scale
+    }
+    if (nodes.bias) {
+      inputs.push_back(VisitExpr(nodes.bias->args[1])[0]);
+    }
+    if (nodes.requantize) {
+      inputs.push_back(VisitExpr(nodes.requantize->args[3])[0]);  // output scale
+      inputs.push_back(VisitExpr(nodes.requantize->args[4])[0]);  // output zero-point
     }
 
     auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
-    SetCallNodeAttribute(json_node, conv);
+    SetCallNodeAttribute(json_node, nodes.conv);
 
     // Override attributes
-    if (pad) {
-      const auto* pad_attr = pad->attrs.as<PadAttrs>();
+    if (nodes.pad) {
+      const auto* pad_attr = nodes.pad->attrs.as<PadAttrs>();
       CHECK(pad_attr);
       auto p = pad_attr->pad_width;
       // Convert to TVM layout for now, conversion to ACL layout takes place in runtime.
@@ -141,7 +182,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
       padding_attr.emplace_back(padding);
       json_node->SetAttr("padding", padding_attr);
     }
-    if (has_activation) {
+    if (nodes.activation) {
       std::vector<std::string> activation_type = {"relu"};
       std::vector<dmlc::any> act_attr;
       act_attr.emplace_back(activation_type);
@@ -161,7 +202,8 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
  */
 IRModule PreProcessModule(const IRModule& mod) {
   IRModule preprocessed_module;
-  tvm::Map<String, Array<String>> desired_layouts = {{"nn.conv2d", {"NHWC", "OHWI"}}};
+  tvm::Map<String, Array<String>> desired_layouts = {{"nn.conv2d", {"NHWC", "OHWI"}},
+                                                     {"qnn.conv2d", {"NHWC", "OHWI"}}};
   preprocessed_module = transform::ConvertLayout(desired_layouts)(mod);
   preprocessed_module = transform::FoldConstant()(preprocessed_module);
   return preprocessed_module;
index e8cdef7..2498dcf 100644 (file)
@@ -25,7 +25,6 @@
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/registry.h>
 
-#include "../../file_util.h"
 #include "../json/json_node.h"
 #include "../json/json_runtime.h"
 
@@ -114,21 +113,8 @@ class ACLRuntime : public JSONRuntimeBase {
    * per engine.
    */
   void BuildEngine() {
-    std::shared_ptr<arm_compute::MemoryManagerOnDemand> mm = MakeMemoryManager();
+    std::shared_ptr<arm_compute::MemoryManagerOnDemand> mm = MakeACLMemoryManager();
     int num_pools = 0;
-
-    for (size_t i = 0; i < input_nodes_.size(); ++i) {
-      uint32_t nid = input_nodes_[i];
-      const auto& node = nodes_[nid];
-      if (node.GetOpType() == "input") {
-        layer_.inputs.push_back(MakeTensor(node));
-      } else if (node.GetOpType() == "const") {
-        uint32_t eid = EntryID(nid, 0);
-        void* data = data_entry_[eid]->data;
-        layer_.const_inputs.push_back(MakeTensor(node, data));
-      }
-    }
-
     bool found_kernel_node = false;
     for (size_t nid = 0; nid < nodes_.size(); ++nid) {
       const auto& node = nodes_[nid];
@@ -139,7 +125,7 @@ class ACLRuntime : public JSONRuntimeBase {
       if (node.GetOpType() == "kernel") {
         found_kernel_node = true;
         auto op_name = node.GetOpName();
-        if ("nn.conv2d" == op_name) {
+        if ("nn.conv2d" == op_name || "qnn.conv2d" == op_name) {
           CreateConvolution2DLayer(&layer_, node, mm);
           num_pools++;
         } else if ("nn.max_pool2d" == op_name) {
@@ -163,24 +149,65 @@ class ACLRuntime : public JSONRuntimeBase {
   struct CachedLayer {
     std::shared_ptr<arm_compute::IFunction> function;
     std::vector<arm_compute::Tensor> inputs;
-    std::vector<arm_compute::Tensor> const_inputs;
     std::vector<arm_compute::Tensor> outputs;
   };
 
   /*!
+   * \brief Create an ACL tensor given the JSON representation. If scale
+   * and offset are given, then create a quantized ACL tensor.
+   *
+   * \param tensor The tensor to represent.
+   * \param scale (optional) The scale of the tensor as an input.
+   * \param offset (optional) The offset of the tensor as an input.
+   * \return ACL Tensor.
+   */
+  arm_compute::Tensor MakeACLTensorFromJSONEntry(const JSONGraphNodeEntry& tensor,
+                                                 JSONGraphNodeEntry* scale = nullptr,
+                                                 JSONGraphNodeEntry* offset = nullptr) {
+    JSONGraphNode node = nodes_[tensor.id_];
+    void* node_data = nullptr;
+    if (node.GetOpType() == "const") {
+      node_data = data_entry_[EntryID(tensor)]->data;
+    }
+    return MakeACLTensorFromJSONNode(node, scale, offset, node_data);
+  }
+
+  /*!
+   * \brief Create an ACL tensor given the JSON representation. If scale
+   * and offset are given, then create a quantized ACL tensor.
+   *
+   * \param node The tensor to represent.
+   * \param scale (optional) The scale of the tensor as an input.
+   * \param offset (optional) The offset of the tensor as an input.
+   * \param data (optional) Constant data of input node.
+   * \return ACL Tensor.
+   */
+  arm_compute::Tensor MakeACLTensorFromJSONNode(const JSONGraphNode& node,
+                                                JSONGraphNodeEntry* scale = nullptr,
+                                                JSONGraphNodeEntry* offset = nullptr,
+                                                void* data = nullptr) {
+    const DLTensor* scale_data = nullptr;
+    const DLTensor* offset_data = nullptr;
+    if (scale && offset) {
+      scale_data = data_entry_[EntryID(*scale)];
+      offset_data = data_entry_[EntryID(*offset)];
+    }
+    return MakeACLTensor(node, data, scale_data, offset_data);
+  }
+
+  /*!
    * \brief Create a 2D convolution layer.
    *
    * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
    * \param node The JSON representation of the operator.
    * \param mm The ACL conv2d layer can request auxiliary memory from TVM.
    */
-  static void CreateConvolution2DLayer(
-      CachedLayer* layer, const JSONGraphNode& node,
-      const std::shared_ptr<arm_compute::MemoryManagerOnDemand>& mm) {
+  void CreateConvolution2DLayer(CachedLayer* layer, const JSONGraphNode& node,
+                                const std::shared_ptr<arm_compute::MemoryManagerOnDemand>& mm) {
     std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("padding");
     std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
     std::vector<std::string> dilation = node.GetAttr<std::vector<std::string>>("dilation");
-    arm_compute::PadStrideInfo pad_stride_info = ToACLPadStride(padding, strides);
+    arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides);
 
     int groups = std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
     CHECK(groups == 1) << "Arm Compute Library NEON convolution only supports group size of 1.";
@@ -198,13 +225,35 @@ class ACLRuntime : public JSONRuntimeBase {
 
     arm_compute::Size2D dilation_2d(std::stoi(dilation[0]), std::stoi(dilation[1]));
 
-    layer->outputs.push_back(MakeOutputTensor(node.GetOpShape()[0]));
+    // Collect inputs and outputs, handling both nn.conv2d and qnn.conv2d cases.
+    std::vector<JSONGraphNodeEntry> inputs = node.GetInputs();
+    size_t num_inputs = inputs.size();
+    bool has_bias;
+    if (node.GetOpName() == "qnn.conv2d") {
+      CHECK(num_inputs >= 8U && num_inputs <= 9U)
+          << "Quantized convolution requires 9 inputs with a bias, 8 inputs without.";
+      has_bias = num_inputs == 9;
+      layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[0], &inputs[4], &inputs[2]));
+      layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[1], &inputs[5], &inputs[3]));
+      if (has_bias) {
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[6]));
+      }
+      layer->outputs.push_back(
+          MakeACLTensorFromJSONNode(node, &inputs[6 + has_bias], &inputs[7 + has_bias]));
+    } else {
+      CHECK(num_inputs >= 2U && num_inputs <= 3U)
+          << "Convolution requires 3 inputs with a bias, 2 inputs without.";
+      has_bias = num_inputs == 3;
+      for (const auto& i : inputs) {
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(i));
+      }
+      layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
+    }
 
     auto function = std::make_shared<arm_compute::NEConvolutionLayer>(mm);
-    function->configure(&layer->inputs[0], &layer->const_inputs[0],
-                        layer->const_inputs.size() > 1 ? &layer->const_inputs[1] : nullptr,
-                        &layer->outputs[0], pad_stride_info, arm_compute::WeightsInfo(),
-                        dilation_2d, act_info);
+    function->configure(&layer->inputs[0], &layer->inputs[1],
+                        has_bias ? &layer->inputs[2] : nullptr, &layer->outputs[0], pad_stride_info,
+                        arm_compute::WeightsInfo(), dilation_2d, act_info);
     layer->function = function;
   }
 
@@ -216,10 +265,10 @@ class ACLRuntime : public JSONRuntimeBase {
    * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
    * \param node The JSON representation of the operator.
    */
-  static void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) {
+  void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) {
     std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("padding");
     std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
-    arm_compute::PadStrideInfo pad_stride_info = ToACLPadStride(padding, strides);
+    arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides);
 
     auto attr_pool_size = node.GetAttr<std::vector<std::string>>("pool_size");
     int pool_size_h = std::stoi(attr_pool_size[0]);
@@ -236,7 +285,8 @@ class ACLRuntime : public JSONRuntimeBase {
         arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w),
                                       arm_compute::DataLayout::NHWC, pad_stride_info);
 
-    layer->outputs.push_back(MakeOutputTensor(node.GetOpShape()[0]));
+    layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
+    layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
 
     auto function = std::make_shared<arm_compute::NEPoolingLayer>();
     function->configure(&layer->inputs[0], &layer->outputs[0], pool_info);
@@ -249,8 +299,9 @@ class ACLRuntime : public JSONRuntimeBase {
    * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
    * \param node The JSON representation of the operator.
    */
-  static void CreateReshapeLayer(CachedLayer* layer, const JSONGraphNode& node) {
-    layer->outputs.push_back(MakeOutputTensor(node.GetOpShape()[0]));
+  void CreateReshapeLayer(CachedLayer* layer, const JSONGraphNode& node) {
+    layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
+    layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
     auto function = std::make_shared<arm_compute::NEReshapeLayer>();
     function->configure(&layer->inputs[0], &layer->outputs[0]);
     layer->function = function;
index ad278ba..98c9cda 100644 (file)
@@ -38,10 +38,12 @@ void CheckACLError(const arm_compute::Status& status) {
   CHECK(status.error_code() == arm_compute::ErrorCode::OK) << "ACL: " << status.error_description();
 }
 
-arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data) {
-  CHECK(tensor_rep.GetOpType() == "input" || tensor_rep.GetOpType() == "const");
+arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data,
+                                  const DLTensor* scale, const DLTensor* offset) {
   arm_compute::Tensor tensor;
-  arm_compute::TensorInfo info = MakeTensorInfo(tensor_rep.GetOpShape()[0]);
+  std::vector<int64_t> shape = tensor_rep.GetOpShape()[0];
+  DLDataType dtype = tensor_rep.GetOpDataType()[0];
+  arm_compute::TensorInfo info = MakeACLTensorInfo(shape, dtype, scale, offset);
   tensor.allocator()->init(info);
   if (data != nullptr) {
     CheckACLError(tensor.allocator()->import_memory(data));
@@ -49,34 +51,37 @@ arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data) {
   return tensor;
 }
 
-arm_compute::Tensor MakeOutputTensor(const std::vector<int64_t>& shape) {
-  arm_compute::Tensor tensor;
-  tensor.allocator()->init(MakeTensorInfo(shape));
-  return tensor;
-}
-
-arm_compute::TensorInfo MakeTensorInfo(const std::vector<int64_t>& shape) {
-  arm_compute::TensorShape acl_shape = MakeTensorShape(shape);
-  return arm_compute::TensorInfo(acl_shape, 1, arm_compute::DataType::F32,
-                                 arm_compute::DataLayout::NHWC);
-}
-
-arm_compute::TensorShape MakeTensorShape(const std::vector<int64_t>& shape) {
+arm_compute::TensorInfo MakeACLTensorInfo(const std::vector<int64_t>& shape,
+                                          const DLDataType& dtype, const DLTensor* scale,
+                                          const DLTensor* offset) {
   arm_compute::TensorShape acl_shape;
   for (unsigned int i = shape.size(); i > 0; --i) {
     acl_shape.set(shape.size() - i, shape[i - 1]);
   }
-  return acl_shape;
+  arm_compute::DataType acl_dtype = MakeACLDataType(dtype);
+  arm_compute::TensorInfo info(acl_shape, 1, acl_dtype, arm_compute::DataLayout::NHWC);
+
+  // If scale and offset provided create quantized ACL tensor.
+  if (scale != nullptr && offset != nullptr) {
+    std::vector<float> scale_data = GetVectorFromDLTensor<float>(scale);
+    std::vector<int> offset_data = GetVectorFromDLTensor<int>(offset);
+    CHECK(scale_data.size() == 1 && offset_data.size() == 1)
+        << "Currently only per-layer quantization is supported in the Arm Compute Library runtime.";
+    arm_compute::QuantizationInfo qinfo(scale_data[0], offset_data[0]);
+    info.set_quantization_info(qinfo);
+  }
+
+  return info;
 }
 
-std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeMemoryManager() {
+std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeACLMemoryManager() {
   auto lifetime_mgr = std::make_shared<arm_compute::OffsetLifetimeManager>();
   auto pool_mgr = std::make_shared<arm_compute::PoolManager>();
   return std::make_shared<arm_compute::MemoryManagerOnDemand>(lifetime_mgr, pool_mgr);
 }
 
-arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
-                                          const std::vector<std::string>& stride) {
+arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
+                                            const std::vector<std::string>& stride) {
   int pad_0 = 0, pad_1 = 0, pad_2 = 0, pad_3 = 0;
   int stride_0 = std::stoi(stride[0]), stride_1 = std::stoi(stride[1]);
   size_t size = pad.size();
@@ -108,6 +113,30 @@ arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
                                     arm_compute::DimensionRoundingType::FLOOR);
 }
 
+arm_compute::DataType MakeACLDataType(const DLDataType& data_type) {
+  if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 32) {
+    return arm_compute::DataType::F32;
+  } else if (data_type.code == DLDataTypeCode::kDLUInt && data_type.bits == 8) {
+    return arm_compute::DataType::QASYMM8;
+  } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 32) {
+    return arm_compute::DataType::S32;
+  } else {
+    LOG(FATAL) << "Datatype " << data_type << " unsupported by ACL runtime";
+    return arm_compute::DataType::UNKNOWN;
+  }
+}
+
+template <typename T>
+std::vector<T> GetVectorFromDLTensor(const DLTensor* tensor) {
+  CHECK(tensor) << "Cannot convert a nullptr";
+  int len = 1;
+  for (int i = 0; i < tensor->ndim; i++) {
+    len *= tensor->shape[i];
+  }
+  T* data = static_cast<T*>(tensor->data);
+  return std::vector<T>(data, data + len);
+}
+
 }  // namespace contrib
 }  // namespace runtime
 }  // namespace tvm
index 6a92780..80c6f0b 100644 (file)
@@ -58,35 +58,27 @@ void CheckACLError(const arm_compute::Status& status);
  *
  * \param tensor_rep A JSON tensor representation.
  * \param data (optional) Initialize the tensor with memory.
+ * \param scale (optional) The quantization scale.
+ * \param offset (optional) The quantization offset.
  * \return arm_compute::Tensor.
  */
-arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data = nullptr);
-
-/*!
- * \brief Make an acl tensor from type and shape, without having a JSON representation.
- *
- * \param shape The shape of the tensor to create.
- * \return arm_compute::Tensor.
- */
-arm_compute::Tensor MakeOutputTensor(const std::vector<int64_t>& shape);
+arm_compute::Tensor MakeACLTensor(const JSONGraphNode& tensor_rep, void* data = nullptr,
+                                  const DLTensor* scale = nullptr,
+                                  const DLTensor* offset = nullptr);
 
 /*!
  * \brief Make an acl tensor info object from JSON tensor
  * representation.
  *
  * \param shape The shape of the tensor to create.
+ * \param dtype The data type of the tensor to create.
+ * \param scale (optional) The quantization scale.
+ * \param offset (optional) The quantization offset.
  * \return arm_compute::TensorInfo.
  */
-arm_compute::TensorInfo MakeTensorInfo(const std::vector<int64_t>& shape);
-
-/*!
- * \brief Convert vector object to acl TensorShape.
- * \note This requires reversing the given vector.
- *
- * \param shape The shape of the tensor as a vector.
- * \return arm_compute::TensorShape.
- */
-arm_compute::TensorShape MakeTensorShape(const std::vector<int64_t>& shape);
+arm_compute::TensorInfo MakeACLTensorInfo(const std::vector<int64_t>& shape,
+                                          const DLDataType& dtype, const DLTensor* scale = nullptr,
+                                          const DLTensor* offset = nullptr);
 
 /*!
  * \brief Create a memory manager for use with a layer that
@@ -94,7 +86,7 @@ arm_compute::TensorShape MakeTensorShape(const std::vector<int64_t>& shape);
  *
  * \return reference counted memory manager.
  */
-std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeMemoryManager();
+std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeACLMemoryManager();
 
 /*!
  * \brief Convert TVM padding and stride format to acl PadStrideInfo.
@@ -103,8 +95,27 @@ std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeMemoryManager();
  * \param stride The stride vector.
  * \return arm_compute::PadStrideInfo
  */
-arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
-                                          const std::vector<std::string>& stride);
+arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
+                                            const std::vector<std::string>& stride);
+
+/*!
+ * \brief Convert DLDataType to arm_compute::DataType.
+ *
+ * \param data_type The data type to convert.
+ * \return arm_compute::DataType.
+ */
+arm_compute::DataType MakeACLDataType(const DLDataType& data_type);
+
+/*!
+ * \brief Get a vector from DLTensor data.
+ * \note Performs a copy of data.
+ *
+ * \tparam T The type of the vector.
+ * \param tensor The tensor to convert.
+ * \return Vector of type T.
+ */
+template <typename T>
+std::vector<T> GetVectorFromDLTensor(const DLTensor* tensor);
 
 }  // namespace contrib
 }  // namespace runtime
index ea486b0..5ed2763 100644 (file)
@@ -17,6 +17,8 @@
 from itertools import zip_longest, combinations
 import json
 
+import numpy as np
+
 import tvm
 from tvm import relay
 from tvm import rpc
@@ -154,15 +156,30 @@ def update_lib(lib, device, cross_compile):
     return lib
 
 
-def verify(answers, atol, rtol):
+def verify(answers, atol, rtol, verify_saturation=False, params=None):
     """Compare the array of answers. Each entry is a list of outputs."""
+    if params is None:
+        params = {}
+
     if len(answers) < 2:
         raise RuntimeError(
             f"No results to compare: expected at least two, found {len(answers)}")
     for answer in zip_longest(*answers):
         for outs in combinations(answer, 2):
-            tvm.testing.assert_allclose(
-               outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
+            if verify_saturation:
+                assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \
+                    "Output is saturated: {}".format(outs[0])
+                assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \
+                    "Output is saturated: {}".format(outs[0])
+            try:
+                tvm.testing.assert_allclose(
+                   outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
+            except AssertionError as e:
+                err_msg = "Results not within the acceptable tolerance.\n"
+                if params:
+                    err_msg += f"The test failed with the following parameters: {params}\n"
+                err_msg += str(e)
+                raise AssertionError(err_msg)
 
 
 def extract_acl_modules(module):
@@ -195,3 +212,45 @@ def verify_codegen(module, known_good_codegen, num_acl_modules,
             f"The JSON produced by codegen does not match the expected result. \n" \
             f"Actual={codegen_str} \n" \
             f"Expected={known_good_codegen_str}"
+
+
+def generate_trials(space, r_factor=3):
+    """Generates a series of trials.
+
+    This algorithm generates a series of non-deterministic trials given a
+    space of options to test. A trial is generated by pulling a value from
+    each option in the space. On some occasions the values are shuffled to
+    ensure a different trial on each r_factor iteration. The algorithm ensures
+    that each value from an option is used at least once. The total number of
+    trials is determined by the r_factor * the option with the largest number
+    of values.
+
+    Parameters
+    ----------
+    space: List[List[Any]]
+        A list of different options with varying values to test.
+    r_factor: (optional) int
+        The repeat factor.
+
+    Returns
+    -------
+    A list of trials specifying values for each option.
+
+    """
+    np.random.seed(0)
+    max_len = 1
+    for option in space:
+        max_len = max(max_len, len(option))
+
+    num_trials = r_factor * max_len
+    trials = []
+    for i in range(num_trials):
+        trial = []
+        for option in space:
+            if i % len(option) == 0:
+                np.random.shuffle(option)
+            trial.append(option[i % len(option)])
+
+        trials.append(trial)
+
+    return trials
index 8765878..c407466 100644 (file)
@@ -22,13 +22,13 @@ import tvm
 from tvm import relay
 
 from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run, \
-    verify, verify_codegen
+    verify, verify_codegen, generate_trials
 from .infrastructure import Device
 
 
-def _get_model(shape, kernel_size, padding, strides,
-               dilation, groups, dtype, channels,
-               var_names, has_bias=False, has_activation=False, has_pad=False):
+def _get_model(shape, kernel_h, kernel_w, padding, strides,
+               dilation, groups, dtype, channels, var_names,
+               has_bias=False, has_activation=False, has_pad=False):
     """Return a model and any parameters it may have"""
     a = relay.var(next(var_names), shape=shape, dtype=dtype)
     if has_pad:
@@ -40,20 +40,21 @@ def _get_model(shape, kernel_size, padding, strides,
             padding = (padding[0], padding[1], padding[0], padding[1])
         shape = (shape[0], shape[1] + padding[0] * 2,
                  shape[2] + padding[1] * 2, shape[3])
-    weight_shape = (kernel_size, kernel_size, shape[3] // groups, channels)
+    weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels)
     w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype))
     weights = relay.const(w, dtype)
     out = relay.nn.conv2d(
         a,
         weights,
-        kernel_size=(kernel_size, kernel_size),
+        kernel_size=(kernel_h, kernel_w),
         data_layout="NHWC",
         kernel_layout="HWIO",
-        dilation=(1, 1),
+        dilation=dilation,
         strides=strides,
         padding=padding,
         groups=groups,
-        channels=channels
+        channels=channels,
+        out_dtype=dtype
     )
     params = {"w": w}
     if has_bias:
@@ -66,59 +67,171 @@ def _get_model(shape, kernel_size, padding, strides,
     return out, params
 
 
-def _get_expected_codegen(shape, kernel_size, padding, strides,
+def _get_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, channels):
+    """Get output qnn parameters given input and kernel parameters."""
+    input_max = input_sc * (255 - input_zp)
+    input_min = - input_sc * input_zp
+    kernel_max = kernel_sc * (255 - kernel_zp)
+    kernel_min = - kernel_sc * kernel_zp
+    output_limits = [kernel_max * kernel_h * kernel_w * channels * input_max,
+                     kernel_min * kernel_h * kernel_w * channels * input_max,
+                     kernel_min * kernel_h * kernel_w * channels * input_min,
+                     kernel_max * kernel_h * kernel_w * channels * input_min]
+    output_max = max(output_limits)
+    output_min = min(output_limits)
+    output_sc = (output_max - output_min) / 255
+    output_zp = - int(output_min / output_sc)
+    return output_zp, output_sc
+
+
+def _get_qnn_model(shape, kernel_h, kernel_w,
+                   padding, strides, dilation, groups, dtype,
+                   channels, input_zp, input_sc,
+                   kernel_zp, kernel_sc, output_zp,
+                   output_sc, var_names, has_bias=False,
+                   has_activation=False, has_pad=False):
+    """Return a model and any parameters it may have."""
+    a = relay.var(next(var_names), shape=shape, dtype=dtype)
+    if has_pad:
+        p = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0))
+        a = relay.nn.pad(a, pad_width=p, pad_value=input_zp, pad_mode="constant")
+        padding = (0, 0, 0, 0)
+    else:
+        if len(padding) == 2:
+            padding = (padding[0], padding[1], padding[0], padding[1])
+        shape = (shape[0], shape[1] + padding[0] * 2,
+                 shape[2] + padding[1] * 2, shape[3])
+    weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels)
+    w = tvm.nd.array(np.random.uniform(0, 255, weight_shape).astype(dtype))
+    weights = relay.const(w, dtype)
+    out = relay.qnn.op.conv2d(
+        a,
+        weights,
+        input_zero_point=relay.const(input_zp, "int32"),
+        kernel_zero_point=relay.const(kernel_zp, "int32"),
+        input_scale=relay.const(input_sc, "float32"),
+        kernel_scale=relay.const(kernel_sc, "float32"),
+        kernel_size=(kernel_h, kernel_w),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        dilation=dilation,
+        strides=strides,
+        padding=padding,
+        groups=groups,
+        channels=channels,
+        out_dtype="int32"
+    )
+    params = {"w": w}
+    if has_bias:
+        b = tvm.nd.array(np.random.uniform(0, 255, weight_shape[3]).astype("int32"))
+        biasc = relay.const(b, "int32")
+        out = relay.nn.bias_add(out, biasc, axis=3)
+        params['b'] = b
+    if has_activation:
+        out = relay.nn.relu(out)
+    req = relay.qnn.op.requantize(
+        out,
+        relay.const(input_sc * kernel_sc, 'float32'),  # input scale
+        relay.const(0, 'int32'),  # input zero point
+        relay.const(output_sc, 'float32'),  # output scale
+        relay.const(output_zp, 'int32'),  # output zero point
+        out_dtype="uint8"
+    )
+    return req, params
+
+
+def _get_expected_codegen(shape, kernel_h, kernel_w, padding, strides,
                           dilation, groups, dtype, channels,
                           has_bias=False, has_activation=False):
     if len(padding) == 2:
         padding = (padding[0], padding[1], padding[0], padding[1])
-    weight_shape = (channels, kernel_size, kernel_size, shape[3] // groups)
-    output_height = ((shape[1] - kernel_size + padding[0] + padding[2]) / strides[0]) + 1
-    output_width = ((shape[2] - kernel_size + padding[1] + padding[3]) / strides[1]) + 1
+    weight_shape = (channels, kernel_h, kernel_w, shape[3] // groups)
+    output_height = ((shape[1] - kernel_h + padding[0] + padding[2]) / strides[0]) + 1
+    output_width = ((shape[2] - kernel_w + padding[1] + padding[3]) / strides[1]) + 1
     output_shape = (1, int(output_height), int(output_width), channels)
+    out_dtype = "int32" if dtype == "uint8" else "float32"
 
     node = {
-            "op": "kernel",
-            "name": "nn.conv2d",
-            "inputs": [[0, 0, 0], [1, 0, 0]],
-            "attrs": {
-                "groups": [["1"]],
-                "num_inputs": str(3 if has_bias else 2),
-                "num_outputs": "1",
-                "data_layout": [["NHWC"]],
-                "kernel_layout": [["OHWI"]],
-                "channels": [["1"]],
-                "dilation": [["1", "1"]],
-                "out_layout": [[""]],
-                "out_dtype": [[""]],
-                "kernel_size": [[str(kernel_size), str(kernel_size)]],
-                "shape": [[list(output_shape)]],
-                "dtype": [[dtype]],
-                "padding": [[str(p) for p in padding]],
-                "strides": [[str(s) for s in strides]]
-            },
-        }
+        "op": "kernel",
+        "name": "nn.conv2d",
+        "inputs": [],
+        "attrs": {
+            "groups": [["1"]],
+            "num_outputs": "1",
+            "data_layout": [["NHWC"]],
+            "kernel_layout": [["OHWI"]],
+            "channels": [[str(channels)]],
+            "dilation": [[str(dilation[0]), str(dilation[1])]],
+            "out_layout": [[""]],
+            "out_dtype": [[out_dtype]],
+            "kernel_size": [[str(kernel_h), str(kernel_w)]],
+            "shape": [[list(output_shape)]],
+            "dtype": [[dtype]],
+            "padding": [[str(p) for p in padding]],
+            "strides": [[str(s) for s in strides]]
+        },
+    }
 
     if has_activation:
         node["attrs"]["activation_type"] = [["relu"]]
 
-    input = {
+    inputs = [{
         "op": "input",
         "name": "",
-        "attrs": {"shape": [[list(shape)]], "dtype": [["float32"]]}}
-    kernel = {
+        "attrs": {
+            "shape": [[list(shape)]],
+            "dtype": [[str(dtype)]]
+        }}, {
         "op": "const",
         "name": "",
-        "attrs": {"shape": [[list(weight_shape)]], "dtype": [["float32"]]}}
+        "attrs": {
+            "shape": [[list(weight_shape)]],
+            "dtype": [[str(dtype)]]
+        }}]
+
+    # qnn.conv2d params, input and kernel
+    if dtype == "uint8":
+        node["name"] = "qnn.conv2d"
+        for param_dtype in ["int32", "float32"]:
+            for _ in range(2):
+                inputs.append({
+                    "op": "const",
+                    "name": "",
+                    "attrs": {
+                        "shape": [[[]]],
+                        "dtype": [[param_dtype]]
+                    }
+                })
 
     if has_bias:
-        bias = {
+        bias_dtype = "int32" if dtype == "uint8" else "float32"
+        inputs.append({
             "op": "const",
             "name": "",
-            "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [["float32"]]}}
-        node["inputs"].append([2, 0, 0])
-        return [input, kernel, bias, node]
-    else:
-        return [input, kernel, node]
+            "attrs": {
+                "shape": [[[weight_shape[0]]]],
+                "dtype": [[bias_dtype]]}
+        })
+
+    # qnn.conv2d params, output
+    if dtype == "uint8":
+        for param_dtype in ["float32", "int32"]:
+            inputs.append({
+                "op": "const",
+                "name": "",
+                "attrs": {
+                    "shape": [[[]]],
+                    "dtype": [[param_dtype]]
+                }
+            })
+
+    input_idx = 0
+    for _ in range(len(inputs)):
+        node["inputs"].append([input_idx, 0, 0])
+        input_idx += 1
+    node["attrs"]["num_inputs"] = str(len(inputs))
+    inputs.append(node)
+    return inputs
 
 
 def test_conv2d():
@@ -128,50 +241,31 @@ def test_conv2d():
     device = Device()
     np.random.seed(0)
 
-    shape = (1, 14, 14, 32)
+    kernel_hs = [1, 2, 3, 5]
+    kernel_ws = [1, 2, 3, 5]
+    pad = [(1, 1), (2, 2), (2, 1)]
+    strides = [(1, 1), (2, 2)]
+    dilation = [(1, 1)]
+    out_channels = [4, 7, 16]
+    input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
+    # composite operator (pad, bias, activation)
+    composite = [(False, False, False), (False, True, False), (False, False, True),
+                 (False, True, True), (True, False, False)]
     dtype = "float32"
+    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
+                              input_shapes, composite], 3)
 
-    inputs = {
-        "a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype)),
-    }
-
-    for kernel_size in [1, 2, 3]:
-        outputs = []
-        func, params = _get_model(shape, kernel_size,
-                                  (0, 0), (1, 1), 1, 1,
-                                  dtype, 1, iter(inputs))
-        for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1,
-                                         params, device,
-                                         enable_acl=acl)[0])
-        verify(outputs, atol=0.002, rtol=0.01)
-
-    for pad_ksize in [((1, 1), 3), ((2, 2), 5), ((2, 1), 3)]:
+    for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
+        groups = 1
+        shape = (1, *input_shapes)
         outputs = []
-        func, params = _get_model(shape, pad_ksize[1], pad_ksize[0],
-                                  (1, 1), 1, 1, dtype, 1, iter(inputs))
-        for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1,
-                                         params, device,
-                                         enable_acl=acl)[0])
-        verify(outputs, atol=0.002, rtol=0.01)
-
-    for strides in [(1, 1), (2, 2)]:
-        outputs = []
-        func, params = _get_model(shape, 2, (0, 0), strides,
-                                  1, 1, dtype, 1, iter(inputs))
-        for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1,
-                                         params, device,
-                                         enable_acl=acl)[0])
-        verify(outputs, atol=0.002, rtol=0.01)
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype)),
+        }
 
-    # Test composite convolution: (has_pad, has_bias, has_activation).
-    for composite in [(False, True, False), (False, False, True), (False, True, True),
-                      (True, False, False)]:
-        outputs = []
-        func, params = _get_model(shape, 2, (1, 1), (1, 1),
-                                  1, 1, dtype, 1, iter(inputs),
+        func, params = _get_model(shape, kernel_h, kernel_w,
+                                  pad, stride, dilation, groups,
+                                  dtype, out_channels, iter(inputs),
                                   has_pad=composite[0],
                                   has_bias=composite[1],
                                   has_activation=composite[2])
@@ -179,26 +273,47 @@ def test_conv2d():
             outputs.append(build_and_run(func, inputs, 1,
                                          params, device,
                                          enable_acl=acl)[0])
-        verify(outputs, atol=0.002, rtol=0.01)
+
+        params = {
+            "shape": shape,
+            "groups": groups,
+            "kernel size": (kernel_h, kernel_w),
+            "padding": pad,
+            "stride": stride,
+            "dilation": dilation,
+            "out channels": out_channels,
+            "composite operators (pad, bias, activation)": composite
+        }
+        verify(outputs, atol=0.002, rtol=0.01, params=params)
 
 
 def test_codegen_conv2d():
     if skip_codegen_test():
         return
 
-    shape = (1, 25, 25, 1)
+    np.random.seed(0)
+
+    kernel_hs = [1, 2, 3, 5]
+    kernel_ws = [1, 2, 3, 5]
+    pad = [(1, 1), (2, 2), (2, 1)]
+    strides = [(1, 1), (2, 2)]
+    dilation = [(1, 1)]
+    out_channels = [4, 7, 16]
+    input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
+    # composite operator (pad, bias, activation)
+    composite = [(False, False, False), (False, True, False), (False, False, True),
+                 (False, True, True), (True, False, False)]
     dtype = "float32"
-    inputs = {"a"}
+    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
+                              input_shapes, composite], 3)
+
+    for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
+        groups = 1
+        shape = (1, *input_shapes)
+        inputs = {"a"}
+
+        args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels)
 
-    for pad_ksize in [((1, 1), 3), ((2, 1), 3)]:
-        args = (shape, pad_ksize[1], pad_ksize[0], (1, 1), 1, 1, dtype, 1)
-        func, params = _get_model(*args, var_names=iter(inputs))
-        exp_codegen = _get_expected_codegen(*args)
-        verify_codegen(func, exp_codegen, 1)
-    # Test composite convolution: (has_pad, has_bias, has_activation).
-    for composite in [(False, True, False), (False, False, True), (False, True, True),
-                      (True, False, False)]:
-        args = (shape, 2, (1, 1), (1, 1), 1, 1, dtype, 1)
         func, params = _get_model(*args, var_names=iter(inputs),
                                   has_pad=composite[0],
                                   has_bias=composite[1],
@@ -209,6 +324,128 @@ def test_codegen_conv2d():
         verify_codegen(func, exp_codegen, 1)
 
 
+def test_qnn_conv2d():
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    np.random.seed(0)
+
+    kernel_hs = [1, 2, 3, 5]
+    kernel_ws = [1, 2, 3, 5]
+    pad = [(1, 1), (2, 2)]
+    strides = [(1, 1), (2, 2)]
+    dilation = [(1, 1)]
+    out_channels = [4, 7, 16]
+    input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
+    # composite operator (pad, bias, activation)
+    composite = [(False, False, False), (False, True, False), (False, False, True),
+                 (False, True, True), (True, False, False)]
+    dtype = "uint8"
+    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
+                              input_shapes, composite], 3)
+
+    for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
+        groups = 1
+        shape = (1, *input_shapes)
+        outputs = []
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))
+        }
+
+        input_zp = 100
+        input_sc = 0.5
+        kernel_zp = 25
+        kernel_sc = 0.03
+        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
+                                               kernel_zp, kernel_sc,
+                                               kernel_h, kernel_w, shape[3])
+
+        func, params = _get_qnn_model(shape, kernel_h, kernel_w,
+                                      pad, stride, dilation, groups,
+                                      dtype, out_channels,
+                                      input_zp, input_sc,
+                                      kernel_zp, kernel_sc,
+                                      output_zp, output_sc,
+                                      iter(inputs),
+                                      has_pad=composite[0],
+                                      has_bias=composite[1],
+                                      has_activation=composite[2])
+        for acl in [False, True]:
+            outputs.append(build_and_run(func, inputs, 1,
+                                         params, device,
+                                         enable_acl=acl)[0])
+
+        params = {
+            "shape": shape,
+            "groups": groups,
+            "kernel size": (kernel_h, kernel_w),
+            "padding": pad,
+            "stride": stride,
+            "dilation": dilation,
+            "out channels": out_channels,
+            "composite operators (pad, bias, activation)": composite,
+            "input scale": input_sc,
+            "input zero point": input_zp,
+            "kernel scale": kernel_sc,
+            "kernel zero point": kernel_zp,
+            "output scale": output_sc,
+            "output zero point": output_zp
+        }
+        verify(outputs, atol=1, rtol=0, params=params)
+
+
+def test_codegen_qnn_conv2d():
+    if skip_codegen_test():
+        return
+
+    np.random.seed(0)
+
+    kernel_hs = [1, 2, 3, 5]
+    kernel_ws = [1, 2, 3, 5]
+    pad = [(1, 1), (2, 2), (2, 1)]
+    strides = [(1, 1), (2, 2)]
+    dilation = [(1, 1)]
+    out_channels = [4, 7, 16]
+    input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)]
+    # composite operator (pad, bias, activation)
+    composite = [(False, False, False), (False, True, False), (False, False, True),
+                 (False, True, True), (True, False, False)]
+    dtype = "uint8"
+    trials = generate_trials([kernel_hs, kernel_ws, pad, strides, dilation, out_channels,
+                              input_shapes, composite], 3)
+
+    for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials:
+        groups = 1
+        shape = (1, *input_shapes)
+        inputs = {"a"}
+
+        input_zp = 100
+        input_sc = 0.5
+        kernel_zp = 25
+        kernel_sc = 0.03
+        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
+                                               kernel_zp, kernel_sc,
+                                               kernel_h, kernel_w, shape[3])
+
+        args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels)
+
+        func, params = _get_qnn_model(*args,
+                                      input_zp=input_zp, input_sc=input_sc,
+                                      kernel_zp=kernel_zp, kernel_sc=kernel_sc,
+                                      output_zp=output_zp, output_sc=output_sc,
+                                      var_names=iter(inputs),
+                                      has_pad=composite[0],
+                                      has_bias=composite[1],
+                                      has_activation=composite[2])
+        exp_codegen = _get_expected_codegen(*args,
+                                            has_bias=composite[1],
+                                            has_activation=composite[2])
+        verify_codegen(func, exp_codegen, 1)
+
+
 if __name__ == "__main__":
     test_conv2d()
+    test_qnn_conv2d()
     test_codegen_conv2d()
+    test_codegen_qnn_conv2d()
index 8648a01..1ba6ca7 100644 (file)
@@ -24,12 +24,17 @@ from .infrastructure import skip_runtime_test, build_and_run, verify
 from .infrastructure import Device
 
 
-def _build_and_run_keras_network(mod, params, inputs, device, tvm_ops, acl_partitions):
-    """Helper function to build and run a network from the Keras frontend."""
+def _build_and_run_network(mod, params, inputs, device, tvm_ops, acl_partitions, atol, rtol):
+    """Helper function to build and run a network."""
     data = {}
     np.random.seed(0)
-    for name, shape in inputs.items():
-        data[name] = np.random.uniform(-128, 127, shape).astype("float32")
+
+    for name, (shape, dtype) in inputs.items():
+        if dtype == "uint8":
+            low, high = 0, 255
+        else:
+            low, high = -127, 128
+        data[name] = np.random.uniform(low, high, shape).astype(dtype)
 
     outputs = []
     for acl in [False, True]:
@@ -37,7 +42,40 @@ def _build_and_run_keras_network(mod, params, inputs, device, tvm_ops, acl_parti
                                      device, enable_acl=acl,
                                      tvm_ops=tvm_ops,
                                      acl_partitions=acl_partitions)[0])
-    verify(outputs, atol=0.002, rtol=0.01)
+    verify(outputs, atol=atol, rtol=rtol, verify_saturation=False)
+
+
+def _get_tflite_model(tflite_model_path, inputs_dict):
+    """Convert TFlite graph to relay."""
+    import tflite.Model
+
+    with open(tflite_model_path, 'rb') as f:
+        tflite_model_buffer = f.read()
+
+    try:
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buffer, 0)
+    except AttributeError:
+        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buffer, 0)
+    shape_dict = {}
+    dtype_dict = {}
+    for input in inputs_dict:
+        input_shape, input_dtype = inputs_dict[input]
+        shape_dict[input] = input_shape
+        dtype_dict[input] = input_dtype
+
+    return relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict=shape_dict,
+        dtype_dict=dtype_dict
+    )
+
+
+def _get_keras_model(keras_model, inputs_dict):
+    """Convert Keras graph to relay."""
+    inputs = {}
+    for name, (shape, _) in inputs_dict.items():
+        inputs[keras_model.input_names[0]] = shape
+    return relay.frontend.from_keras(keras_model, inputs, layout="NHWC")
 
 
 def test_vgg16():
@@ -50,12 +88,13 @@ def test_vgg16():
         from keras.applications import VGG16
         vgg16 = VGG16(include_top=True, weights='imagenet',
                       input_shape=(224, 224, 3), classes=1000)
-        inputs = {vgg16.input_names[0]: (1, 224, 224, 3)}
-        mod, params = relay.frontend.from_keras(vgg16, inputs, layout="NHWC")
+        inputs = {vgg16.input_names[0]: ((1, 224, 224, 3), "float32")}
+        mod, params = _get_keras_model(vgg16, inputs)
         return mod, params, inputs
 
-    _build_and_run_keras_network(*get_model(), device=device,
-                                 tvm_ops=10, acl_partitions=18)
+    _build_and_run_network(*get_model(), device=device,
+                           tvm_ops=10, acl_partitions=18,
+                           atol=0.002, rtol=0.01)
 
 
 def test_mobilenet():
@@ -68,14 +107,42 @@ def test_mobilenet():
         from keras.applications import MobileNet
         mobilenet = MobileNet(include_top=True, weights='imagenet',
                               input_shape=(224, 224, 3), classes=1000)
-        inputs = {mobilenet.input_names[0]: (1, 224, 224, 3)}
-        mod, params = relay.frontend.from_keras(mobilenet, inputs, layout="NHWC")
+        inputs = {mobilenet.input_names[0]: ((1, 224, 224, 3), "float32")}
+        mod, params = _get_keras_model(mobilenet, inputs)
+        return mod, params, inputs
+
+    _build_and_run_network(*get_model(), device=device,
+                           tvm_ops=74, acl_partitions=17,
+                           atol=0.002, rtol=0.01)
+
+
+def test_quantized_mobilenet():
+    if skip_runtime_test():
+        return
+
+    import tvm.relay.testing.tf as tf_testing
+
+    device = Device()
+
+    def get_model():
+        model_path = tf_testing.get_workload_official(
+            "https://storage.googleapis.com/download.tensorflow.org/" \
+            "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
+            "mobilenet_v1_1.0_224_quant.tflite",
+        )
+        inputs = {"input": ((1, 224, 224, 3), "uint8")}
+        mod, params = _get_tflite_model(
+            model_path,
+            inputs_dict=inputs
+        )
         return mod, params, inputs
 
-    _build_and_run_keras_network(*get_model(), device=device,
-                                 tvm_ops=74, acl_partitions=17)
+    _build_and_run_network(*get_model(), device=device,
+                           tvm_ops=45, acl_partitions=16,
+                           atol=8, rtol=0)
 
 
 if __name__ == "__main__":
     test_vgg16()
     test_mobilenet()
+    test_quantized_mobilenet()
index aac7795..4d48f79 100644 (file)
@@ -26,17 +26,17 @@ from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run,
 from .infrastructure import Device
 
 
-def _get_model(shape, typef, sizes, strides, padding,
+def _get_model(shape, dtype, typef, sizes, strides, padding,
                ceil_mode, var_names):
     """Return a model and any parameters it may have."""
-    var = relay.var(next(var_names), shape=shape, dtype="float32")
+    var = relay.var(next(var_names), shape=shape, dtype=dtype)
     pool = typef(var, pool_size=sizes, strides=strides, padding=padding,
                  ceil_mode=ceil_mode, layout="NHWC")
     return pool
 
 
-def _get_expected_codegen(shape, typef, sizes, strides, padding,
-                          ceil_mode):
+def _get_expected_codegen(shape, dtype, typef, sizes, strides,
+                          padding, ceil_mode):
     if len(padding) == 2:
         padding = (padding[1], padding[1], padding[0], padding[0])
     output_height = ((shape[1] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1
@@ -52,7 +52,7 @@ def _get_expected_codegen(shape, typef, sizes, strides, padding,
             "num_outputs": "1",
             "layout": [["NHWC"]],
             "shape": [[list(output_shape)]],
-            "dtype": [["float32"]],
+            "dtype": [[dtype]],
             "padding": [[str(p) for p in padding]],
             "strides": [[str(s) for s in strides]],
             "pool_size": [[str(s) for s in sizes]],
@@ -63,7 +63,7 @@ def _get_expected_codegen(shape, typef, sizes, strides, padding,
     input = {
         "op": "input",
         "name": "",
-        "attrs": {"shape": [[list(shape)]], "dtype": [["float32"]]}}
+        "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
     return [input, node]
 
 
@@ -74,22 +74,33 @@ def test_pooling():
     device = Device()
     np.random.seed(0)
 
-    for size in [(2, 2), (3, 3)]:
-        for stride in [(2, 2)]:
-            shape = (1, size[0] + stride[0] * 5,
-                     size[1] + stride[1] * 5, 16)
-
-            inputs = {
-                "a": tvm.nd.array(np.random.uniform(-1, 1, shape).astype("float32")),
-            }
-
-            outputs = []
-            func = _get_model(shape, relay.nn.max_pool2d, size,
-                              stride, (0, 0), True, iter(inputs))
-            for acl in [False, True]:
-                outputs.append(build_and_run(func, inputs, 1, None, device,
-                                             enable_acl=acl)[0])
-            verify(outputs, atol=0.001, rtol=0.001)
+    for dtype, low, high, atol, rtol in [("float32", -127, 128, 0.001, 0.001), ("uint8", 0, 255, 0, 0)]:
+        for size in [(2, 2), (3, 3)]:
+            for stride in [(2, 2)]:
+                shape = (1, size[0] + stride[0] * 5,
+                         size[1] + stride[1] * 5, 16)
+                pad = (0, 0)
+
+                inputs = {
+                    "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
+                }
+
+                outputs = []
+                func = _get_model(shape, dtype, relay.nn.max_pool2d, size,
+                                  stride, pad, True, iter(inputs))
+                for acl in [False, True]:
+                    outputs.append(build_and_run(func, inputs, 1, None, device,
+                                                 enable_acl=acl)[0])
+
+                params = {
+                    "size": size,
+                    "stride": stride,
+                    "shape": shape,
+                    "pooling type": "max",
+                    "dtype": dtype,
+                    "padding": pad
+                }
+                verify(outputs, atol=atol, rtol=rtol, params=params)
 
 
 def test_codegen_pooling():
@@ -98,15 +109,16 @@ def test_codegen_pooling():
 
     inputs = {"a"}
 
-    for size in [(2, 2), (3, 3)]:
-        for stride in [(2, 2)]:
-            shape = (1, size[0] + stride[0] * 5,
-                     size[1] + stride[1] * 5, 16)
-            args = (shape, relay.nn.max_pool2d, size,
-                    stride, (0, 0), True)
-            func = _get_model(*args, iter(inputs))
-            exp_codegen = _get_expected_codegen(*args)
-            verify_codegen(func, exp_codegen, 1)
+    for dtype in ["float32", "uint8"]:
+        for size in [(2, 2), (3, 3)]:
+            for stride in [(2, 2)]:
+                shape = (1, size[0] + stride[0] * 5,
+                         size[1] + stride[1] * 5, 16)
+                args = (shape, dtype, relay.nn.max_pool2d, size,
+                        stride, (0, 0), True)
+                func = _get_model(*args, iter(inputs))
+                exp_codegen = _get_expected_codegen(*args)
+                verify_codegen(func, exp_codegen, 1)
 
 
 if __name__ == "__main__":
index cb9f295..8ab9437 100644 (file)
@@ -26,14 +26,14 @@ from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run,
 from .infrastructure import Device
 
 
-def _get_model(input_shape, output_shape, var_names):
+def _get_model(input_shape, output_shape, dtype, var_names):
     """Return a model and any parameters it may have."""
-    a = relay.var(next(var_names), shape=input_shape, dtype="float32")
+    a = relay.var(next(var_names), shape=input_shape, dtype=dtype)
     reshape = relay.reshape(a, output_shape)
     return reshape
 
 
-def _get_expected_codegen(input_shape, output_shape):
+def _get_expected_codegen(input_shape, output_shape, dtype):
     node = {
         "op": "kernel",
         "name": "reshape",
@@ -43,7 +43,7 @@ def _get_expected_codegen(input_shape, output_shape):
             "num_outputs": "1",
             "newshape": [[str(s) for s in output_shape]],
             "shape": [[list(output_shape)]],
-            "dtype": [["float32"]],
+            "dtype": [[dtype]],
             "reverse": [["0"]]
         },
     }
@@ -51,7 +51,7 @@ def _get_expected_codegen(input_shape, output_shape):
     input = {
         "op": "input",
         "name": "",
-        "attrs": {"shape": [[list(input_shape)]], "dtype": [["float32"]]}}
+        "attrs": {"shape": [[list(input_shape)]], "dtype": [[dtype]]}}
 
     return [input, node]
 
@@ -63,18 +63,25 @@ def test_reshape():
     device = Device()
     np.random.seed(0)
 
-    inputs = {
-        "a": tvm.nd.array(
-            np.random.uniform(-128, 127, (1, 1, 1, 1000)).astype("float32"))
-    }
+    for dtype, low, high, atol, rtol in [("float32", -127, 128, 0.001, 0.001), ("uint8", 0, 255, 0, 0)]:
+        inputs = {
+            "a": tvm.nd.array(
+                np.random.uniform(low, high, (1, 1, 1, 1000)).astype(dtype))
+        }
+
+        for new_shape in [(1, 1000), (10, 10, 10)]:
+            outputs = []
+            func = _get_model(inputs["a"].shape, new_shape, dtype, iter(inputs))
+            for acl in [False, True]:
+                outputs.append(build_and_run(func, inputs, 1, None, device,
+                                             enable_acl=acl)[0])
 
-    for shape in [(1, 1000), (10, 10, 10)]:
-        outputs = []
-        func = _get_model(inputs["a"].shape, shape, iter(inputs))
-        for acl in [False, True]:
-            outputs.append(build_and_run(func, inputs, 1, None, device,
-                                         enable_acl=acl)[0])
-        verify(outputs, atol=1e-7, rtol=1e-7)
+            params = {
+                "new shape": inputs["a"].shape,
+                "shape": new_shape,
+                "dtype": dtype,
+            }
+            verify(outputs, atol=1e-7, rtol=1e-7, params=params)
 
 
 def test_codegen_reshape():
@@ -83,12 +90,12 @@ def test_codegen_reshape():
 
     shape = (1, 1, 1, 1000)
     inputs = {"a"}
-
-    for new_shape in [(1, 1000), (10, 10, 10)]:
-        args = (shape, new_shape)
-        func = _get_model(*args, iter(inputs))
-        exp_codegen = _get_expected_codegen(*args)
-        verify_codegen(func, exp_codegen, 1)
+    for dtype in ["float32", "uint8"]:
+        for new_shape in [(1, 1000), (10, 10, 10)]:
+            args = (shape, new_shape, dtype)
+            func = _get_model(*args, iter(inputs))
+            exp_codegen = _get_expected_codegen(*args)
+            verify_codegen(func, exp_codegen, 1)
 
 
 if __name__ == "__main__":