[Relay][TOPI] Gluncv SSD support on the GPU (#2784)
authorLeyuan Wang <laurawly@gmail.com>
Mon, 29 Apr 2019 03:47:21 +0000 (20:47 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 29 Apr 2019 03:47:21 +0000 (20:47 -0700)
* ssd gluoncv gpu op updated

* ssd gluoncv gpu op updated

* tutorials and testes modified

* tutorials and testes modified

* fix lint

* fix lint

* address comment

* multibox bug fixed

* space line added

* use less threads per block

* use less threads per block

* less threads per block for get valid count

* less threads per block for get valid count

* merge with master

* Revert "less threads per block for get valid count"

This reverts commit 08896cfccc34b0b2a1646d01d01ea4cad73941c4.

* Revert "less threads per block for get valid count"

This reverts commit 08896cfccc34b0b2a1646d01d01ea4cad73941c4.

* typo fixed

* elem length made to a variable

* fix lint error

* fix lint error

* lint fixed

* bug fixed

* bug fixed

* lint fixed

* error fixed

* error fixed

* test ci

* test ci

* seperate argsort to be an independent op

* seperate argsort to be an independent op

* fix lint

* fix lint

* remove unsupported models

* typo fixed

* argsort added to realy

* solve conflicts with master

* fix lint

* fix lint

* test push

* Revert "test push"

This reverts commit 6db00883fab6cc06bddf564c926bb27c874397d8.

* fix lint error

* fix more lint

* cpu test_sort udpated

* debug ci

* nms fixed

* expose argsort to relay frontend

* test ci

* fix lint

* sort register error fixed

* fix nnvm

* nms type fixed

* adaptive pooling added to relay

* Revert "adaptive pooling added to relay"

This reverts commit 1119f1f2c055753e0cc5611627597749134c5c8c.

* fix lint

* expose argsort op

* fix lint

* fix lint

* fix lint

* sort test updated

* sort bug fixed

* nnvm error fixed

* fix argsort default data type returned to be float insteaf of int

* fix lint

* fix lint

* test fixed

* fix valid count

* fix titanx bug

* tutorial add both targets

* titanx error fixed

* try to fix CI old gpu error

* try to solve CI GPU error

* get_valid_count added

* reverse get_valid_count

* get valid count optimized

* address comments

* fix ci error

* remove unessesary block sync

* add back one sync

* address comments

* address more comments

* more comments

* move sort to be indepent algorithm

* typo fixed

* more typos

* comments addressed

* doc updated

* fix pylint

* address final comments

* apache license added

34 files changed:
docs/langref/relay_op.rst
include/tvm/relay/attrs/algorithm.h [new file with mode: 0644]
include/tvm/relay/attrs/vision.h
nnvm/include/nnvm/top/nn.h
nnvm/python/nnvm/top/vision.py
nnvm/tests/python/compiler/test_top_level4.py
python/tvm/relay/__init__.py
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/op/__init__.py
python/tvm/relay/op/_algorithm.py [new file with mode: 0644]
python/tvm/relay/op/algorithm.py [new file with mode: 0644]
python/tvm/relay/op/tensor.py
python/tvm/relay/op/transform.py
python/tvm/relay/op/vision/_vision.py
python/tvm/relay/op/vision/nms.py
src/contrib/sort/sort.cc
src/relay/op/algorithm/sort.cc [new file with mode: 0644]
src/relay/op/vision/nms.cc
tests/python/contrib/test_sort.py
tests/python/relay/test_op_level5.py
tests/python/relay/test_op_level6.py [new file with mode: 0644]
topi/python/topi/__init__.py
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/sort.py [new file with mode: 0644]
topi/python/topi/cuda/ssd/multibox.py
topi/python/topi/cuda/vision.py
topi/python/topi/generic/__init__.py
topi/python/topi/generic/sort.py [new file with mode: 0644]
topi/python/topi/sort.py [new file with mode: 0644]
topi/python/topi/vision/nms.py
topi/python/topi/vision/ssd/multibox.py
topi/tests/python/test_topi_sort.py [new file with mode: 0644]
topi/tests/python/test_topi_vision.py
tutorials/frontend/deploy_ssd_gluoncv.py

index c45e9b9..4719aba 100644 (file)
@@ -165,6 +165,14 @@ This level enables additional math and transform operators.
    tvm.relay.vision.yolo_reorg
 
 
+**Level 6: Algorithm Operators**
+
+.. autosummary::
+   :nosignatures:
+
+   tvm.relay.argsort
+
+
 **Level 10: Temporary Operators**
 
 This level support backpropagation of broadcast operators. It is temporary.
@@ -294,6 +302,11 @@ Level 5 Definitions
 .. autofunction:: tvm.relay.vision.yolo_reorg
 
 
+Level 6 Definitions
+-------------------
+.. autofunction:: tvm.relay.argsort
+
+
 Level 10 Definitions
 --------------------
 .. autofunction:: tvm.relay.broadcast_to_like
diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h
new file mode 100644 (file)
index 0000000..20f135c
--- /dev/null
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/attrs/vision.h
+ * \brief Auxiliary attributes for vision operators.
+ */
+#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
+#define TVM_RELAY_ATTRS_ALGORITHM_H_
+
+#include <tvm/attrs.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Attributes used in argsort operators */
+struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
+  int axis;
+  bool is_ascend;
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") {
+    TVM_ATTR_FIELD(axis).set_default(-1)
+      .describe("Axis along which to sort the input tensor."
+                "If not given, the flattened array is used.");
+    TVM_ATTR_FIELD(is_ascend).set_default(true)
+      .describe("Whether to sort in ascending or descending order."
+                "By default, sort in ascending order");
+    TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
+      .describe("DType of the output indices.");
+  }
+};
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_ATTRS_ALGORITHM_H_
index 2b3eb4f..11b4ebf 100644 (file)
@@ -92,6 +92,8 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
   double iou_threshold;
   bool force_suppress;
   int top_k;
+  int coord_start;
+  int score_index;
   int id_index;
   bool return_indices;
   bool invalid_to_bottom;
@@ -106,6 +108,10 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
       .describe("Suppress all detections regardless of class_id.");
     TVM_ATTR_FIELD(top_k).set_default(-1)
       .describe("Keep maximum top k detections before nms, -1 for no limit.");
+    TVM_ATTR_FIELD(coord_start).set_default(2)
+      .describe("Start index of the consecutive 4 coordinates.");
+    TVM_ATTR_FIELD(score_index).set_default(1)
+      .describe("Index of the scores/confidence of boxes.");
     TVM_ATTR_FIELD(id_index).set_default(0)
       .describe("Axis index of id.");
     TVM_ATTR_FIELD(return_indices).set_default(true)
index 424a6a0..137d8ca 100644 (file)
@@ -488,6 +488,8 @@ struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppression
   bool force_suppress;
   int top_k;
   int id_index;
+  int coord_start;
+  int score_index;
   int max_output_size;
   bool invalid_to_bottom;
   DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) {
@@ -500,6 +502,10 @@ struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppression
       .describe("Suppress all detections regardless of class_id.");
     DMLC_DECLARE_FIELD(top_k).set_default(-1)
       .describe("Keep maximum top k detections before nms, -1 for no limit.");
+    DMLC_DECLARE_FIELD(coord_start).set_default(2)
+      .describe("Start index of the consecutive 4 coordinates.");
+    DMLC_DECLARE_FIELD(score_index).set_default(1)
+      .describe("Index of the scores/confidence of boxes.");
     DMLC_DECLARE_FIELD(id_index).set_default(0)
       .describe("Axis index of id.");
     DMLC_DECLARE_FIELD(return_indices).set_default(true)
index b366429..2e18cf7 100644 (file)
@@ -94,8 +94,12 @@ def compute_nms(attrs, inputs, _):
     id_index = attrs.get_int('id_index')
     invalid_to_bottom = attrs.get_bool('invalid_to_bottom')
 
-    return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
-                                           iou_threshold, force_suppress, top_k,
-                                           id_index, return_indices, invalid_to_bottom)
+    return topi.vision.non_max_suppression(inputs[0], inputs[1],
+                                           max_output_size=max_output_size,
+                                           iou_threshold=iou_threshold,
+                                           force_suppress=force_suppress,
+                                           top_k=top_k, id_index=id_index,
+                                           return_indices=return_indices,
+                                           invalid_to_bottom=invalid_to_bottom)
 
 reg.register_pattern("non_max_suppression", OpPattern.OPAQUE)
index f8d4f5b..6911639 100644 (file)
@@ -543,14 +543,13 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1),
     if clip:
         np_out = np.clip(np_out, 0, 1)
 
-    target = "llvm"
-    ctx = tvm.cpu()
-    graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
-    m = graph_runtime.create(graph, lib, ctx)
-    m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
-    m.run()
-    out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
-    tvm.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
+    for target, ctx in ctx_list():
+        graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
+        m = graph_runtime.create(graph, lib, ctx)
+        m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
+        m.run()
+        tvm_out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
+        tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
 
 def test_multibox_prior():
     verify_multibox_prior((1, 3, 50, 50))
@@ -577,17 +576,16 @@ def test_multibox_transform_loc():
                                  [0, 0.44999999, 1, 1, 1, 1],
                                  [0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])
 
-    target = "llvm"
     dtype = "float32"
-    ctx = tvm.cpu()
-    graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
-                                                      "loc_preds": (batch_size, num_anchors * 4),
-                                                      "anchors": (1, num_anchors, 4)})
-    m = graph_runtime.create(graph, lib, ctx)
-    m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
-    m.run()
-    out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
-    tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
+    for target, ctx in ctx_list():
+        graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
+                                                          "loc_preds": (batch_size, num_anchors * 4),
+                                                          "anchors": (1, num_anchors, 4)})
+        m = graph_runtime.create(graph, lib, ctx)
+        m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
+        m.run()
+        tvm_out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
+        tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
 
 def test_non_max_suppression():
     dshape = (1, 5, 6)
@@ -607,15 +605,14 @@ def test_non_max_suppression():
                            [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
                            [-1, -1, -1, -1, -1, -1]]])
 
-    target = "llvm"
-    ctx = tvm.cpu()
-    graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
-                                        dtype={"data": "float32", "valid_count": "int32"})
-    m = graph_runtime.create(graph, lib, ctx)
-    m.set_input(**{"data": np_data, "valid_count": np_valid_count})
-    m.run()
-    out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
-    tvm.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
+    for target, ctx in ctx_list():
+        graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
+                                            dtype={"data": "float32", "valid_count": "int32"})
+        m = graph_runtime.create(graph, lib, ctx)
+        m.set_input(**{"data": np_data, "valid_count": np_valid_count})
+        m.run()
+        tvm_out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
+        tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
 
 def np_slice_like(np_data, np_shape_like, axis=[]):
     begin_idx = [0 for _ in np_data.shape]
index 2ab4ca2..80555d3 100644 (file)
@@ -36,6 +36,7 @@ from .op import Op
 from .op.reduce import *
 from .op.tensor import *
 from .op.transform import *
+from .op.algorithm import *
 from . import nn
 from . import annotation
 from . import vision
index 1218e65..f1bf678 100644 (file)
@@ -186,6 +186,13 @@ def _mx_pooling(inputs, attrs):
         'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize()))
 
 
+def _mx_adaptive_avg_pooling(inputs, attrs):
+    output_size = attrs.get_int_tuple("output_size", [])
+    if output_size != (1,):
+        raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
+    return _op.nn.global_avg_pool2d(inputs[0])
+
+
 def _mx_dropout(inputs, attrs):
     rate = attrs.get_float("p", 0.5)
     return _op.nn.dropout(inputs[0], rate=rate)
@@ -529,15 +536,6 @@ def _mx_box_nms(inputs, attrs):
     id_index = attrs.get_int('id_index', -1)
     in_format = attrs.get_str('in_format', 'corner')
     out_format = attrs.get_str('out_format', 'corner')
-    if coord_start != 2:
-        raise tvm.error.OpAttributeInvalid(
-            'Value of attribute "coord_start" must equal 2 for operator box_nms.')
-    if score_index != 1:
-        raise tvm.error.OpAttributeInvalid(
-            'Value of attribute "score_index" must equal 1 for operator box_nms.')
-    if id_index != -1 and int(id_index) != 0:
-        raise tvm.error.OpAttributeInvalid(
-            'Value of attribute "id_index" must equal either -1 or 0 for operator box_nms.')
     if in_format != 'corner':
         raise tvm.error.OpAttributeInvalid(
             'Value of attribute "in_format" must equal "corner" for operator box_nms.')
@@ -551,6 +549,8 @@ def _mx_box_nms(inputs, attrs):
                                              iou_threshold=iou_thresh,
                                              force_suppress=force_suppress,
                                              top_k=top_k,
+                                             coord_start=coord_start,
+                                             score_index=score_index,
                                              id_index=id_index,
                                              return_indices=False,
                                              invalid_to_bottom=True)
@@ -648,6 +648,15 @@ def _mx_deformable_convolution(inputs, attrs):
     return res
 
 
+def _mx_argsort(inputs, attrs):
+    assert len(inputs) == 1
+    new_attrs = {}
+    new_attrs["axis"] = attrs.get_int("axis", -1)
+    new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
+    new_attrs["dtype"] = attrs.get_str("dtype", "float32")
+    return _op.argsort(inputs[0], **new_attrs)
+
+
 # Note: due to attribute conversion constraint
 # ops in the identity set must be attribute free
 _identity_list = [
@@ -783,6 +792,7 @@ _convert_map = {
     "BlockGrad"     : _mx_BlockGrad,
     "shape_array"   : _mx_shape_array,
     "Embedding"     : _mx_embedding,
+    "argsort"       : _mx_argsort,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
     "smooth_l1"     : _mx_smooth_l1,
@@ -796,6 +806,7 @@ _convert_map = {
     "_contrib_MultiProposal" : _mx_proposal,
     "_contrib_box_nms" : _mx_box_nms,
     "_contrib_DeformableConvolution" : _mx_deformable_convolution,
+    "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
     # List of missing operators that are present in NNVMv1
     # TODO(tvm-tvm): support all operators.
     #
index fdc990e..3bea795 100644 (file)
@@ -24,6 +24,7 @@ from .op import get, register, register_schedule, register_compute, register_gra
 from .reduce import *
 from .tensor import *
 from .transform import *
+from .algorithm import *
 from . import nn
 from . import annotation
 from . import image
@@ -36,6 +37,7 @@ from . import _tensor
 from . import _tensor_grad
 from . import _transform
 from . import _reduce
+from . import _algorithm
 from ..expr import Expr
 from ..base import register_relay_node
 
diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py
new file mode 100644 (file)
index 0000000..57e7165
--- /dev/null
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"Definition of classic algorithms"
+# pylint: disable=invalid-name,unused-argument
+from __future__ import absolute_import
+
+import topi
+from topi.util import get_const_int
+from ..op import OpPattern, register_compute, register_schedule, register_pattern
+
+
+@register_schedule("argsort")
+def schedule_argsort(_, outs, target):
+    """Schedule definition of argsort"""
+    with target:
+        return topi.generic.schedule_argsort(outs)
+
+
+@register_compute("argsort")
+def compute_argsort(attrs, inputs, _, target):
+    """Compute definition of argsort"""
+    axis = get_const_int(attrs.axis)
+    is_ascend = bool(get_const_int(attrs.is_ascend))
+    dtype = str(attrs.dtype)
+    return [
+        topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
+                            dtype=dtype, flag=False)
+    ]
+
+
+register_pattern("argsort", OpPattern.OPAQUE)
diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py
new file mode 100644 (file)
index 0000000..6451eb4
--- /dev/null
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Classic algorithm operation"""
+from __future__ import absolute_import as _abs
+from . import _make
+
+def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
+    """Performs sorting along the given axis and returns an array of indicies
+    having same shape as an input array that index data in sorted order.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data tensor.
+
+    valid_count : tvm.Tensor
+        The number of valid elements to be sorted.
+
+    axis : int, optional
+        Axis long which to sort the input tensor.
+
+    is_ascend : boolean, optional
+        Whether to sort in ascending or descending order.
+
+    dtype : string, optional
+        DType of the output indices.
+
+    Returns
+    -------
+    out : relay.Expr
+        Tensor with same shape as data.
+    """
+    return _make.argsort(data, axis, is_ascend, dtype)
index bcbbe0c..3795e65 100644 (file)
@@ -710,6 +710,30 @@ def concatenate(data, axis):
     return _make.concatenate(Tuple(data), axis)
 
 
+def stack(data, axis):
+    """Join a sequence of arrays along a new axis.
+
+    Parameters
+    ----------
+    data : Union(List[relay.Expr], Tuple(relay.Expr))
+        A list of tensors.
+
+    axis : int
+        The axis in the result array along which the input arrays are stacked.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The stacked tensor.
+    """
+    data = list(data)
+    if not data:
+        raise ValueError("relay.stack requires data to be non-empty.")
+    if not isinstance(axis, int):
+        raise ValueError("For now, we only support integer axis")
+    return _make.stack(Tuple(data), axis)
+
+
 def copy(data):
     """Copy a tensor.
 
index 5489ad1..9c76b7e 100644 (file)
@@ -315,28 +315,6 @@ def arange(start, stop=None, step=1, dtype="float32"):
     return _make.arange(start, stop, step, dtype)
 
 
-def stack(data, axis):
-    """Join a sequence of arrays along a new axis.
-
-    Parameters
-    ----------
-    data : relay.Expr
-        The input data to the operator.
-
-    axis : int
-        The axis in the result array along which the input arrays are stacked.
-
-    .. note::
-        Each array in the input array sequence must have the same shape.
-
-    Returns
-    -------
-    ret : relay.Expr
-        The computed result.
-    """
-    return _make.stack(data, axis)
-
-
 def repeat(data, repeats, axis):
     """Repeats elements of an array.
     By default, repeat flattens the input array into 1-D and then repeats the elements.
@@ -698,5 +676,4 @@ def gather_nd(data, indices):
         indices = [[0, 1], [1, 0]]
         relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
     """
-
     return _make.gather_nd(data, indices)
index bcf7e06..8c8c4cd 100644 (file)
@@ -103,12 +103,15 @@ def compute_nms(attrs, inputs, _, target):
     iou_threshold = get_const_float(attrs.iou_threshold)
     force_suppress = bool(get_const_int(attrs.force_suppress))
     top_k = get_const_int(attrs.top_k)
+    coord_start = get_const_int(attrs.coord_start)
+    score_index = get_const_int(attrs.score_index)
     id_index = get_const_int(attrs.id_index)
     invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
     return [
         topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
                                         iou_threshold, force_suppress, top_k,
-                                        id_index, return_indices, invalid_to_bottom)
+                                        coord_start, score_index, id_index,
+                                        return_indices, invalid_to_bottom)
     ]
 
 
index b8f9bf1..ab34eb6 100644 (file)
@@ -49,6 +49,8 @@ def non_max_suppression(data,
                         iou_threshold=0.5,
                         force_suppress=False,
                         top_k=-1,
+                        coord_start=2,
+                        score_index=1,
                         id_index=0,
                         return_indices=True,
                         invalid_to_bottom=False):
@@ -77,6 +79,12 @@ def non_max_suppression(data,
     top_k : int, optional
         Keep maximum top k detections before nms, -1 for no limit.
 
+    coord_start : int, optional
+        The starting index of the consecutive 4 coordinates.
+
+    score_index : int, optional
+        Index of the scores/confidence of boxes.
+
     id_index : int, optional
         index of the class categories, -1 to disable.
 
@@ -93,4 +101,5 @@ def non_max_suppression(data,
     """
     return _make.non_max_suppression(data, valid_count, max_output_size,
                                      iou_threshold, force_suppress, top_k,
-                                     id_index, return_indices, invalid_to_bottom)
+                                     coord_start, score_index, id_index,
+                                     return_indices, invalid_to_bottom)
index fd0107c..cf25e89 100644 (file)
@@ -46,20 +46,20 @@ bool CompareDescend(const std::pair<int32_t, DType>& lhs,
 }
 
 
-// Argsort implemented C library sort.
+// Argsort implemented C library sort for nms.
 // Return indices of sorted tensor.
 // By default, the last axis will be used to sort.
 // sort_num specify the number of elements to be sorted.
 // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
 // and sort axis is dk. sort_num should have dimension of
 // (d1, d2, ..., d(k-1), d(k+1), ..., dn).
-TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
   DLTensor *input = args[0];
   DLTensor *sort_num = args[1];
   DLTensor *output = args[2];
   int32_t axis = args[3];
-  bool is_descend = args[4];
+  bool is_ascend = args[4];
 
   auto dtype = input->dtype;
   auto data_ptr = static_cast<float *>(input->data);
@@ -97,10 +97,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
         int64_t full_idx = base_idx + k * axis_mul_after;
         sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
       }
-      if (is_descend) {
-        std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
-      } else {
+      if (is_ascend) {
         std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
+      } else {
+        std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
       }
       for (int32_t k = 0; k < input->shape[axis]; ++k) {
         *(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
@@ -110,5 +110,68 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
   }
 });
 
+
+// Argsort implemented C library sort.
+// Return indices of sorted tensor.
+// By default, the last axis will be used to sort.
+// sort_num specify the number of elements to be sorted.
+// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
+// and sort axis is dk. sort_num should have dimension of
+// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  DLTensor *input = args[0];
+  DLTensor *output = args[1];
+  int32_t axis = args[2];
+  bool is_ascend = args[3];
+
+  auto dtype = input->dtype;
+  auto data_ptr = static_cast<float *>(input->data);
+  std::vector<std::pair<float, float>> sorter;
+  int64_t axis_mul_before = 1;
+  int64_t axis_mul_after = 1;
+
+  if (axis < 0) {
+    axis = input->ndim + axis;
+  }
+
+  // Currently only supports input dtype to be float32.
+  CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
+      "to be float32.";
+  CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
+      "to be float32.";
+  CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
+      "input ndim " << input->ndim;
+
+  for (int i = 0; i < input->ndim; ++i) {
+    if (i < axis) {
+      axis_mul_before *= input->shape[i];
+    } else if (i > axis) {
+      axis_mul_after *= input->shape[i];
+    }
+  }
+
+  int32_t current_sort_num = input->shape[axis];
+  for (int64_t i = 0 ; i < axis_mul_before; ++i) {
+    for (int64_t j = 0 ; j < axis_mul_after; ++j) {
+      sorter.clear();
+      int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
+      for (int64_t k = 0; k < current_sort_num; ++k) {
+        int64_t full_idx = base_idx + k * axis_mul_after;
+        sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
+      }
+      if (is_ascend) {
+        std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
+      } else {
+        std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
+      }
+      for (int32_t k = 0; k < input->shape[axis]; ++k) {
+        *(static_cast<float *>(output->data) + base_idx + k * axis_mul_after)
+            = k < static_cast<float>(sorter.size()) ? sorter[k].first : k;
+      }
+    }
+  }
+});
+
 }  // namespace contrib
 }  // namespace tvm
diff --git a/src/relay/op/algorithm/sort.cc b/src/relay/op/algorithm/sort.cc
new file mode 100644 (file)
index 0000000..5777b79
--- /dev/null
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file nms.cc
+ * \brief Non-maximum suppression operators
+ */
+#include <tvm/relay/op.h>
+#include <tvm/relay/attrs/algorithm.h>
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(ArgsortAttrs);
+
+bool ArgsortRel(const Array<Type>& types,
+                int num_inputs,
+                const Attrs& attrs,
+                const TypeReporter& reporter) {
+  // `types` contains: [data, result]
+  const ArgsortAttrs* param = attrs.as<ArgsortAttrs>();
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "Argsort: expect input type to be TensorType but get "
+        << types[0];
+    return false;
+  }
+  CHECK_EQ(param->dtype, Float(32));
+  reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
+  return true;
+}
+
+Expr MakeArgsort(Expr data,
+                 int axis,
+                 bool is_ascend,
+                 DataType dtype) {
+  auto attrs = make_node<ArgsortAttrs>();
+  attrs->axis = axis;
+  attrs->is_ascend = is_ascend;
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("argsort");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_API("relay.op._make.argsort")
+.set_body_typed(MakeArgsort);
+
+RELAY_REGISTER_OP("argsort")
+.describe(R"doc(Returns the indices that would sort an
+input array along the given axis.
+)doc" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.ArgsortAttrs")
+.add_argument("data", "Tensor", "Input data.")
+.set_support_level(6)
+.add_type_rel("Argsort", ArgsortRel);
+}  // namespace relay
+}  // namespace tvm
index 5344bce..2e5661c 100644 (file)
@@ -106,6 +106,8 @@ Expr MakeNMS(Expr data,
              double iou_threshold,
              bool force_suppress,
              int top_k,
+             int coord_start,
+             int score_index,
              int id_index,
              bool return_indices,
              bool invalid_to_bottom) {
@@ -114,6 +116,8 @@ Expr MakeNMS(Expr data,
   attrs->iou_threshold = iou_threshold;
   attrs->force_suppress = force_suppress;
   attrs->top_k = top_k;
+  attrs->coord_start = coord_start;
+  attrs->score_index = score_index;
   attrs->id_index = id_index;
   attrs->return_indices = return_indices;
   attrs->invalid_to_bottom = invalid_to_bottom;
index 856d3fa..87cdac0 100644 (file)
@@ -24,11 +24,11 @@ def test_sort():
     data = tvm.placeholder((n, l, m), name='data')
     sort_num = tvm.placeholder((n, m), name="sort_num", dtype="int32")
     axis = 1
-    is_descend = True
+    is_ascend = False
     out = tvm.extern(data.shape, [data, sort_num],
                      lambda ins, outs: tvm.call_packed(
-                         "tvm.contrib.sort.argsort", ins[0],
-                         ins[1], outs[0], axis, is_descend),
+                         "tvm.contrib.sort.argsort_nms", ins[0],
+                         ins[1], outs[0], axis, is_ascend),
                      dtype='int32', name="sort_tensor")
     input = [[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
              [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]]
@@ -50,13 +50,13 @@ def test_sort_np():
     dshape = (1, 2, 3, 4, 5, 6)
     axis = 4
     reduced_shape = (1, 2, 3, 4, 6)
-    is_descend = False
+    is_ascend = True
     data = tvm.placeholder(dshape, name='data')
     sort_num = tvm.placeholder(reduced_shape, name="sort_num", dtype="int32")
     out = tvm.extern(data.shape, [data, sort_num],
                      lambda ins, outs: tvm.call_packed(
-                         "tvm.contrib.sort.argsort", ins[0],
-                         ins[1], outs[0], axis, is_descend),
+                         "tvm.contrib.sort.argsort_nms", ins[0],
+                         ins[1], outs[0], axis, is_ascend),
                      dtype='int32', name="sort_tensor")
 
     ctx = tvm.cpu(0)
index 7e1c371..e6d99c7 100644 (file)
@@ -177,12 +177,13 @@ def test_get_valid_counts():
         assert "score_threshold" in z.astext()
         func = relay.Function([x], z.astuple())
         func = relay.ir_pass.infer_type(func)
-        ctx_list = [("llvm", tvm.cpu(0))]
-        for target, ctx in ctx_list:
+        for target, ctx in ctx_list():
+            if target == 'cuda':
+                return
             intrp = relay.create_executor("debug", ctx=ctx, target=target)
             out = intrp.evaluate(func)(np_data)
-            tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3)
-            tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3)
+            tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
+            tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
 
     verify_get_valid_counts((1, 2500, 6), 0)
     verify_get_valid_counts((1, 2500, 6), -1)
@@ -195,9 +196,13 @@ def test_non_max_suppression():
                    iou_threshold=0.5, force_suppress=False, top_k=-1,
                    check_type_only=False):
         x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
-        x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int"))
-        z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False)
-        z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k)
+        x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32"))
+        z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
+            iou_threshold = iou_threshold, force_suppress = force_suppress, \
+            top_k = top_k, return_indices=False)
+        z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
+                    iou_threshold = iou_threshold, force_suppress = force_suppress, \
+                    top_k = top_k)
         assert "iou_threshold" in z.astext()
         assert "iou_threshold" in z_indices.astext()
         zz = relay.ir_pass.infer_type(z)
@@ -212,8 +217,7 @@ def test_non_max_suppression():
         func = relay.ir_pass.infer_type(func)
         func_indices = relay.Function([x0, x1], z_indices)
         func_indices = relay.ir_pass.infer_type(func_indices)
-        ctx_list = [("llvm", tvm.cpu(0))]
-        for target, ctx in ctx_list:
+        for target, ctx in ctx_list():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(x0_data, x1_data)
             op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data)
@@ -296,8 +300,7 @@ def test_multibox_transform_loc():
         nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False)
         func = relay.Function([cls_prob, loc_pred, anchors], nms)
         func = relay.ir_pass.infer_type(func)
-        ctx_list = [("llvm", tvm.cpu(0))]
-        for target, ctx in ctx_list:
+        for target, ctx in ctx_list():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
             op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds,
                                             np_anchors)
diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py
new file mode 100644 (file)
index 0000000..983a915
--- /dev/null
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Support level6 operator test cases.
+"""
+import math
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.testing import ctx_list
+import topi.testing
+
+def test_argsort():
+    def verify_argsort(shape, axis, is_ascend):
+        x = relay.var("x", relay.TensorType(shape, "float32"))
+        z = relay.argsort(x, axis=axis, is_ascend=is_ascend)
+        zz = relay.ir_pass.infer_type(z)
+        func = relay.Function([x], z)
+        x_data = np.random.uniform(size=shape).astype("float32")
+        if is_ascend:
+            ref_res = np.argsort(x_data, axis=axis)
+        else:
+            ref_res = np.argsort(-x_data, axis=axis)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(x_data)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5)
+    verify_argsort((2, 3, 4), axis=0, is_ascend=False)
+    verify_argsort((1, 4, 6), axis=1, is_ascend=True)
+    verify_argsort((3, 5, 6), axis=-1, is_ascend=False)
+
+
+if __name__ == "__main__":
+    test_argsort()
index 2eb460d..a998414 100644 (file)
@@ -21,6 +21,7 @@ from .generic_op_impl import *
 from .reduction import *
 from .transform import *
 from .broadcast import *
+from .sort import *
 from . import nn
 from . import x86
 from . import cuda
index e6377fa..5d04d72 100644 (file)
@@ -20,77 +20,380 @@ import math
 import tvm
 
 from tvm import api
-from topi.vision import non_max_suppression
-from ..util import get_const_tuple
+from tvm.generic import cast
+from tvm.intrin import if_then_else, log, power
+from topi.vision import non_max_suppression, get_valid_counts
+from .sort import argsort
 
-def sort_ir(data, index, output):
-    """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+
+def get_valid_counts_pre(data, flag, idx, score_threshold):
+    """Low level IR to Prepare get valid count of bounding boxes
+    given a score threshold. Also moves valid boxes to the
+    top of input data.
 
     Parameters
     ----------
     data: Buffer
-        2D Buffer of input boxes' score with shape [batch_size, num_anchors].
+        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+    flag : Buffer
+        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
 
-    index : Buffer
-        1D Buffer of number of valid number of boxes.
+    idx : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
 
-    output : Buffer
-        2D Output buffer of indicies of sorted tensor with shape [batch_size, num_anchors].
+    score_threshold : float32
+        Lower limit of score for valid bounding boxes.
 
     Returns
     -------
     stmt : Stmt
         The result IR statement.
     """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    box_data_length = data.shape[2]
+
+    ib = tvm.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    flag = ib.buffer_ptr(flag)
+    idx = ib.buffer_ptr(idx)
+    score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
 
-    assert data.dtype == "float32", "Currently only supports input dtype to be float32"
-    batch, num_anchors = get_const_tuple(data.shape)
     max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = batch_size * num_anchors // max_threads + 1
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+
+    with ib.if_scope(tid < batch_size * num_anchors):
+        with ib.if_scope(data[tid * box_data_length + 1] > score_threshold):
+            flag[tid] = 1
+            idx[tid] = 1
+        with ib.else_scope():
+            flag[tid] = 0
+            idx[tid] = 0
+
+    return ib.get()
+
+def get_valid_counts_upsweep(data, idx_in, idx, partial):
+    """Low level IR of first step of scan: unsweep.
+
+    Parameters
+    ----------
+    data: Buffer
+        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+    idx_in : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+    idx : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+    partial : Buffer
+        2D Buffer of valid data indices with shape [batch_size, new_range].
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
     ib = tvm.ir_builder.create()
-    p_data = ib.buffer_ptr(data)
-    p_index = ib.buffer_ptr(index)
-    p_out = ib.buffer_ptr(output)
+    data = ib.buffer_ptr(data)
+    idx_in = ib.buffer_ptr(idx_in)
+    idx = ib.buffer_ptr(idx)
+    partial = ib.buffer_ptr(partial)
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    elem_per_thread = num_anchors // max_threads + 1
     nthread_tx = max_threads
-    nthread_bx = num_anchors // max_threads + 1
+    nthread_bx = batch_size
     tx = tvm.thread_axis("threadIdx.x")
-    bx = tvm.thread_axis("vthread")
+    bx = tvm.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "virtual_thread", nthread_bx)
-    tid = bx * nthread_tx + tx
-    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
-    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
-
-    with ib.for_range(0, batch, for_type="unroll") as b:
-        start = b * num_anchors
-        with ib.if_scope(tid < num_anchors):
-            p_out[start + tid] = tid
-        # OddEvenTransposeSort
-        with ib.for_range(0, p_index[b]) as k:
-            with ib.if_scope(tid < (p_index[b] + 1) // 2):
-                offset = start + 2 * tid + (k % 2)
-                with ib.if_scope( \
-                        tvm.all(offset + 1 < p_index[0], p_data[offset] < p_data[offset + 1])):
-                    temp_data[0] = p_data[offset]
-                    p_data[offset] = p_data[offset + 1]
-                    p_data[offset + 1] = temp_data[0]
-                    temp_index[0] = p_out[offset]
-                    p_out[offset] = p_out[offset + 1]
-                    p_out[offset + 1] = temp_index[0]
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    new_range = num_anchors // elem_per_thread + 1
+    # Scan: Upsweep:
+    with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)):
+        with ib.for_range(0, elem_per_thread) as i:
+            with ib.if_scope(bx * num_anchors + \
+                             tx * elem_per_thread + i < batch_size * num_anchors):
+                with ib.if_scope(i == 0):
+                    partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread]
+                    idx[bx * num_anchors + tx * elem_per_thread] = \
+                    idx_in[bx * num_anchors + tx * elem_per_thread]
+                with ib.else_scope():
+                    partial[bx * new_range + tx] += \
+                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
+                    idx[bx * num_anchors + tx * elem_per_thread + i] = \
+                    idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \
+                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
+    return ib.get()
+
+def get_valid_counts_scan(data, partial_in, partial):
+    """Low level IR to do scan.
+
+    Parameters
+    ----------
+    data: Buffer
+        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+    idx_in : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+    idx : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+    partial : Buffer
+        2D Buffer of valid data indices with shape [batch_size, new_range].
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    ib = tvm.ir_builder.create()
+    partial_in = ib.buffer_ptr(partial_in)
+    partial = ib.buffer_ptr(partial)
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    elem_per_thread = num_anchors // max_threads + 1
+    nthread_tx = max_threads
+    nthread_bx = batch_size
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    var = tvm.make.node("FloatImm", dtype="float32", value=2)
+    new_range = num_anchors // elem_per_thread + 1
+    iteration = log(cast(new_range, "float32")) // math.log(2)
+    # Scan: Kogge-Stone adder
+    with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
+        with ib.for_range(0, iteration) as k:
+            with ib.if_scope(k == 0):
+                with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))):
+                    partial[bx * new_range + tx] = \
+                    partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1]
+                with ib.else_scope():
+                    partial[bx * new_range] = partial_in[bx * new_range]
+            with ib.else_scope():
+                with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \
+                                         tx < tvm.min(new_range, num_anchors))):
+                    partial[bx * new_range + tx] += \
+                    partial[bx * new_range + tx - cast(power(var, k), "int32")]
             ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                   tvm.convert(['shared']),
                                   tvm.expr.Call.Intrinsic, None, 0))
+    return ib.get()
+
+def get_valid_counts_downsweep(data, idx_in, partial, idx):
+    """Low level IR to do downsweep of scan.
+
+    Parameters
+    ----------
+    data: Buffer
+        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+    idx_in : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+    partial : Buffer
+        2D Buffer of valid data indices with shape [batch_size, new_range].
+
+    idx : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    ib = tvm.ir_builder.create()
+    idx_in = ib.buffer_ptr(idx_in)
+    idx = ib.buffer_ptr(idx)
+    partial = ib.buffer_ptr(partial)
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    elem_per_thread = num_anchors // max_threads + 1
+    nthread_tx = max_threads
+    nthread_bx = batch_size * num_anchors // max_threads + 1
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+    new_range = num_anchors // elem_per_thread + 1
+    # Scan: Downsweep:
+    with ib. if_scope(tid < batch_size * num_anchors):
+        i = tid / num_anchors # number of batches
+        j = tid % num_anchors # number of anchors
+        with ib.if_scope(j < elem_per_thread):
+            idx[tid] = idx_in[tid]
+        with ib.else_scope():
+            idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1]
 
     return ib.get()
 
-def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
+def get_valid_counts_ir(data, flag, idx, valid_count, out):
+    """Low level IR to get valid count of bounding boxes
+    given a score threshold. Also moves valid boxes to the
+    top of input data.
+
+    Parameters
+    ----------
+    data : Buffer
+        Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].
+
+    flag : Buffer
+        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
+
+    idx : Buffer
+        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+    valid_count : Buffer
+        1-D buffer for valid number of boxes.
+
+    out : Buffer
+        Rearranged data buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    elem_length = data.shape[2]
+    size = batch_size * num_anchors * elem_length
+
+    ib = tvm.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    flag = ib.buffer_ptr(flag)
+    idx = ib.buffer_ptr(idx)
+    valid_count = ib.buffer_ptr(valid_count)
+    out = ib.buffer_ptr(out)
+
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    tid = bx * max_threads + tx
+
+    with ib.if_scope(tid < batch_size * num_anchors):
+        i = tid / num_anchors
+        j = tid % num_anchors
+        base_idx = i * num_anchors * elem_length
+        with ib.if_scope(flag[tid] > 0):
+            with ib.for_range(0, elem_length) as k:
+                with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size):
+                    out[base_idx + (idx[tid] - 1) * elem_length + k] =\
+                    data[base_idx + j * elem_length + k]
+        with ib.if_scope(j == 0):
+            valid_count[i] = idx[tid + num_anchors - 1]
+        with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]):
+            with ib.for_range(0, elem_length) as l:
+                with ib.if_scope(tid * elem_length + l < size):
+                    out[tid * elem_length + l] = -1.0
+    return ib.get()
+
+
+@get_valid_counts.register(["cuda", "gpu"])
+def get_valid_counts_gpu(data, score_threshold=0):
+    """Get valid count of bounding boxes given a score threshold.
+    Also moves valid boxes to the top of input data.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+    score_threshold : optional, float
+        Lower limit of score for valid bounding boxes.
+
+    Returns
+    -------
+    valid_count : tvm.Tensor
+        1-D tensor for valid number of boxes.
+
+    out_tensor : tvm.Tensor
+        Rearranged data tensor.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    elem_per_thread = num_anchors // max_threads + 1
+    new_range = num_anchors // elem_per_thread + 1
+    temp_flag_buf = api.decl_buffer(
+        (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8)
+    temp_idx_buf = api.decl_buffer(
+        (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8)
+    temp_partial_buf = api.decl_buffer(
+        (batch_size, new_range), "int32", "temp_partial", data_alignment=8)
+    data_buf = api.decl_buffer(
+        data.shape, data.dtype, "data_buf", data_alignment=8)
+
+    temp_flag, temp_idx = \
+        tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
+                   lambda ins, outs: get_valid_counts_pre(
+                       ins[0], outs[0], outs[1], score_threshold),
+                   dtype=["int32", "int32"],
+                   out_buffers=[temp_flag_buf, temp_idx_buf],
+                   name="get_valid_counts_phase_one")
+    temp_idx_new, temp_partial = \
+        tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx],
+                   lambda ins, outs: get_valid_counts_upsweep(
+                       ins[0], ins[1], outs[0], outs[1]),
+                   dtype=["int32", "int32"],
+                   out_buffers=[temp_idx_buf, temp_partial_buf],
+                   name="get_valid_counts_phase_two")
+    temp_partial_new = \
+        tvm.extern([(batch_size, new_range)], [data, temp_partial],
+                   lambda ins, outs: get_valid_counts_scan(
+                       ins[0], ins[1], outs[0]),
+                   dtype=["int32"],
+                   out_buffers=[temp_partial_buf],
+                   name="get_valid_counts_phase_three")
+    temp_idx_final = \
+        tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new],
+                   lambda ins, outs: get_valid_counts_downsweep(
+                       ins[0], ins[1], ins[2], outs[0]),
+                   dtype=["int32"],
+                   out_buffers=[temp_idx_buf],
+                   name="get_valid_counts_phase_four")
+    valid_count, out_tensor = \
+       tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final],
+            lambda ins, outs: get_valid_counts_ir(
+                ins[0], ins[1], ins[2], outs[0], outs[1]),
+            dtype=["int32", data.dtype],
+            in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
+            name="get_valid_counts_phase_five",
+            tag="get_valid_counts_gpu")
+
+    return [valid_count, out_tensor]
+
+
+def nms_ir(data, sorted_index, valid_count, out, box_indices,
+           max_output_size, iou_threshold, force_suppress,
+           top_k, coord_start, id_index):
     """Low level IR routing for transform location in multibox_detection operator.
 
     Parameters
     ----------
-    data: Buffer
+    data : Buffer
         Buffer of output boxes with class and score.
 
-    sort_result : Buffer
+    sort_index : Buffer
         Buffer of output box indexes sorted by score.
 
     valid_count : Buffer
@@ -99,15 +402,25 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
     out : Buffer
         Output buffer.
 
-    nms_threshold : float
-        Non-maximum suppression threshold.
+    max_output_size : int
+        Max number of output valid boxes for each instance.
+        By default all valid boxes are returned.
+
+    iou_threshold : float
+        Overlapping(IoU) threshold to suppress object with smaller score.
 
     force_suppress : boolean
         Whether to suppress all detections regardless of class_id.
 
-    nms_topk : int
+    top_k : int
         Keep maximum top k detections before nms, -1 for no limit.
 
+    coord_start : int
+        Start index of the consecutive 4 coordinates.
+
+    id_index : int
+        index of the class categories, -1 to disable.
+
     Returns
     -------
     stmt : Stmt
@@ -127,100 +440,232 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
             (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
         return tvm.expr.Select(u <= 0.0, 0.0, i / u)
 
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    box_data_length = data.shape[2]
+
+    ib = tvm.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    sorted_index = ib.buffer_ptr(sorted_index)
+    valid_count = ib.buffer_ptr(valid_count)
+    out = ib.buffer_ptr(out)
+    box_indices = ib.buffer_ptr(box_indices)
+    num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
+
     max_threads = int(math.sqrt(
         tvm.target.current_target(allow_none=False).max_num_threads))
-    ib = tvm.ir_builder.create()
-    p_data = ib.buffer_ptr(data)
-    p_sort_result = ib.buffer_ptr(sort_result)
-    p_valid_count = ib.buffer_ptr(valid_count)
-    p_out = ib.buffer_ptr(out)
-    batch_size = out.shape[0]
-    num_anchors = out.shape[1]
     nthread_tx = max_threads
     nthread_bx = num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
-    i = bx * max_threads + tx
-
-    nms_threshold_node = tvm.make.node(
-        "FloatImm", dtype="float32", value=nms_threshold)
-    nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
-    force_suppress_node = tvm.make.node(
-        "IntImm", dtype="int32", value=1 if force_suppress else 0)
-    with ib.for_range(0, batch_size, for_type="unroll") as b:
-        base_idx = b * num_anchors * 6
-        with ib.if_scope( \
-                tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
-                        p_valid_count[0] > 0)):
+    k = bx * max_threads + tx
+
+    iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
+    top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
+    coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
+    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
+    force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
+
+    with ib.for_range(0, batch_size, for_type="unroll") as i:
+        base_idx = i * num_anchors * box_data_length
+        with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
             # Reorder output
-            nkeep = tvm.if_then_else( \
-                    tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b]),
-                    nms_topk, p_valid_count[b])
-            with ib.for_range(0, nkeep) as l:
-                with ib.if_scope(i < 6):
-                    p_out[(base_idx + l * 6 + i)] = \
-                            p_data[(base_idx + p_sort_result[b * num_anchors + l] * 6 + i)]
-            with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])):
-                with ib.for_range(0, p_valid_count[b] - nkeep) as l:
-                    with ib.if_scope(i < 6):
-                        p_out[(base_idx + (l + nkeep) * 6 + i)] = -1.0
+            nkeep = if_then_else( \
+                    tvm.all(top_k > 0, top_k < valid_count[i]),
+                    top_k, valid_count[i])
+            with ib.for_range(0, nkeep) as j:
+                with ib.if_scope(k < box_data_length):
+                    out[(base_idx + j * box_data_length + k)] = \
+                    data[(base_idx + sorted_index[i * num_anchors + j] \
+                    * box_data_length + k)]
+                box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
+            with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
+                with ib.for_range(0, valid_count[i] - nkeep) as j:
+                    with ib.if_scope(k < box_data_length):
+                        out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
+                    box_indices[i * num_anchors + (j + nkeep)] = -1
             # Apply nms
-            with ib.for_range(0, p_valid_count[b]) as l:
-                offset_l = l * 6
-                with ib.if_scope(p_out[base_idx + offset_l] >= 0):
-                    with ib.if_scope(i < p_valid_count[b]):
-                        offset_i = i * 6
-                        with ib.if_scope(tvm.all(i > l, p_out[base_idx
-                                                              + offset_i] >= 0)):
-                            with ib.if_scope(tvm.any(force_suppress_node > 0,
-                                                     p_out[base_idx + offset_l] ==
-                                                     p_out[base_idx + offset_i])):
-                                # When force_suppress == True or class_id equals
-                                iou = calculate_overlap(p_out, base_idx + offset_l + 2,
-                                                        base_idx + offset_i + 2)
-                                with ib.if_scope(iou >= nms_threshold):
-                                    p_out[base_idx + offset_i] = -1.0
+            with ib.for_range(0, valid_count[i]) as j:
+                offset_j = j * box_data_length
+                with ib.if_scope(out[base_idx + offset_j] >= 0):
+                    with ib.if_scope(k < valid_count[i]):
+                        offset_k = k * box_data_length
+                        with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \
+                                                tvm.any(force_suppress > 0, id_index < 0, \
+                                                         out[base_idx + offset_j] == \
+                                                         out[base_idx + offset_k]))):
+                            iou = calculate_overlap(out, base_idx + offset_k + coord_start,
+                                                    base_idx + offset_j + coord_start)
+                            with ib.if_scope(iou >= iou_threshold):
+                                out[base_idx + offset_k] = -1.0
+                                box_indices[i * num_anchors + k] = -1
                 ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                       tvm.convert(['shared']),
                                       tvm.expr.Call.Intrinsic, None, 0))
         with ib.else_scope():
-            with ib.for_range(0, p_valid_count[b]) as c:
-                with ib.if_scope(i < 6):
-                    p_out[(base_idx + c * 6 + i)] = p_data[base_idx + c * 6 + i]
+            with ib.for_range(0, valid_count[i]) as j:
+                offset_j = j * box_data_length
+                with ib.if_scope(k < box_data_length):
+                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
+                box_indices[i * num_anchors + j] = j
         # Set invalid entry to be -1
-        with ib.for_range(0, num_anchors - p_valid_count[b]) as c:
-            with ib.if_scope(i < 6):
-                p_out[base_idx + (c + p_valid_count[b]) * 6 + i] = -1.0
-    body = ib.get()
-    return body
+        with ib.for_range(0, num_anchors - valid_count[i]) as j:
+            with ib.if_scope(k < box_data_length):
+                out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
+            box_indices[i * num_anchors + j + valid_count[i]] = -1
+        # Only return max_output_size number of valid boxes
+        num_valid_boxes[0] = 0
+        with ib.if_scope(max_output_size > 0):
+            with ib.for_range(0, valid_count[i]) as j:
+                offset_j = j * box_data_length
+                with ib.if_scope(out[base_idx + offset_j] >= 0):
+                    with ib.if_scope(num_valid_boxes[0] == max_output_size):
+                        with ib.if_scope(k < box_data_length):
+                            out[base_idx + offset_j + k] = -1.0
+                        box_indices[i * num_anchors + j] = -1
+                    with ib.else_scope():
+                        num_valid_boxes[0] += 1
+                ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
+                                      tvm.convert(['shared']),
+                                      tvm.expr.Call.Intrinsic, None, 0))
+
+    return ib.get()
+
+
+def invalid_to_bottom_pre(data, flag, idx):
+    """Low level IR to rearrange nms output to move all valid entries to top.
+
+    Parameters
+    ----------
+    data: Buffer
+        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+    flag : Buffer
+        1D Buffer of flag indicating valid data with [num_anchors].
+
+    idx : Buffer
+        1D Buffer of valid data indices with [num_anchors].
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    elem_length = data.shape[2]
+
+    ib = tvm.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    flag = ib.buffer_ptr(flag)
+    idx = ib.buffer_ptr(idx)
+
+    max_threads = int(math.sqrt(
+        tvm.target.current_target(allow_none=False).max_num_threads))
+    nthread_tx = max_threads
+    nthread_bx = num_anchors // max_threads + 1
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    j = bx * max_threads + tx
+
+    with ib.for_range(0, batch_size, for_type="unroll") as i:
+        base_idx = i * num_anchors * elem_length
+        with ib.if_scope(j < num_anchors):
+            with ib.if_scope(data[base_idx + j * elem_length] >= 0):
+                flag[i * num_anchors + j] = 1
+                idx[i * num_anchors + j] = 1
+            with ib.else_scope():
+                flag[i * num_anchors + j] = 0
+                idx[i * num_anchors + j] = 0
+
+    with ib.if_scope(j < batch_size):
+        with ib.for_range(0, num_anchors) as k:
+            with ib.if_scope(k > 0):
+                idx[j * num_anchors + k] += idx[j * num_anchors + k - 1]
+    return ib.get()
+
+
+def invalid_to_bottom_ir(data, flag, idx, out):
+    """Low level IR to rearrange nms output to move all valid entries to top.
+
+    Parameters
+    ----------
+    data: Buffer
+        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+    flag : Buffer
+        1D Buffer of flag indicating valid data with [num_anchors].
+
+    idx : Buffer
+        1D Buffer of valid data indices with [num_anchors].
+
+    out : Buffer
+        3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length].
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    elem_length = data.shape[2]
+
+    ib = tvm.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    flag = ib.buffer_ptr(flag)
+    idx = ib.buffer_ptr(idx)
+    out = ib.buffer_ptr(out)
+
+    max_threads = int(math.sqrt(
+        tvm.target.current_target(allow_none=False).max_num_threads))
+    nthread_tx = max_threads
+    nthread_bx = num_anchors // max_threads + 1
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("blockIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
+    j = bx * max_threads + tx
+
+    with ib.for_range(0, batch_size, for_type="unroll") as i:
+        base_idx = i * num_anchors * elem_length
+        with ib.if_scope(j < num_anchors):
+            with ib.for_range(0, elem_length) as k:
+                out[base_idx + j * elem_length + k] = -1.0
+            with ib.if_scope(flag[i * num_anchors + j] > 0):
+                with ib.for_range(0, elem_length) as k:
+                    out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
+                    = data[base_idx + j * elem_length + k]
+    return ib.get()
 
 
 @non_max_suppression.register(["cuda", "gpu"])
-def nms_gpu(data,
-            valid_count,
-            max_output_size=-1,
-            iou_threshold=0.5,
-            force_suppress=False,
-            top_k=-1,
-            id_index=0,
-            return_indices=True,
-            invalid_to_bottom=False):
+def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
+                            iou_threshold=0.5, force_suppress=False, top_k=-1,
+                            coord_start=2, score_index=1, id_index=0,
+                            return_indices=True, invalid_to_bottom=False):
     """Non-maximum suppression operator for object detection.
 
     Parameters
     ----------
     data : tvm.Tensor
-        3-D tensor with shape [batch_size, num_anchors, 6].
+        3-D tensor with shape [batch_size, num_anchors, elem_length].
         The last dimension should be in format of
         [class_id, score, box_left, box_top, box_right, box_bottom].
 
     valid_count : tvm.Tensor
         1-D tensor for valid number of boxes.
 
-    return_indices : boolean
-        Whether to return box indices in input data.
+    max_output_size : optional, int
+        Max number of output valid boxes for each instance.
+        By default all valid boxes are returned.
 
     iou_threshold : optional, float
         Non-maximum suppression threshold.
@@ -231,16 +676,25 @@ def nms_gpu(data,
     top_k : optional, int
         Keep maximum top k detections before nms, -1 for no limit.
 
+    coord_start : required, int
+        Start index of the consecutive 4 coordinates.
+
+    score_index : optional, int
+        Index of the scores/confidence of boxes.
+
     id_index : optional, int
         index of the class categories, -1 to disable.
 
+    return_indices : boolean
+        Whether to return box indices in input data.
+
     invalid_to_bottom : optional, boolean
         Whether to move all valid bounding boxes to the top.
 
     Returns
     -------
     out : tvm.Tensor
-        3-D tensor with shape [batch_size, num_anchors, 6].
+        3-D tensor with shape [batch_size, num_anchors, elem_length].
 
     Example
     --------
@@ -253,12 +707,13 @@ def nms_gpu(data,
         iou_threshold = 0.7
         force_suppress = True
         top_k = -1
-        out = nms(data, valid_count, iou_threshold, force_suppress, topk)
+        out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold,
+                                 force_suppress=force_supress, top_k=top_k, return_indices=False)
         np_data = np.random.uniform(dshape)
         np_valid_count = np.array([4])
         s = topi.generic.schedule_nms(out)
-        f = tvm.build(s, [data, valid_count, out], "llvm")
-        ctx = tvm.cpu()
+        f = tvm.build(s, [data, valid_count, out], "cuda")
+        ctx = tvm.gpu(0)
         tvm_data = tvm.nd.array(np_data, ctx)
         tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
         tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
@@ -266,38 +721,62 @@ def nms_gpu(data,
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
+
     valid_count_dtype = "int32"
     valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
                                       "valid_count_buf", data_alignment=4)
-    data_buf = api.decl_buffer(
-        data.shape, data.dtype, "data_buf", data_alignment=8)
+    score_axis = score_index
     score_shape = (batch_size, num_anchors)
-    score_tensor = tvm.compute(
-        score_shape, lambda i, j: data[i, j, 1], name="score_tensor")
-    score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
-                                       "score_tensor_buf", data_alignment=8)
+    score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
+    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
 
-    sort_tensor_dtype = "int32"
-    sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
+    sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
                                       "sort_tensor_buf", data_alignment=8)
 
-    sort_tensor = \
-        tvm.extern(score_shape,
-                   [score_tensor, valid_count],
-                   lambda ins, outs: sort_ir(
-                       ins[0], ins[1], outs[0]),
-                   dtype=sort_tensor_dtype,
-                   in_buffers=[score_tensor_buf, valid_count_buf],
-                   out_buffers=sort_tensor_buf,
-                   name="nms_sort")
+    data_buf = api.decl_buffer(
+        data.shape, data.dtype, "data_buf", data_alignment=8)
 
-    out = \
-        tvm.extern(data.shape,
+    out_buf = api.decl_buffer(
+        data.shape, data.dtype, "out_buf", data_alignment=8)
+
+    out, box_indices = \
+        tvm.extern([data.shape, score_shape],
                    [data, sort_tensor, valid_count],
                    lambda ins, outs: nms_ir(
-                       ins[0], ins[1], ins[2], outs[0], iou_threshold,
-                       force_suppress, top_k),
-                   dtype="float32",
+                       ins[0], ins[1], ins[2], outs[0], outs[1],
+                       max_output_size, iou_threshold, force_suppress,
+                       top_k, coord_start, id_index),
+                   dtype=[data.dtype, "int32"],
                    in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+                   name="nms",
                    tag="nms")
+
+    if return_indices:
+        return box_indices
+
+    if invalid_to_bottom:
+        output_buf = api.decl_buffer(
+            data.shape, data.dtype, "output_buf", data_alignment=8)
+        temp_flag_buf = api.decl_buffer(
+            score_shape, valid_count_dtype, "temp_flag", data_alignment=8)
+        temp_idx_buf = api.decl_buffer(
+            score_shape, valid_count_dtype, "temp_idx", data_alignment=8)
+        temp_flag, temp_idx = tvm.extern([score_shape, score_shape], [out],
+                                         lambda ins, outs: invalid_to_bottom_pre(
+                                             ins[0], outs[0], outs[1]),
+                                         dtype=["int32", "int32"],
+                                         in_buffers=[out_buf],
+                                         out_buffers=[temp_flag_buf, temp_idx_buf],
+                                         name="invalid_to_bottom_phase_one")
+
+        output = tvm.extern([data.shape], [out, temp_flag, temp_idx],
+                            lambda ins, outs: invalid_to_bottom_ir(
+                                ins[0], ins[1], ins[2], outs[0]),
+                            dtype=[data.dtype],
+                            in_buffers=[out_buf, temp_flag_buf, temp_idx_buf],
+                            out_buffers=[output_buf],
+                            name="invalid_to_bottom",
+                            tag="invalid_to_bottom")
+        return output
+
     return out
diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py
new file mode 100644 (file)
index 0000000..99ba852
--- /dev/null
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Argsort operator """
+import tvm
+
+from tvm import api
+from topi.sort import argsort
+
+def sort_ir(data, output, axis, is_ascend):
+    """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+
+    Parameters
+    ----------
+    data: Buffer
+        Buffer of input data.
+
+    output : Buffer
+        Output buffer of indicies of sorted tensor with same shape as data.
+
+    axis : Int
+        Axis long which to sort the input tensor.
+
+    is_ascend : Boolean
+        Whether to sort in ascending or descending order.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    size = 1
+    axis_mul_before = 1
+    axis_mul_after = 1
+    shape = data.shape
+    if axis < 0:
+        axis = len(shape) + axis
+    for i, value in enumerate(shape, 0):
+        size *= value
+        if i < axis:
+            axis_mul_before *= value
+        elif i > axis:
+            axis_mul_after *= value
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    ib = tvm.ir_builder.create()
+    data = ib.buffer_ptr(data)
+    output = ib.buffer_ptr(output)
+    nthread_tx = max_threads
+    nthread_bx = size // max_threads + 1
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("vthread")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "virtual_thread", nthread_bx)
+    tid = bx * nthread_tx + tx
+    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
+    temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local")
+    is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
+
+    with ib.for_range(0, axis_mul_before) as i:
+        with ib.for_range(0, axis_mul_after) as j:
+            current_sort_num = shape[axis]
+            base_idx = i * shape[axis] * axis_mul_after + j
+            with ib.if_scope(tid < shape[axis]):
+                output[base_idx + tid * axis_mul_after] = tid.astype("float32")
+            # OddEvenTransposeSort
+            with ib.for_range(0, current_sort_num) as k:
+                with ib.if_scope(tid < (current_sort_num + 1) // 2):
+                    offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
+                    with ib.if_scope(tvm.all(is_ascend == 1, \
+                                             2 * tid + (k % 2) + 1 < current_sort_num, \
+                                             data[offset] > data[offset + axis_mul_after])):
+                        temp_data[0] = data[offset]
+                        data[offset] = data[offset + axis_mul_after]
+                        data[offset + axis_mul_after] = temp_data[0]
+                        temp_index[0] = output[offset]
+                        output[offset] = output[offset + axis_mul_after]
+                        output[offset + axis_mul_after] = temp_index[0]
+                    with ib.if_scope(tvm.all(is_ascend == 0, \
+                                             2 * tid + (k % 2) + 1 < current_sort_num, \
+                                             data[offset] < data[offset + axis_mul_after])):
+                        temp_data[0] = data[offset]
+                        data[offset] = data[offset + axis_mul_after]
+                        data[offset + axis_mul_after] = temp_data[0]
+                        temp_index[0] = output[offset]
+                        output[offset] = output[offset + axis_mul_after]
+                        output[offset + axis_mul_after] = temp_index[0]
+                ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
+                                      tvm.convert(['shared']),
+                                      tvm.expr.Call.Intrinsic, None, 0))
+
+    return ib.get()
+
+
+
+def sort_nms_ir(data, valid_count, output, axis, is_ascend):
+    """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+
+    Parameters
+    ----------
+    data: Buffer
+        Buffer of input data.
+
+    valid_count : Buffer
+        1D Buffer of number of valid number of boxes.
+
+    output : Buffer
+        Output buffer of indicies of sorted tensor with same shape as data.
+
+    axis : Int
+        Axis long which to sort the input tensor.
+
+    is_ascend : Boolean
+        Whether to sort in ascending or descending order.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+
+    size = 1
+    axis_mul_before = 1
+    axis_mul_after = 1
+    shape = data.shape
+    if axis < 0:
+        axis = len(shape) + axis
+    for i, value in enumerate(shape, 0):
+        size *= value
+        if i < axis:
+            axis_mul_before *= value
+        elif i > axis:
+            axis_mul_after *= value
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    ib = tvm.ir_builder.create()
+    data = ib.buffer_ptr(data)
+    valid_count = ib.buffer_ptr(valid_count)
+    output = ib.buffer_ptr(output)
+    nthread_tx = max_threads
+    nthread_bx = size // max_threads + 1
+    tx = tvm.thread_axis("threadIdx.x")
+    bx = tvm.thread_axis("vthread")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    ib.scope_attr(bx, "virtual_thread", nthread_bx)
+    tid = bx * nthread_tx + tx
+    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
+    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
+    is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
+
+    with ib.for_range(0, axis_mul_before) as i:
+        with ib.for_range(0, axis_mul_after) as j:
+            current_sort_num = valid_count[i * axis_mul_after + j]
+            base_idx = i * shape[axis] * axis_mul_after + j
+            with ib.if_scope(tid < shape[axis]):
+                output[base_idx + tid * axis_mul_after] = tid
+            # OddEvenTransposeSort
+            with ib.for_range(0, current_sort_num) as k:
+                with ib.if_scope(tid < (current_sort_num + 1) // 2):
+                    offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
+                    with ib.if_scope(tvm.all(is_ascend == 1, \
+                                             2 * tid + (k % 2) + 1 < current_sort_num, \
+                                             data[offset] > data[offset + axis_mul_after])):
+                        temp_data[0] = data[offset]
+                        data[offset] = data[offset + axis_mul_after]
+                        data[offset + axis_mul_after] = temp_data[0]
+                        temp_index[0] = output[offset]
+                        output[offset] = output[offset + axis_mul_after]
+                        output[offset + axis_mul_after] = temp_index[0]
+                    with ib.if_scope(tvm.all(is_ascend == 0, \
+                                             2 * tid + (k % 2) + 1 < current_sort_num, \
+                                             data[offset] < data[offset + axis_mul_after])):
+                        temp_data[0] = data[offset]
+                        data[offset] = data[offset + axis_mul_after]
+                        data[offset + axis_mul_after] = temp_data[0]
+                        temp_index[0] = output[offset]
+                        output[offset] = output[offset + axis_mul_after]
+                        output[offset + axis_mul_after] = temp_index[0]
+                ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
+                                      tvm.convert(['shared']),
+                                      tvm.expr.Call.Intrinsic, None, 0))
+
+    return ib.get()
+
+@argsort.register(["cuda", "gpu"])
+def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
+    """Performs sorting along the given axis and returns an array of indicies
+    having same shape as an input array that index data in sorted order.
+
+    Parameters
+    ----------
+    data: tvm.Tensor
+        The input array.
+
+    valid_count : tvm.Tensor
+        The number of valid elements to be sorted.
+
+    axis : int
+        Axis long which to sort the input tensor.
+
+    is_ascend : boolean
+        Whether to sort in ascending or descending order.
+
+    flag : boolean
+        Whether this argsort is used in nms operator
+
+    Returns
+    -------
+    out : tvm.Tensor
+        The output of this function.
+    """
+    data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    if flag:
+        valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
+                                          "valid_count_buf", data_alignment=4)
+        out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
+        out = tvm.extern([data.shape],
+                         [data, valid_count],
+                         lambda ins, outs: sort_nms_ir(
+                             ins[0], ins[1], outs[0], axis, is_ascend),
+                         dtype="int32",
+                         in_buffers=[data_buf, valid_count_buf],
+                         out_buffers=[out_buf],
+                         name="argsort_nms_gpu",
+                         tag="argsort_nms_gpu")
+    else:
+        out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+        out = tvm.extern([data.shape],
+                         [data],
+                         lambda ins, outs: sort_ir(
+                             ins[0], outs[0], axis, is_ascend),
+                         dtype=dtype,
+                         in_buffers=[data_buf],
+                         out_buffers=[out_buf],
+                         name="argsort_gpu",
+                         tag="argsort_gpu")
+    return out
index 38b76f3..f7e5f94 100644 (file)
@@ -21,6 +21,7 @@ import math
 import tvm
 
 from tvm import api
+from tvm.intrin import if_then_else, exp
 
 import topi
 
@@ -93,12 +94,11 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
             center_w = (j + offset_w) * steps_w
 
             for k in range(num_sizes + num_ratios - 1):
-                w = tvm.if_then_else(k < num_sizes,
-                                     size_ratio_concat[
-                                         k] * in_height / in_width / 2.0,
-                                     size_ratio_concat[0] * in_height / in_width *
-                                     math.sqrt(size_ratio_concat[k + 1]) / 2.0)
-                h = tvm.if_then_else(
+                w = if_then_else(k < num_sizes,
+                                 size_ratio_concat[k] * in_height / in_width / 2.0,
+                                 size_ratio_concat[0] * in_height / in_width *
+                                 math.sqrt(size_ratio_concat[k + 1]) / 2.0)
+                h = if_then_else(
                     k < num_sizes, size_ratio_concat[k] / 2.0,
                     size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
                 count = (i * in_width * (num_sizes + num_ratios - 1) +
@@ -154,8 +154,7 @@ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
         out = topi.clip(out, 0, 1)
     return out
 
-
-def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold):
+def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp_score, threshold):
     """Low level IR routing for transform location data preparation.
 
     Parameters
@@ -166,13 +165,13 @@ def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out,
     valid_count : Buffer
         Buffer of number of valid output boxes.
 
-    temp_flag : Buffer
+    temp_valid_count : Buffer
         Output intermediate result buffer
 
-    temp_id : Buffer
+    temp_cls_id : Buffer
         Output intermediate result buffer
 
-    temp_score_out : Buffer
+    temp_score : Buffer
         Output buffer
 
     threshold : float
@@ -187,53 +186,53 @@ def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out,
     num_classes = cls_prob.shape[1]
     num_anchors = cls_prob.shape[2]
 
-    max_threads = int(
-        tvm.target.current_target(allow_none=False).max_num_threads)
     ib = tvm.ir_builder.create()
-    score = ib.buffer_ptr(temp_score_out)
-    cls_id = ib.buffer_ptr(temp_id)
-    flag = ib.buffer_ptr(temp_flag)
+
+    cls_prob = ib.buffer_ptr(cls_prob)
+    cls_id = ib.buffer_ptr(temp_cls_id)
+    valid_count = ib.buffer_ptr(valid_count)
+    temp_valid_count = ib.buffer_ptr(temp_valid_count)
+    score = ib.buffer_ptr(temp_score)
+
+    threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)
+
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = (batch_size *  num_anchors) // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("blockIdx.x")
-    nthread_tx = max_threads
-    nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
-    p_cls_prob = ib.buffer_ptr(cls_prob)
-    p_valid_count = ib.buffer_ptr(valid_count)
 
     with ib.if_scope(tid < batch_size * num_anchors):
-        n = tid / num_anchors  # number of batches
-        i = tid % num_anchors  # number of anchors
-        score[i] = -1.0
-        cls_id[i] = 0
-        p_valid_count[n] = 0
-        with ib.for_range(0, num_classes-1, name="k") as k:
-            temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i]
-            with ib.if_scope(temp > score[i]):
-                cls_id[i] = k + 1
-                score[i] = temp
-        with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)):
-            cls_id[i] = 0
-        with ib.if_scope(cls_id[i] > 0):
-            flag[i] = 1
+        i = tid / num_anchors
+        j = tid % num_anchors
+        valid_count[i] = 0
+        score[tid] = -1.0
+        cls_id[tid] = 0
+        with ib.for_range(0, num_classes - 1) as k:
+            temp = cls_prob[i * num_classes * num_anchors + (k + 1) * num_anchors + j]
+            cls_id[tid] = if_then_else(temp > score[tid], k + 1, cls_id[tid])
+            score[tid] = tvm.max(temp, score[tid])
+        with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
+            cls_id[tid] = 0
+        with ib.if_scope(cls_id[tid] > 0):
+            temp_valid_count[tid] = 1
         with ib.else_scope():
-            flag[i] = 0
+            temp_valid_count[tid] = 0
 
         with ib.if_scope(tid < batch_size):
-            with ib.for_range(0, num_anchors, name="k") as k:
+            with ib.for_range(0, num_anchors) as k:
                 with ib.if_scope(k > 0):
-                    flag[tid * num_anchors +
-                         k] += flag[tid * num_anchors + k - 1]
-            p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1]
+                    temp_valid_count[tid * num_anchors + k] += \
+                    temp_valid_count[tid * num_anchors + k - 1]
+            valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors - 1]
 
-    body = ib.get()
-    return body
+    return ib.get()
 
-
-def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
-                     out, clip, variances, batch_size, num_classes, num_anchors):
+def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score, out, \
+                     clip, variances, batch_size, num_anchors):
     """Low level IR routing for transform location in multibox_detection operator.
 
     Parameters
@@ -244,13 +243,13 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
     anchor : Buffer
         Buffer of prior anchor boxes.
 
-    temp_flag : Buffer
+    temp_valid_count : Buffer
         Intermediate result buffer.
 
-    temp_id : Buffer
+    temp_cls_id : Buffer
         Intermediate result buffer.
 
-    temp_score_in : Buffer
+    temp_score : Buffer
         Input buffer which stores intermediate results.
 
     out : Buffer
@@ -265,9 +264,6 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
     batch_size : int
         Batch size
 
-    num_classes : int
-        Number of classes
-
     num_anchors : int
         Number of anchors
 
@@ -293,47 +289,55 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
         ph = loc[loc_base_idx + 3]
         ox = px * vx * aw + ax
         oy = py * vy * ah + ay
-        ow = tvm.exp(pw * vw) * aw / 2.0
-        oh = tvm.exp(ph * vh) * ah / 2.0
-        return tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \
-            tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \
-            tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \
-            tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh)
-
-    max_threads = int(
-        tvm.target.current_target(allow_none=False).max_num_threads)
+        ow = exp(pw * vw) * aw / 2.0
+        oh = exp(ph * vh) * ah / 2.0
+        return tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox - ow)), ox - ow), \
+            tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy - oh)), oy - oh), \
+            tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox + ow)), ox + ow), \
+            tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy + oh)), oy + oh)
+
     ib = tvm.ir_builder.create()
-    score = ib.buffer_ptr(temp_score_in)
-    cls_id = ib.buffer_ptr(temp_id)
-    flag = ib.buffer_ptr(temp_flag)
+
+    loc_pred = ib.buffer_ptr(loc_pred)
+    anchor = ib.buffer_ptr(anchor)
+    temp_valid_count = ib.buffer_ptr(temp_valid_count)
+    cls_id = ib.buffer_ptr(temp_cls_id)
+    score = ib.buffer_ptr(temp_score)
+    out_loc = ib.buffer_ptr(out)
+
+    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    nthread_bx = (batch_size * num_anchors) // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("blockIdx.x")
-    nthread_tx = max_threads
-    nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
-    p_loc_pred = ib.buffer_ptr(loc_pred)
-    p_anchor = ib.buffer_ptr(anchor)
-    p_out = ib.buffer_ptr(out)
 
     with ib.if_scope(tid < batch_size * num_anchors):
-        n = tid / num_anchors  # number of batches
-        i = tid % num_anchors  # number of anchors
+        i = tid / num_anchors
+        j = tid % num_anchors
         with ib.if_scope(cls_id[tid] > 0):
             with ib.if_scope(tid == 0):
-                out_base_idx = n * num_anchors * 6
+                out_base_idx = i * num_anchors * 6
+                out_loc[out_base_idx] = cls_id[tid] - 1.0
+                out_loc[out_base_idx + 1] = score[tid]
+                out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \
+                    out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4,
+                                                              anchor, j * 4, clip, variances[0],
+                                                              variances[1], variances[2],
+                                                              variances[3])
             with ib.else_scope():
-                out_base_idx = n * num_anchors * 6 + flag[tid - 1] * 6
-            p_out[out_base_idx] = cls_id[tid] - 1.0
-            p_out[out_base_idx + 1] = score[tid]
-            p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
-                p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4,
-                                                        p_anchor, i*4, clip, variances[0],
-                                                        variances[1], variances[2], variances[3])
+                out_base_idx = i * num_anchors * 6 + temp_valid_count[tid - 1] * 6
+                out_loc[out_base_idx] = cls_id[tid] - 1.0
+                out_loc[out_base_idx + 1] = score[tid]
+                out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \
+                    out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4,
+                                                              anchor, j * 4, clip, variances[0],
+                                                              variances[1], variances[2],
+                                                              variances[3])
 
-    body = ib.get()
-    return body
+    return ib.get()
 
 
 @multibox_transform_loc.register(["cuda", "gpu"])
@@ -372,44 +376,48 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
         1-D tensor with shape (batch_size,), number of valid anchor boxes.
     """
     batch_size = cls_prob.shape[0]
-    num_classes = cls_prob.shape[1]
     num_anchors = cls_prob.shape[2]
     oshape = (batch_size, num_anchors, 6)
     # Define data alignment for intermediate buffer
     valid_count_dtype = "int32"
+    out_loc_dtype = loc_pred.dtype
+
     valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
                                       "valid_count_buf", data_alignment=4)
-    out_buf = api.decl_buffer(
-        oshape, cls_prob.dtype, "out_buf", data_alignment=8)
-    size = num_anchors
-    temp_flag_buf = api.decl_buffer(
-        (size,), valid_count_dtype, "flag", data_alignment=8)
-    temp_id_buf = api.decl_buffer(
-        (size,), valid_count_dtype, "cls_id", data_alignment=8)
+    loc_pred_buf = api.decl_buffer(loc_pred.shape, loc_pred.dtype,
+                                   "loc_pred_buf", data_alignment=8)
+    anchor_buf = api.decl_buffer(anchor.shape, anchor.dtype,
+                                 "anchor_buf", data_alignment=8)
+
+    temp_valid_count_buf = api.decl_buffer(
+        (batch_size, num_anchors,), valid_count_dtype, "temp_valid_count", data_alignment=8)
+    temp_cls_id_buf = api.decl_buffer(
+        (batch_size, num_anchors,), valid_count_dtype, "temp_cls_id", data_alignment=8)
     temp_score_buf = api.decl_buffer(
-        (size,), cls_prob.dtype, "score", data_alignment=8)
+        (batch_size, num_anchors,), cls_prob.dtype, "temp_score", data_alignment=8)
 
-    valid_count, temp_flag, temp_id, temp_score = \
-        tvm.extern([(batch_size,), (size,), (size,), (size,)],
-                   [cls_prob],
+    valid_count, temp_valid_count, temp_cls_id, temp_score = \
+        tvm.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), \
+                    (batch_size, num_anchors,)], [cls_prob],
                    lambda ins, outs: transform_loc_pre(
                        ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
-                   dtype=[valid_count_dtype,
-                          valid_count_dtype, valid_count_dtype, cls_prob.dtype],
-                   out_buffers=[valid_count_buf,
-                                temp_flag_buf, temp_id_buf, temp_score_buf],
-                   tag="multibox_transform_loc_first_step")
+                   dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype],
+                   out_buffers=[valid_count_buf, temp_valid_count_buf, \
+                                temp_cls_id_buf, temp_score_buf],
+                   tag="multibox_transform_loc_phase_one")
 
-    out = \
+    out_loc = \
         tvm.extern([oshape],
-                   [loc_pred, anchor, temp_flag, temp_id, temp_score],
+                   [loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score],
                    lambda ins, outs: transform_loc_ir(
-                       ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \
-                       variances, batch_size, num_classes, num_anchors),
-                   dtype=[cls_prob.dtype],
-                   out_buffers=[out_buf],
+                       ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, variances, \
+                       batch_size, num_anchors),
+                   in_buffers=[loc_pred_buf, anchor_buf, temp_valid_count_buf, \
+                               temp_cls_id_buf, temp_score_buf],
+                   dtype=[out_loc_dtype],
                    tag="multibox_transform_loc")
-    return [out, valid_count]
+
+    return [out_loc, valid_count]
 
 
 @multibox_detection.register(["cuda", "gpu"])
@@ -453,6 +461,7 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
     """
     inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
                                        clip, threshold, variances)
-    out = non_max_suppression(
-        inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
+    out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1,
+                              iou_threshold=nms_threshold, force_suppress=force_suppress,
+                              top_k=nms_topk, return_indices=False)
     return out
index 5d7bc9e..78f5c1f 100644 (file)
@@ -32,11 +32,15 @@ def _default_schedule(outs):
 
     def traverse(op):
         """inline all one-to-one-mapping operators except the last stage (output)"""
-        if "nms" in op.tag:
-            sort = op.input_tensors[1]
+        if op.tag in ["nms", "invalid_to_bottom"]:
+            if op.tag == "nms":
+                sort = op.input_tensors[1]
+            else:
+                out = op.input_tensors[0]
+                sort = s[out].op.input_tensors[1]
             score = s[sort].op.input_tensors[0]
             fused = s[score].fuse(*s[score].op.axis)
-            num_thread = tvm.target.current_target(allow_none=False).max_num_threads
+            num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads)
             bx, tx = s[score].split(fused, factor=num_thread)
             s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
             s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
@@ -199,3 +203,30 @@ def schedule_get_valid_counts(outs):
       The computation schedule for the op.
     """
     return _default_schedule(outs)
+
+@generic.schedule_argsort.register(["cuda", "gpu"])
+def schedule_argsort(outs):
+    """Schedule for argsort operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of argsort
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    scheduled_ops = []
+    from .injective import _schedule_injective
+    def traverse(op):
+        for tensor in op.input_tensors:
+            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+                traverse(tensor.op)
+        scheduled_ops.append(op)
+    traverse(outs[0].op)
+    return s
index 8450e2d..6bf5f3a 100644 (file)
@@ -19,3 +19,4 @@ from .nn import *
 from .injective import *
 from .extern import *
 from .vision import *
+from .sort import *
diff --git a/topi/python/topi/generic/sort.py b/topi/python/topi/generic/sort.py
new file mode 100644 (file)
index 0000000..1ad088c
--- /dev/null
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member
+"""Generic vision operators"""
+from __future__ import absolute_import as _abs
+import tvm
+from .vision import _default_schedule
+
+@tvm.target.generic_func
+def schedule_argsort(outs):
+    """Schedule for argsort operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+      The indices that would sort an input array along
+      the given axis.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py
new file mode 100644 (file)
index 0000000..84fff8d
--- /dev/null
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-arguments
+"""Argsort operator"""
+import tvm
+from tvm import api
+
+@tvm.target.generic_func
+def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
+    """Performs sorting along the given axis and returns an array
+    of indices having the same shape as an input array that index
+    data in sorted order.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        The input tensor.
+
+    valid_count : tvm.Tensor
+        1-D tensor for valid number of boxes only for ssd.
+
+    axis : optional, int
+       Axis along which to sort the input tensor.
+        By default the flattened array is used.
+
+    is_ascend : optional, boolean
+        Whether to sort in ascending or descending order.
+
+    dtype : optional, string
+        DType of the output indices.
+
+    flag : optional, boolean
+        Whether valid_count is valid.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        Sorted index tensor.
+
+    Example
+    --------
+    .. code-block:: python
+
+        # An example to use argsort
+        dshape = (1, 5, 6)
+        data = tvm.placeholder(dshape, name="data")
+        valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
+        axis = 0
+        is_ascend = False
+        flag = False
+        out = argsort(data, valid_count, axis, is_ascend, flag)
+        np_data = np.random.uniform(dshape)
+        np_valid_count = np.array([4])
+        s = topi.generic.schedule_argsort(out)
+        f = tvm.build(s, [data, valid_count, out], "llvm")
+        ctx = tvm.cpu()
+        tvm_data = tvm.nd.array(np_data, ctx)
+        tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
+        tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
+        f(tvm_data, tvm_valid_count, tvm_out)
+    """
+    data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    if flag:
+        valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
+                                          "valid_count_buf", data_alignment=4)
+        out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
+        out = \
+            tvm.extern(data.shape,
+                       [data, valid_count],
+                       lambda ins, outs: tvm.call_packed(
+                           "tvm.contrib.sort.argsort_nms", ins[0], ins[1],
+                           outs[0], axis, is_ascend),
+                       dtype="int32",
+                       in_buffers=[data_buf, valid_count_buf],
+                       out_buffers=out_buf,
+                       name="argsort_nms_cpu",
+                       tag="argsort_nms_cpu")
+    else:
+        out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+        out = \
+            tvm.extern(data.shape,
+                       [data],
+                       lambda ins, outs: tvm.call_packed(
+                           "tvm.contrib.sort.argsort", ins[0],
+                           outs[0], axis, is_ascend),
+                       dtype=dtype,
+                       in_buffers=[data_buf],
+                       out_buffers=out_buf,
+                       name="argsort_cpu",
+                       tag="argsort_cpu")
+    return out
index d8b15aa..979565d 100644 (file)
@@ -18,7 +18,8 @@
 """Non-maximum suppression operator"""
 import tvm
 
-from tvm import api, hybrid
+from tvm import hybrid
+from ..sort import argsort
 
 @hybrid.script
 def hybrid_rearrange_out(data):
@@ -129,7 +130,7 @@ def get_valid_counts(data, score_threshold=0):
 @hybrid.script
 def hybrid_nms(data, sorted_index, valid_count,
                max_output_size, iou_threshold, force_suppress,
-               top_k, id_index):
+               top_k, coord_start, id_index):
     """Hybrid routing for non-maximum suppression.
 
     Parameters
@@ -158,6 +159,9 @@ def hybrid_nms(data, sorted_index, valid_count,
     top_k : tvm.const
         Keep maximum top k detections before nms, -1 for no limit.
 
+    coord_start : tvm.const
+        Start index of the consecutive 4 coordinates.
+
     id_index : tvm.const
         index of the class categories, -1 to disable.
 
@@ -208,7 +212,7 @@ def hybrid_nms(data, sorted_index, valid_count,
                             batch_idx = i
                             box_a_idx = j
                             box_b_idx = k
-                            box_start_idx = 2
+                            box_start_idx = coord_start
                             a_t = output[batch_idx, box_a_idx, box_start_idx + 1]
                             a_b = output[batch_idx, box_a_idx, box_start_idx + 3]
                             a_l = output[batch_idx, box_a_idx, box_start_idx]
@@ -252,7 +256,8 @@ def hybrid_nms(data, sorted_index, valid_count,
 @tvm.target.generic_func
 def non_max_suppression(data, valid_count, max_output_size=-1,
                         iou_threshold=0.5, force_suppress=False, top_k=-1,
-                        id_index=0, return_indices=True, invalid_to_bottom=False):
+                        coord_start=2, score_index=1, id_index=0,
+                        return_indices=True, invalid_to_bottom=False):
     """Non-maximum suppression operator for object detection.
 
     Parameters
@@ -278,6 +283,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
     top_k : optional, int
         Keep maximum top k detections before nms, -1 for no limit.
 
+    coord_start : required, int
+        Start index of the consecutive 4 coordinates.
+
+    score_index: optional, int
+        Index of the scores/confidence of boxes.
+
     id_index : optional, int
         index of the class categories, -1 to disable.
 
@@ -317,32 +328,16 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
-    valid_count_dtype = "int32"
-    valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
-                                      "valid_count_buf", data_alignment=4)
-    score_axis = 1
+    score_axis = score_index
     score_shape = (batch_size, num_anchors)
     score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
-    score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
-                                       "score_tensor_buf", data_alignment=8)
-    sort_tensor_dtype = "int32"
-    sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
-                                      "sort_tensor_buf", data_alignment=8)
-    sort_tensor = \
-        tvm.extern(score_shape,
-                   [score_tensor, valid_count],
-                   lambda ins, outs: tvm.call_packed(
-                       "tvm.contrib.sort.argsort", ins[0], ins[1],
-                       outs[0], score_axis, True),
-                   dtype=sort_tensor_dtype,
-                   in_buffers=[score_tensor_buf, valid_count_buf],
-                   out_buffers=sort_tensor_buf,
-                   name="nms_sort")
+    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
     out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
                                   tvm.const(max_output_size, dtype="int32"),
                                   tvm.const(iou_threshold, dtype="float32"),
                                   tvm.const(force_suppress, dtype="bool"),
                                   tvm.const(top_k, dtype="int32"),
+                                  tvm.const(coord_start, dtype="int32"),
                                   tvm.const(id_index, dtype="int32"))
     if not return_indices and invalid_to_bottom:
         out = hybrid_rearrange_out(out)
index 7996690..ca1b4a9 100644 (file)
@@ -308,7 +308,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
     """
     inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
                                        clip, threshold, variances)
-    out = non_max_suppression(inter_out[0], inter_out[1], -1,
-                              nms_threshold, force_suppress, nms_topk,
-                              return_indices=False)
+    out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1,
+                              iou_threshold=nms_threshold, force_suppress=force_suppress,
+                              top_k=nms_topk, return_indices=False)
     return out
diff --git a/topi/tests/python/test_topi_sort.py b/topi/tests/python/test_topi_sort.py
new file mode 100644 (file)
index 0000000..3a2c9c2
--- /dev/null
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for vision package"""
+from __future__ import print_function
+import math
+import numpy as np
+import tvm
+import topi
+import topi.testing
+
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+from topi import argsort
+
+def test_argsort():
+    dshape = (1, 8)
+    valid_count_shape = (2,)
+    data = tvm.placeholder(dshape, name="data", dtype="float32")
+    valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
+    np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
+    np_valid_count = np.array([4]).astype(valid_count.dtype)
+    np_result = np.argsort(-np_data)
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False)
+            s = topi.generic.schedule_argsort(out)
+
+        tvm_data = tvm.nd.array(np_data, ctx)
+        tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
+        tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
+        f = tvm.build(s, [data, valid_count, out], device)
+        f(tvm_data, tvm_valid_count, tvm_out)
+        tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
+
+    for device in ['llvm', 'cuda', 'opencl']:
+        check_device(device)
+
+
+if __name__ == "__main__":
+    test_argsort()
index 6bb57b5..483f3a6 100644 (file)
@@ -66,7 +66,7 @@ def verify_get_valid_counts(dshape, score_threshold):
         tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
         tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
 
-    for device in ['llvm']:
+    for device in ['llvm', 'cuda', 'opencl']:
         check_device(device)
 
 
@@ -124,7 +124,7 @@ def test_non_max_suppression():
         f(tvm_data, tvm_valid_count, tvm_indices_out)
         tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
 
-    for device in ['llvm']:
+    for device in ['llvm', 'cuda', 'opencl']:
         check_device(device)
 
 
@@ -231,7 +231,7 @@ def test_multibox_detection():
         f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out)
         tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4)
 
-    for device in ['llvm', 'opencl']:
+    for device in ['llvm', 'opencl', 'cuda']:
         check_device(device)
 
 
@@ -275,7 +275,7 @@ def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_s
         f(tvm_a, tvm_rois, tvm_b)
         tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3)
 
-    for device in ['llvm', 'cuda']:
+    for device in ['llvm', 'cuda', 'opencl']:
         check_device(device)
 
 
index fe84283..ff7691c 100644 (file)
@@ -18,6 +18,7 @@
 Deploy Single Shot Multibox Detector(SSD) model
 ===============================================
 **Author**: `Yao Wang <https://github.com/kevinthesun>`_
+`Leyuan Wang <https://github.com/Laurawly>`_
 
 This article is an introductory tutorial to deploy SSD models with TVM.
 We will use GluonCV pre-trained SSD model and convert it to Relay IR
@@ -37,30 +38,29 @@ from gluoncv import model_zoo, data, utils
 # ------------------------------
 # .. note::
 #
-#   Currently we support compiling SSD on CPU only.
-#   GPU support is in progress.
+#   We support compiling SSD on bot CPUs and GPUs now.
 #
 #   To get best inference performance on CPU, change
 #   target argument according to your device and
 #   follow the :ref:`tune_relay_x86` to tune x86 CPU and
 #   :ref:`tune_relay_arm` for arm cpu.
 #
+#   To get best performance fo SSD on Intel graphics,
+#   change target argument to 'opencl -device=intel_graphics'
+#
 #   SSD with VGG as body network is not supported yet since
 #   x86 conv2d schedule doesn't support dilation.
 
 supported_model = [
-    'ssd_512_resnet18_v1_voc',
-    'ssd_512_resnet18_v1_coco',
     'ssd_512_resnet50_v1_voc',
     'ssd_512_resnet50_v1_coco',
     'ssd_512_resnet101_v2_voc',
-    'ssd_512_mobilenet1_0_voc',
-    'ssd_512_mobilenet1_0_coco',
+    'ssd_512_mobilenet1.0_voc',
+    'ssd_512_mobilenet1.0_coco',
 ]
 
-model_name = "ssd_512_resnet50_v1_voc"
+model_name = supported_model[0]
 dshape = (1, 3, 512, 512)
-dtype = "float32"
 target_list = ctx_list()
 
 ######################################################################
@@ -76,7 +76,7 @@ x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
 
 block = model_zoo.get_model(model_name, pretrained=True)
 
-def compile(target):
+def build(target):
     net, params = relay.frontend.from_mxnet(block, {"data": dshape})
     with relay.build_config(opt_level=3):
         graph, lib, params = relay.build(net, target, params=params)
@@ -98,10 +98,7 @@ def run(graph, lib, params, ctx):
     return class_IDs, scores, bounding_boxs
 
 for target, ctx in target_list:
-    if target == "cuda":
-        print("GPU not supported yet, skip.")
-        continue
-    graph, lib, params = compile(target)
+    graph, lib, params = build(target)
     class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
 
 ######################################################################