[BYOC][ACL] Add support for dense (fully connected) layer (#6254)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Thu, 13 Aug 2020 16:21:35 +0000 (17:21 +0100)
committerGitHub <noreply@github.com>
Thu, 13 Aug 2020 16:21:35 +0000 (09:21 -0700)
* [BYOC][ACL] Add support for dense (fully connected) layer

This patch adds the ability to offload dense (or fully connected) operators to ACL.

For fp32 a single dense layer can be offloaded, or the composite variant: nn.dense, nn.bias_add? (ACL does not currently offer fused activation).
For uint8: qnn.dense, nn.bias_add?, qnn.requantize

Change-Id: I83ea00b2aa6bdc5d9ef5cd6d54bbf981e523bd14

* Don't offload dense layer with unsupported datatype

Change-Id: I856eb2298499fdf22c172ba7f85d21033d3cc920

docs/deploy/arm_compute_lib.rst
python/tvm/relay/op/contrib/arm_compute_lib.py
src/relay/backend/contrib/arm_compute_lib/codegen.cc
src/runtime/contrib/arm_compute_lib/acl_runtime.cc
tests/python/contrib/test_arm_compute_lib/test_conv2d.py
tests/python/contrib/test_arm_compute_lib/test_dense.py [new file with mode: 0644]
tests/python/contrib/test_arm_compute_lib/test_network.py
tests/python/contrib/test_arm_compute_lib/test_pooling.py

index eaffc0a..c0b1a7e 100644 (file)
@@ -181,6 +181,13 @@ Operator support
 |              |                                                                         |
 |              | (only groups = 1 supported)                                             |
 +--------------+-------------------------------------------------------------------------+
+| nn.dense     | fp32:                                                                   |
+|              |   Simple: nn.dense                                                      |
+|              |   Composite: nn.dense, nn.bias_add?                                     |
++--------------+-------------------------------------------------------------------------+
+| qnn.dense    | uint8:                                                                  |
+|              |   Composite: qnn.dense, nn.bias_add?, qnn.requantize                    |
++--------------+-------------------------------------------------------------------------+
 | nn.maxpool2d | fp32, uint8                                                             |
 +--------------+-------------------------------------------------------------------------+
 | reshape      | fp32, uint8                                                             |
index 2f031b3..e20f2d1 100644 (file)
@@ -98,6 +98,33 @@ def arm_compute_lib_pattern_table():
             pattern, wildcard(), wildcard(), is_constant(), is_constant())
         return pattern
 
+    def dense_pattern():
+        """Create a dense (fully-connected) pattern.
+
+        Returns
+        -------
+        pattern : dataflow_pattern.AltPattern
+            Denotes the convolution pattern.
+        """
+        pattern = is_op('nn.dense')(wildcard(), is_constant())
+        pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
+        return pattern
+
+    def qnn_dense_pattern():
+        """Create a quantized dense (fully-connected) pattern.
+
+        Returns
+        -------
+        pattern : dataflow_pattern.AltPattern
+            Denotes the convolution pattern.
+        """
+        pattern = is_op('qnn.dense')(
+            wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
+        pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
+        pattern = is_op('qnn.requantize')(
+            pattern, wildcard(), wildcard(), is_constant(), is_constant())
+        return pattern
+
     def check_conv(extract):
         """Check conv pattern is supported by ACL."""
         call = extract
@@ -114,8 +141,26 @@ def arm_compute_lib_pattern_table():
             call = call.args[0]
         return qnn_conv2d(call.attrs, call.args)
 
+    def check_dense(extract):
+        """Check conv pattern is supported by ACL."""
+        call = extract
+        while call.op.name != "nn.dense":
+            call = call.args[0]
+        return dense(call.attrs, call.args)
+
+    def check_qnn_dense(extract):
+        """Check qnn conv pattern is supported by ACL."""
+        if extract.attrs.out_dtype != "uint8":
+            return False
+        call = extract
+        while call.op.name != "qnn.dense":
+            call = call.args[0]
+        return qnn_dense(call.attrs, call.args)
+
     return [('arm_compute_lib.conv2d', conv_pattern(), check_conv),
-            ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv)]
+            ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
+            ('arm_compute_lib.dense', dense_pattern(), check_dense),
+            ('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense)]
 
 
 def _register_external_op_helper(op_name, supported=True):
@@ -164,6 +209,33 @@ def qnn_conv2d(attrs, args):
     return True
 
 
+@tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib")
+def dense(attrs, args):
+    """Check if the external ACL codegen for dense should be used."""
+    data_typ = args[0].checked_type
+    if data_typ.dtype != "float32":
+        return False
+    kernel_typ = args[1].checked_type
+    if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "float32":
+        return False
+    if attrs.out_dtype != "float32" and attrs.out_dtype != "":
+        return False
+    return True
+
+
+def qnn_dense(attrs, args):
+    """Check if the external ACL codegen for qnn.dense should be used."""
+    data_typ = args[0].checked_type
+    if data_typ.dtype != "uint8":
+        return False
+    kernel_typ = args[1].checked_type
+    if len(kernel_typ.shape) != 2 or kernel_typ.dtype != "uint8":
+        return False
+    if attrs.out_dtype != "int32":
+        return False
+    return True
+
+
 @tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib")
 def max_pool2d(attrs, args):
     """Check if the external ACL codegen for maxpool2d should be used."""
index 88de3ed..1132b1c 100644 (file)
@@ -62,6 +62,16 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
   };
 
   /*!
+   * \brief A series of operators that form a composite
+   * dense layer. Supports both nn.dense and qnn.dense.
+   */
+  struct CompositeDenseNode {
+    const CallNode* dense = nullptr;
+    const CallNode* bias = nullptr;
+    const CallNode* requantize = nullptr;
+  };
+
+  /*!
    * \brief Visit call nodes and generate appropriate JSON node.
    *
    * \param cn The current call node.
@@ -82,6 +92,8 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
     std::shared_ptr<JSONGraphNode> json_node;
     if (name == "arm_compute_lib.conv2d" || name == "arm_compute_lib.qnn_conv2d") {
       json_node = CreateCompositeConvJSONNode(cn);
+    } else if (name == "arm_compute_lib.dense" || name == "arm_compute_lib.qnn_dense") {
+      json_node = CreateCompositeDenseJSONNode(cn);
     } else {
       LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
     }
@@ -190,6 +202,71 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
     }
     return json_node;
   }
+
+  /*!
+   * \brief Extract dense nodes from a composite function.
+   *
+   * \param cn The call node of the composite function.
+   * \return Extracted composite convolution nodes.
+   */
+  static CompositeDenseNode UnpackCompositeDense(const CallNode* cn) {
+    CompositeDenseNode nodes{};
+    const auto* fn = cn->op.as<FunctionNode>();
+    CHECK(fn);
+
+    // Traverse composite dense function from child to parent
+    const auto* current_call = fn->body.as<CallNode>();
+    if (backend::IsOp(current_call, "qnn.requantize")) {
+      nodes.requantize = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
+    if (backend::IsOp(current_call, "nn.bias_add")) {
+      nodes.bias = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
+    // Enforce a dense node exists at this point during traversal
+    if (nodes.requantize) {
+      CHECK(backend::IsOp(current_call, "qnn.dense"));
+    } else {
+      CHECK(backend::IsOp(current_call, "nn.dense"));
+    }
+    nodes.dense = current_call;
+    return nodes;
+  }
+
+  /*!
+   * \brief Create a JSON representation of a composite dense (fully-connected) operator.
+   *
+   * \param cn The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateCompositeDenseJSONNode(const CallNode* cn) {
+    CompositeDenseNode nodes = UnpackCompositeDense(cn);
+    std::string name = "nn.dense";
+
+    // Inputs must be added in the same order they appear in the relay graph.
+    std::vector<JSONGraphNodeEntry> inputs;
+    inputs.push_back(VisitExpr(cn->args[0])[0]);
+    inputs.push_back(VisitExpr(nodes.dense->args[1])[0]);
+    if (nodes.requantize) {
+      name = "qnn.dense";
+      inputs.push_back(VisitExpr(nodes.dense->args[2])[0]);  // input zero-point
+      inputs.push_back(VisitExpr(nodes.dense->args[3])[0]);  // weight zero-point
+      inputs.push_back(VisitExpr(nodes.dense->args[4])[0]);  // input scale
+      inputs.push_back(VisitExpr(nodes.dense->args[5])[0]);  // weight scale
+    }
+    if (nodes.bias) {
+      inputs.push_back(VisitExpr(nodes.bias->args[1])[0]);
+    }
+    if (nodes.requantize) {
+      inputs.push_back(VisitExpr(nodes.requantize->args[3])[0]);  // output scale
+      inputs.push_back(VisitExpr(nodes.requantize->args[4])[0]);  // output zero-point
+    }
+
+    auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+    SetCallNodeAttribute(json_node, nodes.dense);
+    return json_node;
+  }
 };
 
 /*!
index 2498dcf..f62420a 100644 (file)
@@ -31,6 +31,7 @@
 #ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
 #include <arm_compute/core/Types.h>
 #include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
+#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
 #include <arm_compute/runtime/NEON/functions/NEPoolingLayer.h>
 #include <arm_compute/runtime/NEON/functions/NEReshapeLayer.h>
 
@@ -128,6 +129,9 @@ class ACLRuntime : public JSONRuntimeBase {
         if ("nn.conv2d" == op_name || "qnn.conv2d" == op_name) {
           CreateConvolution2DLayer(&layer_, node, mm);
           num_pools++;
+        } else if ("nn.dense" == op_name || "qnn.dense" == op_name) {
+          CreateFullyConnectedLayer(&layer_, node, mm);
+          num_pools++;
         } else if ("nn.max_pool2d" == op_name) {
           CreatePoolingLayer(&layer_, node);
         } else if ("reshape" == op_name) {
@@ -258,6 +262,50 @@ class ACLRuntime : public JSONRuntimeBase {
   }
 
   /*!
+   * \brief Create a fully connected (dense) layer.
+   *
+   * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
+   * \param node The JSON representation of the operator.
+   * \param mm The ACL fully connected layer can request auxiliary memory from TVM.
+   */
+  void CreateFullyConnectedLayer(CachedLayer* layer, const JSONGraphNode& node,
+                                 const std::shared_ptr<arm_compute::MemoryManagerOnDemand>& mm) {
+    arm_compute::FullyConnectedLayerInfo fc_info;
+    fc_info.set_weights_trained_layout(arm_compute::DataLayout::NHWC);
+
+    // Collect inputs and outputs, handling both nn.dense and qnn.dense cases.
+    std::vector<JSONGraphNodeEntry> inputs = node.GetInputs();
+    size_t num_inputs = inputs.size();
+    bool has_bias;
+    if (node.GetOpName() == "qnn.dense") {
+      CHECK(num_inputs >= 8U && num_inputs <= 9U)
+          << "Quantized fully connected (dense) layer requires 9 inputs with a bias, 8 inputs "
+             "without.";
+      has_bias = num_inputs == 9;
+      layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[0], &inputs[4], &inputs[2]));
+      layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[1], &inputs[5], &inputs[3]));
+      if (has_bias) {
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[6]));
+      }
+      layer->outputs.push_back(
+          MakeACLTensorFromJSONNode(node, &inputs[6 + has_bias], &inputs[7 + has_bias]));
+    } else {
+      CHECK(num_inputs >= 2U && num_inputs <= 3U)
+          << "Fully connected (dense) layer requires 3 inputs with a bias, 2 inputs without.";
+      has_bias = num_inputs == 3;
+      for (const auto& i : inputs) {
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(i));
+      }
+      layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
+    }
+
+    auto function = std::make_shared<arm_compute::NEFullyConnectedLayer>(mm);
+    function->configure(&layer->inputs[0], &layer->inputs[1],
+                        has_bias ? &layer->inputs[2] : nullptr, &layer->outputs[0], fc_info);
+    layer->function = function;
+  }
+
+  /*!
    * \brief Create a pooling layer.
    *
    * \note Currently only maxpool is supported.
index c407466..a89f04d 100644 (file)
@@ -392,7 +392,7 @@ def test_qnn_conv2d():
             "output scale": output_sc,
             "output zero point": output_zp
         }
-        verify(outputs, atol=1, rtol=0, params=params)
+        verify(outputs, atol=1, rtol=0, params=params, verify_saturation=True)
 
 
 def test_codegen_qnn_conv2d():
diff --git a/tests/python/contrib/test_arm_compute_lib/test_dense.py b/tests/python/contrib/test_arm_compute_lib/test_dense.py
new file mode 100644 (file)
index 0000000..2208026
--- /dev/null
@@ -0,0 +1,319 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Arm Compute Library integration dense tests."""
+
+import numpy as np
+
+import tvm
+from tvm import relay
+
+from .infrastructure import Device, skip_runtime_test, skip_codegen_test, \
+    build_and_run, verify, verify_codegen, generate_trials
+
+
+def _get_model(shape, weight_shape, units, dtype, var_names,
+               has_bias=False):
+    """Return a model and any parameters it may have"""
+    a = relay.var(next(var_names), shape=shape, dtype=dtype)
+    w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype))
+    weights = relay.const(w, dtype)
+    out = relay.nn.dense(
+        a,
+        weights,
+        units=units,
+        out_dtype=dtype
+    )
+    params = {"w": w}
+    if has_bias:
+        b = tvm.nd.array(np.random.randint(-128, 127, weight_shape[0]).astype(dtype))
+        biasc = relay.const(b, dtype)
+        out = relay.nn.bias_add(out, biasc)
+        params['b'] = b
+    return out, params
+
+
+def _get_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc,
+                    kernel_h, kernel_w):
+    """Get output qnn parameters given input and kernel parameters."""
+    input_max = input_sc * (255 - input_zp)
+    input_min = - input_sc * input_zp
+    kernel_max = kernel_sc * (255 - kernel_zp)
+    kernel_min = - kernel_sc * kernel_zp
+    output_limits = [kernel_max * kernel_h * kernel_w * input_max,
+                     kernel_min * kernel_h * kernel_w * input_max,
+                     kernel_min * kernel_h * kernel_w * input_min,
+                     kernel_max * kernel_h * kernel_w * input_min]
+    output_max = max(output_limits)
+    output_min = min(output_limits)
+    output_sc = (output_max - output_min) / 255
+    output_zp = - int(output_min / output_sc)
+    return output_zp, output_sc
+
+
+def _get_qnn_model(shape, weight_shape, units, dtype,
+                   input_zp, input_sc, kernel_zp,
+                   kernel_sc, output_zp, output_sc, var_names,
+                   has_bias=False):
+    a = relay.var(next(var_names), shape=shape, dtype=dtype)
+    w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype))
+    weights = relay.const(w, dtype)
+    out = relay.qnn.op.dense(
+        a,
+        weights,
+        units=units,
+        input_zero_point=relay.const(input_zp, "int32"),
+        kernel_zero_point=relay.const(kernel_zp, "int32"),
+        input_scale=relay.const(input_sc, "float32"),
+        kernel_scale=relay.const(kernel_sc, "float32"),
+        out_dtype="int32"
+    )
+    params = {"w": w}
+    if has_bias:
+        b = tvm.nd.array(np.random.randint(0, 255, weight_shape[0]).astype("int32"))
+        biasc = relay.const(b, "int32")
+        out = relay.nn.bias_add(out, biasc)
+        params['b'] = b
+    out = relay.qnn.op.requantize(
+        out,
+        relay.const(input_sc * kernel_sc, 'float32'),  # input scale
+        relay.const(input_zp * kernel_zp, 'int32'),  # input zero point
+        relay.const(output_sc, 'float32'),  # output scale
+        relay.const(output_zp, 'int32'),  # output zero point
+        out_dtype="uint8"
+    )
+    return out, params
+
+
+def _get_expected_codegen(shape, weight_shape, units, dtype,
+                          has_bias=False):
+    output_shape = (shape[0], units)
+    out_dtype = "int32" if dtype == "uint8" else "float32"
+
+    node = {
+        "op": "kernel",
+        "name": "nn.dense",
+        "inputs": [],
+        "attrs": {
+            "num_outputs": "1",
+            "out_dtype": [[out_dtype]],
+            "shape": [[list(output_shape)]],
+            "dtype": [[dtype]],
+            "units": [[str(units)]]
+        }
+    }
+
+    inputs = [{
+        "op": "input",
+        "name": "",
+        "attrs": {
+            "shape": [[list(shape)]],
+            "dtype": [[str(dtype)]]
+        }}, {
+        "op": "const",
+        "name": "",
+        "attrs": {
+            "shape": [[list(weight_shape)]],
+            "dtype": [[str(dtype)]]
+        }}]
+
+    # qnn.dense params, input and kernel
+    if dtype == "uint8":
+        node["name"] = "qnn.dense"
+        for param_dtype in ["int32", "float32"]:
+            for _ in range(2):
+                inputs.append({
+                    "op": "const",
+                    "name": "",
+                    "attrs": {
+                        "shape": [[[]]],
+                        "dtype": [[param_dtype]]
+                    }
+                })
+
+    if has_bias:
+        bias_dtype = "int32" if dtype == "uint8" else "float32"
+        inputs.append({
+            "op": "const",
+            "name": "",
+            "attrs": {
+                "shape": [[[weight_shape[0]]]],
+                "dtype": [[bias_dtype]]}
+        })
+
+    # qnn.dense params, output
+    if dtype == "uint8":
+        for param_dtype in ["float32", "int32"]:
+            inputs.append({
+                "op": "const",
+                "name": "",
+                "attrs": {
+                    "shape": [[[]]],
+                    "dtype": [[param_dtype]]
+                }
+            })
+
+    input_idx = 0
+    for _ in range(len(inputs)):
+        node["inputs"].append([input_idx, 0, 0])
+        input_idx += 1
+    node["attrs"]["num_inputs"] = str(len(inputs))
+    inputs.append(node)
+    return inputs
+
+
+def test_dense():
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    np.random.seed(0)
+
+    dtype = ["float32"]
+    shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
+    composite = [False, True]
+    trials = generate_trials([dtype, shape, composite], 3)
+
+    for dtype, (shape, weight_shape, units), composite in trials:
+        outputs = []
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))
+        }
+        func, params = _get_model(shape, weight_shape, units, dtype, var_names=iter(inputs),
+                                  has_bias=composite)
+        for acl in [False, True]:
+            outputs.append(build_and_run(func, inputs, 1, params,
+                                         device, enable_acl=acl)[0])
+
+        config = {
+            "shape": shape,
+            "weight_shape": weight_shape,
+            "units": units,
+            "dtype": dtype,
+            "composite operators (bias)": composite
+        }
+        verify(outputs, atol=0.001, rtol=0.01, params=config)
+
+
+def test_codegen_dense():
+    if skip_codegen_test():
+        return
+
+    np.random.seed(0)
+
+    dtype = ["float32"]
+    shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
+    composite = [False, True]
+    trials = generate_trials([dtype, shape, composite], 3)
+
+    for dtype, (shape, weight_shape, units), composite in trials:
+        inputs = {"a"}
+
+        args = (shape, weight_shape, units, dtype)
+
+        func, params = _get_model(*args, var_names=iter(inputs),
+                                  has_bias=composite)
+        exp_codegen = _get_expected_codegen(*args, has_bias=composite)
+        verify_codegen(func, exp_codegen, 1)
+
+
+def test_qnn_dense():
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    np.random.seed(0)
+
+    dtype = ["uint8"]
+    shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
+    composite = [False, True]
+    trials = generate_trials([dtype, shape, composite], 3)
+
+    for dtype, (shape, weight_shape, units), composite in trials:
+        outputs = []
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))
+        }
+        input_zp = 100
+        input_sc = 0.5
+        kernel_zp = 50
+        kernel_sc = 0.03
+        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
+                                               kernel_zp, kernel_sc,
+                                               weight_shape[0], weight_shape[1])
+
+        func, params = _get_qnn_model(shape, weight_shape, units, dtype,
+                                      input_zp, input_sc, kernel_zp,
+                                      kernel_sc, output_zp, output_sc,
+                                      var_names=iter(inputs), has_bias=composite)
+
+        for acl in [False, True]:
+            outputs.append(build_and_run(func, inputs, 1, params,
+                                         device, enable_acl=acl)[0])
+
+        config = {
+            "shape": shape,
+            "weight_shape": weight_shape,
+            "units": units,
+            "dtype": dtype,
+            "composite operators (bias)": composite,
+            "input scale": input_sc,
+            "input zero point": input_zp,
+            "kernel scale": kernel_sc,
+            "kernel zero point": kernel_zp,
+            "output scale": output_sc,
+            "output zero point": output_zp
+        }
+        verify(outputs, atol=1, rtol=0, params=config, verify_saturation=True)
+
+
+def test_codegen_qnn_dense():
+    if skip_codegen_test():
+        return
+
+    np.random.seed(0)
+
+    dtype = ["uint8"]
+    shape = [((1, 128), (16, 128), 16), ((32, 32), (32, 32), 32), ((1, 64), (1, 64), 1)]
+    composite = [False, True]
+    trials = generate_trials([dtype, shape, composite], 3)
+
+    for dtype, (shape, weight_shape, units), composite in trials:
+        inputs = {"a"}
+        args = (shape, weight_shape, units, dtype)
+
+        input_zp = 100
+        input_sc = 0.5
+        kernel_zp = 25
+        kernel_sc = 0.03
+        output_zp, output_sc = _get_qnn_params(input_zp, input_sc,
+                                               kernel_zp, kernel_sc,
+                                               weight_shape[0], weight_shape[1])
+
+        func, params = _get_qnn_model(*args, var_names=iter(inputs),
+                                      input_zp=input_zp, input_sc=input_sc,
+                                      kernel_zp=kernel_zp, kernel_sc=kernel_sc,
+                                      output_zp=output_zp, output_sc=output_sc,
+                                      has_bias=composite)
+        exp_codegen = _get_expected_codegen(*args, has_bias=composite)
+        verify_codegen(func, exp_codegen, 1)
+
+
+if __name__ == "__main__":
+    test_dense()
+    test_qnn_dense()
+    test_codegen_dense()
+    test_codegen_qnn_dense()
index 1ba6ca7..ceef179 100644 (file)
@@ -93,7 +93,7 @@ def test_vgg16():
         return mod, params, inputs
 
     _build_and_run_network(*get_model(), device=device,
-                           tvm_ops=10, acl_partitions=18,
+                           tvm_ops=4, acl_partitions=21,
                            atol=0.002, rtol=0.01)
 
 
index 4d48f79..c9ae1d9 100644 (file)
@@ -100,7 +100,7 @@ def test_pooling():
                     "dtype": dtype,
                     "padding": pad
                 }
-                verify(outputs, atol=atol, rtol=rtol, params=params)
+                verify(outputs, atol=atol, rtol=rtol, params=params, verify_saturation=True)
 
 
 def test_codegen_pooling():