[BYOC][ACL] Improved pooling support (#6248)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Thu, 27 Aug 2020 03:50:38 +0000 (04:50 +0100)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 03:50:38 +0000 (20:50 -0700)
* [BYOC][ACL] Improved pooling support

Adds support in ACL for the following relay pooling operators and composite functions:
  * nn.avg_pool2d (fp32), cast + nn.avg_pool2d(uint8) + cast => AVG pool
  * nn.global_max_pool2d => Global MAX pool
  * nn.global_avg_pool2d, cast + nn.global_avg_pool2d(uint8) + cast => Global AVG pool
  * power(2) + nn.avg_pool2d + sqrt => L2 pooling (for fp32 only)

Tests updated to reflect these changes.

Change-Id: I1644b67b60ebb252344eb9695a521d2d958c724e

* Address comments

Change-Id: Ibe8a61b4c42da246ce54701c89ea985b423c8f83

* Fix not checking output saturation

Change-Id: Ia6f3d9db31cfb8c417d8556d29961210fea418b2

* Use defined set of trials

Change-Id: Ib180e3a0cbb84d6fa00c7e1994f58cb62662db15

* Rebase master

Change-Id: I5c932751cd38da06d6f2b397be5d8ab7fdeb169f

docs/deploy/arm_compute_lib.rst
python/tvm/relay/op/contrib/arm_compute_lib.py
src/relay/backend/contrib/arm_compute_lib/codegen.cc
src/runtime/contrib/arm_compute_lib/acl_runtime.cc
src/runtime/contrib/arm_compute_lib/acl_utils.cc
src/runtime/contrib/arm_compute_lib/acl_utils.h
tests/python/contrib/test_arm_compute_lib/infrastructure.py
tests/python/contrib/test_arm_compute_lib/test_conv2d.py
tests/python/contrib/test_arm_compute_lib/test_network.py
tests/python/contrib/test_arm_compute_lib/test_pooling.py
tests/python/contrib/test_arm_compute_lib/test_reshape.py

index 26b42ae..e3399c5 100644 (file)
@@ -188,31 +188,50 @@ An example configuration for `test_config.json`:
 
 Operator support
 ----------------
-+--------------+-------------------------------------------------------------------------+
-| Relay Node   | Remarks                                                                 |
-+==============+=========================================================================+
-| nn.conv2d    | fp32:                                                                   |
-|              |   Simple: nn.conv2d                                                     |
-|              |   Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?                 |
-|              |                                                                         |
-|              | (only groups = 1 supported)                                             |
-+--------------+-------------------------------------------------------------------------+
-| qnn.conv2d   | uint8:                                                                  |
-|              |   Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize |
-|              |                                                                         |
-|              | (only groups = 1 supported)                                             |
-+--------------+-------------------------------------------------------------------------+
-| nn.dense     | fp32:                                                                   |
-|              |   Simple: nn.dense                                                      |
-|              |   Composite: nn.dense, nn.bias_add?                                     |
-+--------------+-------------------------------------------------------------------------+
-| qnn.dense    | uint8:                                                                  |
-|              |   Composite: qnn.dense, nn.bias_add?, qnn.requantize                    |
-+--------------+-------------------------------------------------------------------------+
-| nn.maxpool2d | fp32, uint8                                                             |
-+--------------+-------------------------------------------------------------------------+
-| reshape      | fp32, uint8                                                             |
-+--------------+-------------------------------------------------------------------------+
++----------------------+-------------------------------------------------------------------------+
+| Relay Node           | Remarks                                                                 |
++======================+=========================================================================+
+| nn.conv2d            | fp32:                                                                   |
+|                      |   Simple: nn.conv2d                                                     |
+|                      |   Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?                 |
+|                      |                                                                         |
+|                      | (only groups = 1 supported)                                             |
++----------------------+-------------------------------------------------------------------------+
+| qnn.conv2d           | uint8:                                                                  |
+|                      |   Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize |
+|                      |                                                                         |
+|                      | (only groups = 1 supported)                                             |
++----------------------+-------------------------------------------------------------------------+
+| nn.dense             | fp32:                                                                   |
+|                      |   Simple: nn.dense                                                      |
+|                      |   Composite: nn.dense, nn.bias_add?                                     |
++----------------------+-------------------------------------------------------------------------+
+| qnn.dense            | uint8:                                                                  |
+|                      |   Composite: qnn.dense, nn.bias_add?, qnn.requantize                    |
++----------------------+-------------------------------------------------------------------------+
+| nn.max_pool2d        | fp32, uint8                                                             |
++----------------------+-------------------------------------------------------------------------+
+| nn.global_max_pool2d | fp32, uint8                                                             |
++----------------------+-------------------------------------------------------------------------+
+| nn.avg_pool2d        | fp32:                                                                   |
+|                      |    Simple: nn.avg_pool2d                                                |
+|                      |                                                                         |
+|                      | uint8:                                                                  |
+|                      |    Composite: cast(int32), nn.avg_pool2d, cast(uint8)                   |
++----------------------+-------------------------------------------------------------------------+
+| nn.global_avg_pool2d | fp32:                                                                   |
+|                      |    Simple: nn.global_avg_pool2d                                         |
+|                      |                                                                         |
+|                      | uint8:                                                                  |
+|                      |    Composite: cast(int32), nn.avg_pool2d, cast(uint8)                   |
++----------------------+-------------------------------------------------------------------------+
+| power(of 2) +        | A special case for L2 pooling.                                          |
+| nn.avg_pool2d +      |                                                                         |
+| sqrt                 | fp32:                                                                   |
+|                      |    Composite: power(of 2), nn.avg_pool2d, sqrt                          |
++----------------------+-------------------------------------------------------------------------+
+| reshape              | fp32, uint8                                                             |
++----------------------+-------------------------------------------------------------------------+
 
 .. note::
     A composite operator is a series of operators that map to a single Arm Compute Library operator. You can view this
index e20f2d1..adeeeb1 100644 (file)
 # pylint: disable=invalid-name, unused-argument
 """Arm Compute Library supported operators."""
 import tvm
+from tvm.relay.expr import const
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
 
-from ...dataflow_pattern import wildcard, is_op, is_constant
+from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr
 from .register import register_pattern_table
 
 
@@ -125,6 +126,33 @@ def arm_compute_lib_pattern_table():
             pattern, wildcard(), wildcard(), is_constant(), is_constant())
         return pattern
 
+    def avg_pool2d_pattern():
+        """Creates a pattern that matches either quantized
+        avg_pool2d or quantized global_avg_pool2d.
+
+        Returns
+        -------
+        pattern : dataflow_pattern.AltPattern
+            Denotes the convolution pattern.
+        """
+        pattern = is_op('cast')(wildcard())
+        pattern = is_op('nn.avg_pool2d')(pattern) | is_op('nn.global_avg_pool2d')(pattern)
+        pattern = is_op('cast')(pattern)
+        return pattern
+
+    def l2_pool2d_pattern():
+        """Create an l2 pooling pattern from equivalent relay operators.
+
+        Returns
+        -------
+        pattern : dataflow_pattern.AltPattern
+            Denotes the convolution pattern.
+        """
+        pattern = is_op('power')(wildcard(), is_expr(const(2.0)))
+        pattern = is_op('nn.avg_pool2d')(pattern)
+        pattern = is_op('sqrt')(pattern)
+        return pattern
+
     def check_conv(extract):
         """Check conv pattern is supported by ACL."""
         call = extract
@@ -157,10 +185,27 @@ def arm_compute_lib_pattern_table():
             call = call.args[0]
         return qnn_dense(call.attrs, call.args)
 
+    def check_avg_pool2d(extract):
+        """Check average pool2d pattern is supported by ACL."""
+        if extract.attrs.dtype != "uint8":
+            return False
+        pool = extract.args[0]
+        if pool.args[0].attrs.dtype != "int32":
+            return False
+        return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True)
+
+    def check_l2_pool2d(extract):
+        """Check l2 pool2d pattern is supported by ACL."""
+        pool = extract.args[0]
+        return avg_pool2d(pool.attrs, pool.args)
+
     return [('arm_compute_lib.conv2d', conv_pattern(), check_conv),
             ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
             ('arm_compute_lib.dense', dense_pattern(), check_dense),
-            ('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense)]
+            ('arm_compute_lib.qnn_dense', qnn_dense_pattern(), check_qnn_dense),
+            ('arm_compute_lib.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
+            ('arm_compute_lib.avg_pool2d', avg_pool2d_pattern(), check_avg_pool2d),
+            ('arm_compute_lib.l2_pool2d', l2_pool2d_pattern(), check_l2_pool2d)]
 
 
 def _register_external_op_helper(op_name, supported=True):
@@ -245,3 +290,40 @@ def max_pool2d(attrs, args):
     if typ.dtype not in ["float32", "uint8"]:
         return False
     return True
+
+
+@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib")
+def avg_pool2d(attrs, args, from_quantized_composite=False):
+    """Check if the external ACL codegen for avgpool2d should be used."""
+    typ = args[0].checked_type
+    if from_quantized_composite:
+        if typ.dtype != "int32":
+            return False
+    else:
+        if typ.dtype not in ["float32"]:
+            return False
+    if attrs.layout != "NHWC":
+        return False
+    return True
+
+
+@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib")
+def global_max_pool2d(attrs, args):
+    """Check if the external ACL codegen for gloval_maxpool2d should be used."""
+    typ = args[0].checked_type
+    if typ.dtype not in ["float32", "uint8"]:
+        return False
+    if attrs.layout != "NHWC":
+        return False
+    return True
+
+
+@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib")
+def global_avg_pool2d(attrs, args):
+    """Check if the external ACL codegen for global_avgpool2d should be used."""
+    typ = args[0].checked_type
+    if typ.dtype not in ["float32"]:
+        return False
+    if attrs.layout != "NHWC":
+        return False
+    return True
index 1132b1c..087c895 100644 (file)
@@ -94,6 +94,10 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
       json_node = CreateCompositeConvJSONNode(cn);
     } else if (name == "arm_compute_lib.dense" || name == "arm_compute_lib.qnn_dense") {
       json_node = CreateCompositeDenseJSONNode(cn);
+    } else if (name == "arm_compute_lib.avg_pool2d") {
+      json_node = CreateCompositeAvgPool2DJSONNode(cn);
+    } else if (name == "arm_compute_lib.l2_pool2d") {
+      json_node = CreateCompositeL2Pool2DJSONNode(cn);
     } else {
       LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
     }
@@ -267,6 +271,62 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
     SetCallNodeAttribute(json_node, nodes.dense);
     return json_node;
   }
+
+  /*!
+   * \brief Create a JSON representation of a composite (global) average pooling operator.
+   *
+   * A composite function is only created when using the uint8 datatype for these operators.
+   *
+   * \param cn The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateCompositeAvgPool2DJSONNode(const CallNode* cn) {
+    const auto* fn = cn->op.as<FunctionNode>();
+    CHECK(fn);
+    const auto* cast = fn->body.as<CallNode>();
+    CHECK(cast);
+    const auto* avg_pool = cast->args[0].as<CallNode>();
+    CHECK(avg_pool);
+    const auto* avg_pool_op = avg_pool->op.as<OpNode>();
+    CHECK(avg_pool_op);
+    const std::string name = avg_pool_op->name;
+
+    std::vector<JSONGraphNodeEntry> inputs;
+    inputs.push_back(VisitExpr(cn->args[0])[0]);
+    auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+    SetCallNodeAttribute(json_node, avg_pool);
+    return json_node;
+  }
+
+  /*!
+   * \brief Create a JSON representation of a composite L2 pooling operator.
+   *
+   * \note Relay does not have an operator for L2 pooling, instead we can create
+   * an equivalent from power(2) + nn.avg_pool2d + sqrt.
+   *
+   * \param cn The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateCompositeL2Pool2DJSONNode(const CallNode* cn) {
+    const std::string name = "nn.l2_pool2d";
+    const auto* fn = cn->op.as<FunctionNode>();
+    CHECK(fn);
+    const auto* sqrt = fn->body.as<CallNode>();
+    CHECK(sqrt);
+    const auto* avg_pool = sqrt->args[0].as<CallNode>();
+    CHECK(avg_pool);
+    const auto* pow = avg_pool->args[0].as<CallNode>();
+    CHECK(pow);
+    const auto* exponent = pow->args[1].as<ConstantNode>();
+    CHECK(exponent);
+    CHECK_EQ(*static_cast<float*>(exponent->data->data), 2) << "Exponent must be 2 for L2 pooling";
+
+    std::vector<JSONGraphNodeEntry> inputs;
+    inputs.push_back(VisitExpr(cn->args[0])[0]);
+    auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+    SetCallNodeAttribute(json_node, avg_pool);
+    return json_node;
+  }
 };
 
 /*!
index f62420a..f2d2fca 100644 (file)
@@ -132,8 +132,11 @@ class ACLRuntime : public JSONRuntimeBase {
         } else if ("nn.dense" == op_name || "qnn.dense" == op_name) {
           CreateFullyConnectedLayer(&layer_, node, mm);
           num_pools++;
-        } else if ("nn.max_pool2d" == op_name) {
+        } else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name ||
+                   "nn.l2_pool2d" == op_name) {
           CreatePoolingLayer(&layer_, node);
+        } else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" == op_name) {
+          CreateGlobalPoolingLayer(&layer_, node);
         } else if ("reshape" == op_name) {
           CreateReshapeLayer(&layer_, node);
         } else {
@@ -308,7 +311,7 @@ class ACLRuntime : public JSONRuntimeBase {
   /*!
    * \brief Create a pooling layer.
    *
-   * \note Currently only maxpool is supported.
+   * \note Currently max_pool2d, avg_pool2d and L2 pooling are supported.
    *
    * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
    * \param node The JSON representation of the operator.
@@ -316,22 +319,65 @@ class ACLRuntime : public JSONRuntimeBase {
   void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node) {
     std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("padding");
     std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
-    arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides);
+    bool ceil_mode = std::stoi(node.GetAttr<std::vector<std::string>>("ceil_mode")[0]);
+    arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides, ceil_mode);
 
     auto attr_pool_size = node.GetAttr<std::vector<std::string>>("pool_size");
     int pool_size_h = std::stoi(attr_pool_size[0]);
     int pool_size_w = std::stoi(attr_pool_size[1]);
 
+    // Only applies to average pool and l2 pool.
+    // ACL exclude pad option is inverse to Relays include pad option.
+    bool exclude_pad = false;
+    if (node.HasAttr("count_include_pad")) {
+      int count_include_pad =
+          std::stoi(node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
+      exclude_pad = !count_include_pad;
+    }
+
     arm_compute::PoolingType pool_type;
     if (node.GetOpName() == "nn.max_pool2d") {
       pool_type = arm_compute::PoolingType::MAX;
+    } else if (node.GetOpName() == "nn.avg_pool2d") {
+      pool_type = arm_compute::PoolingType::AVG;
+    } else if (node.GetOpName() == "nn.l2_pool2d") {
+      pool_type = arm_compute::PoolingType::L2;
     } else {
       LOG(FATAL) << "Pooling type not supported";
     }
 
     arm_compute::PoolingLayerInfo pool_info =
         arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w),
-                                      arm_compute::DataLayout::NHWC, pad_stride_info);
+                                      arm_compute::DataLayout::NHWC, pad_stride_info, exclude_pad);
+
+    layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
+    layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
+
+    auto function = std::make_shared<arm_compute::NEPoolingLayer>();
+    function->configure(&layer->inputs[0], &layer->outputs[0], pool_info);
+    layer->function = function;
+  }
+
+  /*!
+   * \brief Create a global pooling layer.
+   *
+   * \note Currently global_max_pool2d and global_avg_pool2d are supported.
+   *
+   * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
+   * \param node The JSON representation of the operator.
+   */
+  void CreateGlobalPoolingLayer(CachedLayer* layer, const JSONGraphNode& node) {
+    arm_compute::PoolingType pool_type;
+    if (node.GetOpName() == "nn.global_max_pool2d") {
+      pool_type = arm_compute::PoolingType::MAX;
+    } else if (node.GetOpName() == "nn.global_avg_pool2d") {
+      pool_type = arm_compute::PoolingType::AVG;
+    } else {
+      LOG(FATAL) << "Pooling type not supported";
+    }
+
+    arm_compute::PoolingLayerInfo pool_info =
+        arm_compute::PoolingLayerInfo(pool_type, arm_compute::DataLayout::NHWC);
 
     layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
     layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
index 98c9cda..59c941d 100644 (file)
@@ -81,9 +81,11 @@ std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeACLMemoryManager() {
 }
 
 arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
-                                            const std::vector<std::string>& stride) {
+                                            const std::vector<std::string>& stride,
+                                            bool ceil_mode) {
   int pad_0 = 0, pad_1 = 0, pad_2 = 0, pad_3 = 0;
   int stride_0 = std::stoi(stride[0]), stride_1 = std::stoi(stride[1]);
+  auto dimensions_rounding = arm_compute::DimensionRoundingType::FLOOR;
   size_t size = pad.size();
   if (size == 1) {
     int pad_v = std::stoi(pad[0]);
@@ -109,8 +111,12 @@ arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
     LOG(FATAL) << "Unsupported padding dimensions";
   }
 
+  if (ceil_mode) {
+    dimensions_rounding = arm_compute::DimensionRoundingType::CEIL;
+  }
+
   return arm_compute::PadStrideInfo(stride_0, stride_1, pad_0, pad_1, pad_2, pad_3,
-                                    arm_compute::DimensionRoundingType::FLOOR);
+                                    dimensions_rounding);
 }
 
 arm_compute::DataType MakeACLDataType(const DLDataType& data_type) {
index 80c6f0b..576ed91 100644 (file)
@@ -93,10 +93,12 @@ std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeACLMemoryManager();
  *
  * \param pad The pad vector.
  * \param stride The stride vector.
+ * \param ceil_mode Dimensions rounding.
  * \return arm_compute::PadStrideInfo
  */
 arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad,
-                                            const std::vector<std::string>& stride);
+                                            const std::vector<std::string>& stride,
+                                            bool ceil_mode = false);
 
 /*!
  * \brief Convert DLDataType to arm_compute::DataType.
index 4e930e2..cc4818e 100644 (file)
@@ -181,9 +181,20 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti
 
 
 def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, no_runs=1,
-                  tvm_ops=0, acl_partitions=1):
+                  tvm_ops=0, acl_partitions=1, config=None):
     """Build and run the relay module."""
-    lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions)
+    if config is None:
+        config = {}
+
+    try:
+        lib = build_module(mod, device.target, params, enable_acl, tvm_ops, acl_partitions)
+    except Exception as e:
+        err_msg = "The module could not be built.\n"
+        if config:
+            err_msg += f"The test failed with the following parameters: {config}\n"
+        err_msg += str(e)
+        raise Exception(err_msg)
+
     lib = update_lib(lib, device.device, device.cross_compile)
     gen_module = graph_runtime.GraphModule(lib['default'](device.device.cpu(0)))
     gen_module.set_input(**inputs)
@@ -208,28 +219,28 @@ def update_lib(lib, device, cross_compile):
     return lib
 
 
-def verify(answers, atol, rtol, verify_saturation=False, params=None):
+def verify(answers, atol, rtol, verify_saturation=False, config=None):
     """Compare the array of answers. Each entry is a list of outputs."""
-    if params is None:
-        params = {}
+    if config is None:
+        config = {}
 
     if len(answers) < 2:
         raise RuntimeError(
             f"No results to compare: expected at least two, found {len(answers)}")
     for answer in zip_longest(*answers):
         for outs in combinations(answer, 2):
-            if verify_saturation:
-                assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \
-                    "Output is saturated: {}".format(outs[0])
-                assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \
-                    "Output is saturated: {}".format(outs[0])
             try:
+                if verify_saturation:
+                    assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \
+                        "Output is saturated: {}".format(outs[0])
+                    assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \
+                        "Output is saturated: {}".format(outs[0])
                 tvm.testing.assert_allclose(
                    outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
             except AssertionError as e:
                 err_msg = "Results not within the acceptable tolerance.\n"
-                if params:
-                    err_msg += f"The test failed with the following parameters: {params}\n"
+                if config:
+                    err_msg += f"The test failed with the following parameters: {config}\n"
                 err_msg += str(e)
                 raise AssertionError(err_msg)
 
index 555cbe1..37575cc 100644 (file)
@@ -276,7 +276,7 @@ def test_conv2d():
                                          params, device,
                                          enable_acl=acl)[0])
 
-        params = {
+        config = {
             "shape": shape,
             "groups": groups,
             "kernel size": (kernel_h, kernel_w),
@@ -286,7 +286,7 @@ def test_conv2d():
             "out channels": out_channels,
             "composite operators (pad, bias, activation)": composite
         }
-        verify(outputs, atol=0.002, rtol=0.01, params=params)
+        verify(outputs, atol=0.002, rtol=0.01, config=config)
 
 
 def test_codegen_conv2d():
@@ -380,7 +380,7 @@ def test_qnn_conv2d():
                                          params, device,
                                          enable_acl=acl)[0])
 
-        params = {
+        config = {
             "shape": shape,
             "groups": groups,
             "kernel size": (kernel_h, kernel_w),
@@ -396,15 +396,13 @@ def test_qnn_conv2d():
             "output scale": output_sc,
             "output zero point": output_zp
         }
-        verify(outputs, atol=1, rtol=0, params=params, verify_saturation=True)
+        verify(outputs, atol=1, rtol=0, config=config, verify_saturation=True)
 
 
 def test_codegen_qnn_conv2d():
     if skip_codegen_test():
         return
 
-    np.random.seed(0)
-
     kernel_hs = [1, 2, 3, 5]
     kernel_ws = [1, 2, 3, 5]
     pad = [(1, 1), (2, 2), (2, 1)]
index 18cac33..e1bb83b 100644 (file)
@@ -116,7 +116,7 @@ def test_mobilenet():
         return mod, params, inputs
 
     _build_and_run_network(*get_model(), device=device,
-                           tvm_ops=74, acl_partitions=17,
+                           tvm_ops=73, acl_partitions=18,
                            atol=0.002, rtol=0.01)
 
 
@@ -144,7 +144,7 @@ def test_quantized_mobilenet():
         return mod, params, inputs
 
     _build_and_run_network(*get_model(), device=device,
-                           tvm_ops=45, acl_partitions=16,
+                           tvm_ops=42, acl_partitions=17,
                            atol=8, rtol=0)
 
 
index 32176af..c104a06 100644 (file)
@@ -26,26 +26,70 @@ from .infrastructure import skip_runtime_test, skip_codegen_test, build_and_run,
 from .infrastructure import Device
 
 
-def _get_model(shape, dtype, typef, sizes, strides, padding,
-               ceil_mode, var_names):
+def _calculate_output_shape(shape, sizes, padding, strides):
+    """Calculate pooling output shape."""
+    output_height = ((shape[1] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1
+    output_width = ((shape[2] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1
+    return 1, int(output_height), int(output_width), shape[3]
+
+
+def _get_pooling_model(shape, dtype, typef, sizes, strides, padding,
+                       ceil_mode, count_include_pad, var_names):
+    """Return a model and any parameters it may have."""
+    if len(padding) == 2:
+        padding = (padding[0], padding[1], padding[0], padding[1])
+    out = relay.var(next(var_names), shape=shape, dtype=dtype)
+
+    if typef == "nn.max_pool2d":
+        out = relay.nn.max_pool2d(out, pool_size=sizes, strides=strides, padding=padding,
+                                  ceil_mode=ceil_mode, layout="NHWC")
+    elif typef == "nn.avg_pool2d":
+        if dtype == "uint8":
+            out = relay.cast(out, 'int32')
+        out = relay.nn.avg_pool2d(out, pool_size=sizes, strides=strides, padding=padding,
+                                  ceil_mode=ceil_mode, count_include_pad=count_include_pad,
+                                  layout="NHWC")
+        if dtype == "uint8":
+            out = relay.cast(out, 'uint8')
+    elif typef == "nn.l2_pool2d":
+        out = relay.power(out, relay.const(2.0))
+        out = relay.nn.avg_pool2d(out, pool_size=sizes, strides=strides, padding=padding,
+                                  ceil_mode=ceil_mode, count_include_pad=count_include_pad,
+                                  layout="NHWC")
+        out = relay.sqrt(out)
+    else:
+        raise ValueError("Function not supported")
+
+    return out
+
+
+def _get_global_pooling_model(shape, dtype, typef, var_names):
     """Return a model and any parameters it may have."""
-    var = relay.var(next(var_names), shape=shape, dtype=dtype)
-    pool = typef(var, pool_size=sizes, strides=strides, padding=padding,
-                 ceil_mode=ceil_mode, layout="NHWC")
-    return pool
+    out = relay.var(next(var_names), shape=shape, dtype=dtype)
+
+    if typef == "nn.global_max_pool2d":
+        out = relay.nn.global_max_pool2d(out, layout="NHWC")
+    elif typef == "nn.global_avg_pool2d":
+        if dtype == "uint8":
+            out = relay.cast(out, 'int32')
+        out = relay.nn.global_avg_pool2d(out, layout="NHWC")
+        if dtype == "uint8":
+            out = relay.cast(out, 'uint8')
+    else:
+        raise ValueError("Function not supported")
+
+    return out
 
 
-def _get_expected_codegen(shape, dtype, typef, sizes, strides,
-                          padding, ceil_mode):
+def _get_expected_pooling_codegen(shape, dtype, typef, sizes, strides,
+                                  padding, ceil_mode, count_include_pad):
     if len(padding) == 2:
-        padding = (padding[1], padding[1], padding[0], padding[0])
-    output_height = ((shape[1] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1
-    output_width = ((shape[2] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1
-    output_shape = (1, int(output_height), int(output_width), shape[3])
+        padding = (padding[0], padding[1], padding[0], padding[1])
+    output_shape = _calculate_output_shape(shape, sizes, padding, strides)
 
     node = {
         "op": "kernel",
-        "name": "nn.max_pool2d",
+        "name": typef,
         "inputs": [[0, 0, 0]],
         "attrs": {
             "num_inputs": "1",
@@ -60,6 +104,30 @@ def _get_expected_codegen(shape, dtype, typef, sizes, strides,
         },
     }
 
+    if typef == "nn.avg_pool2d" or typef == "nn.l2_pool2d":
+        node["attrs"]["count_include_pad"] = [["1" if count_include_pad else "0"]]
+
+    input = {
+        "op": "input",
+        "name": "",
+        "attrs": {"shape": [[list(shape)]], "dtype": [[dtype]]}}
+    return [input, node]
+
+
+def _get_expected_global_pooling_codegen(shape, dtype, typef):
+    node = {
+        "op": "kernel",
+        "name": typef,
+        "inputs": [[0, 0, 0]],
+        "attrs": {
+            "num_inputs": "1",
+            "num_outputs": "1",
+            "layout": [["NHWC"]],
+            "shape": [[[1, 1, 1, shape[3]]]],
+            "dtype": [[dtype]]
+        }
+    }
+
     input = {
         "op": "input",
         "name": "",
@@ -76,53 +144,160 @@ def test_pooling():
     device = Device()
     np.random.seed(0)
 
-    for dtype, low, high, atol, rtol in [("float32", -127, 128, 0.001, 0.001), ("uint8", 0, 255, 0, 0)]:
-        for size in [(2, 2), (3, 3)]:
-            for stride in [(2, 2)]:
-                shape = (1, size[0] + stride[0] * 5,
-                         size[1] + stride[1] * 5, 16)
-                pad = (0, 0)
-
-                inputs = {
-                    "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
-                }
-
-                outputs = []
-                func = _get_model(shape, dtype, relay.nn.max_pool2d, size,
-                                  stride, pad, True, iter(inputs))
-                for acl in [False, True]:
-                    outputs.append(build_and_run(func, inputs, 1, None, device,
-                                                 enable_acl=acl)[0])
-
-                params = {
-                    "size": size,
-                    "stride": stride,
-                    "shape": shape,
-                    "pooling type": "max",
-                    "dtype": dtype,
-                    "padding": pad
-                }
-                verify(outputs, atol=atol, rtol=rtol, params=params, verify_saturation=True)
+    fp32_dtype = ("float32", -127, 128, 0.001, 0.001)
+    uint8_dtype = ("uint8", 0, 255, 1, 0)
+
+    trials = [["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+              ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)],
+              ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)],
+              ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+              ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)],
+              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
+              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+              ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)],
+              ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
+              ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (16, 16, 16)],
+              ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)],
+              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)]]
+
+    for typef, (dtype, low, high, atol, rtol), size, stride, pad, ceil_mode, count_include_pad, \
+            input_shape in trials:
+        shape = (1, *input_shape)
+        outputs = []
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
+        }
+
+        func = _get_pooling_model(shape, dtype, typef, size,
+                                  stride, pad, ceil_mode, count_include_pad, iter(inputs))
+
+        config = {
+            "size": size,
+            "stride": stride,
+            "shape": shape,
+            "pooling type": typef,
+            "dtype": dtype,
+            "padding": pad,
+            "ceil_mode": ceil_mode,
+            "count_include_pad": count_include_pad
+        }
+        verify_saturation = True if dtype == "uint8" else False
+
+        for acl in [False, True]:
+            outputs.append(build_and_run(func, inputs, 1, None, device,
+                                         enable_acl=acl, config=config)[0])
+
+        verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation)
+
+
+def test_global_pooling():
+    Device.load("test_config.json")
+
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    np.random.seed(0)
+
+    fp32_dtype = ("float32", -127, 128, 0.001, 0.001)
+    uint8_dtype = ("uint8", 0, 255, 1, 0)
+
+    trials = [["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)],
+              ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)],
+              ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)],
+              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)],
+              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
+              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)]]
+
+    for typef, (dtype, low, high, atol, rtol), input_shape in trials:
+        shape = (1, *input_shape)
+        outputs = []
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(low, high, shape).astype(dtype)),
+        }
+
+        func = _get_global_pooling_model(shape, dtype, typef, iter(inputs))
+
+        config = {
+            "shape": shape,
+            "pooling type": typef,
+            "dtype": dtype,
+        }
+        verify_saturation = True if dtype == "uint8" else False
+
+        for acl in [False, True]:
+            outputs.append(build_and_run(func, inputs, 1, None, device,
+                                         enable_acl=acl, config=config)[0])
+
+        verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation)
 
 
 def test_codegen_pooling():
     if skip_codegen_test():
         return
 
-    inputs = {"a"}
+    fp32_dtype = ("float32", -127, 128)
+    uint8_dtype = ("uint8", 0, 255)
+
+    trials = [["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+              ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)],
+              ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)],
+              ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+              ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)],
+              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)],
+              ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)],
+              ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)],
+              ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)],
+              ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)],
+              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (15, 15, 16)],
+              ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)],
+              ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)]]
+
+    for typef, (dtype, low, high), size, stride, pad, ceil_mode, count_include_pad, \
+            input_shape in trials:
+        shape = (1, *input_shape)
+        inputs = {"a"}
+        args = (shape, dtype, typef, size,
+                stride, pad, False, False)
+        func = _get_pooling_model(*args, iter(inputs))
+        exp_codegen = _get_expected_pooling_codegen(*args)
+        verify_codegen(func, exp_codegen, 1)
+
+
+def test_codegen_global_pooling():
+    if skip_codegen_test():
+        return
+
+    fp32_dtype = ("float32", -127, 128)
+    uint8_dtype = ("uint8", 0, 255)
+
+    trials = [["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_max_pool2d", fp32_dtype, (9, 9, 16)],
+              ["nn.global_max_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_max_pool2d", uint8_dtype, (8, 8, 16)],
+              ["nn.global_max_pool2d", uint8_dtype, (9, 9, 16)],
+              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_avg_pool2d", fp32_dtype, (8, 8, 16)],
+              ["nn.global_avg_pool2d", fp32_dtype, (9, 9, 16)],
+              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)],
+              ["nn.global_avg_pool2d", uint8_dtype, (8, 8, 16)]]
 
-    for dtype in ["float32", "uint8"]:
-        for size in [(2, 2), (3, 3)]:
-            for stride in [(2, 2)]:
-                shape = (1, size[0] + stride[0] * 5,
-                         size[1] + stride[1] * 5, 16)
-                args = (shape, dtype, relay.nn.max_pool2d, size,
-                        stride, (0, 0), True)
-                func = _get_model(*args, iter(inputs))
-                exp_codegen = _get_expected_codegen(*args)
-                verify_codegen(func, exp_codegen, 1)
+    for typef, (dtype, low, high), input_shape in trials:
+        shape = (1, *input_shape)
+        inputs = {"a"}
+        args = (shape, dtype, typef)
+        func = _get_global_pooling_model(*args, iter(inputs))
+        exp_codegen = _get_expected_global_pooling_codegen(*args)
+        verify_codegen(func, exp_codegen, 1)
 
 
 if __name__ == "__main__":
     test_pooling()
+    test_global_pooling()
     test_codegen_pooling()
+    test_codegen_global_pooling()
index 38694e8..b6a8754 100644 (file)
@@ -78,12 +78,12 @@ def test_reshape():
                 outputs.append(build_and_run(func, inputs, 1, None, device,
                                              enable_acl=acl)[0])
 
-            params = {
+            config = {
                 "new shape": inputs["a"].shape,
                 "shape": new_shape,
                 "dtype": dtype,
             }
-            verify(outputs, atol=1e-7, rtol=1e-7, params=params)
+            verify(outputs, atol=1e-7, rtol=1e-7, config=config)
 
 
 def test_codegen_reshape():