[Relay/TOPI][Op] Add TopK operator (#3256)
authorHaichen Shen <shenhaichen@gmail.com>
Tue, 4 Jun 2019 23:29:56 +0000 (16:29 -0700)
committerLeyuan Wang <laurawly@gmail.com>
Tue, 4 Jun 2019 23:29:56 +0000 (16:29 -0700)
* init impl for topk

* Fix cpu for topk

* init cuda impl for topk

* Add cuda for topk

* fix

* Add doc

* update doc

* lint

* lint

* lint

* x

* fix warning

* [Relay] Add TopK in tf converter

* Add frontend converter

* fix

24 files changed:
docs/api/python/topi.rst
docs/langref/relay_op.rst
include/tvm/relay/attrs/algorithm.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/op/_algorithm.py
python/tvm/relay/op/algorithm.py
python/tvm/relay/op/nn/nn.py
python/tvm/relay/op/transform.py
src/codegen/build_module.cc
src/contrib/sort/sort.cc
src/relay/op/algorithm/argsort.cc [moved from src/relay/op/algorithm/sort.cc with 94% similarity]
src/relay/op/algorithm/topk.cc [new file with mode: 0644]
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_op_level6.py
topi/python/topi/cuda/__init__.py
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/sort.py
topi/python/topi/generic/sort.py
topi/python/topi/sort.py
topi/python/topi/transform.py
topi/python/topi/vision/nms.py
topi/tests/python/test_topi_sort.py

index 0b217d4..ade0f1a 100644 (file)
@@ -99,6 +99,8 @@ List of operators
    topi.shape
    topi.layout_transform
    topi.image.resize
+   topi.argsort
+   topi.topk
 
 
 List of schedules
@@ -163,6 +165,8 @@ topi
 .. autofunction:: topi.tile
 .. autofunction:: topi.shape
 .. autofunction:: topi.layout_transform
+.. autofunction:: topi.argsort
+.. autofunction:: topi.topk
 
 topi.nn
 ~~~~~~~
index 836f8f3..28ee99e 100644 (file)
@@ -172,6 +172,7 @@ This level enables additional math and transform operators.
    :nosignatures:
 
    tvm.relay.argsort
+   tvm.relay.topk
 
 
 **Level 10: Temporary Operators**
@@ -309,6 +310,7 @@ Level 5 Definitions
 Level 6 Definitions
 -------------------
 .. autofunction:: tvm.relay.argsort
+.. autofunction:: tvm.relay.topk
 
 
 Level 10 Definitions
index 20f135c..f5ba699 100644 (file)
@@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
   }
 };
 
+struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
+  int k;
+  int axis;
+  bool is_ascend;
+  std::string ret_type;
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
+    TVM_ATTR_FIELD(k).set_default(1)
+      .describe("Number of top elements to select");
+    TVM_ATTR_FIELD(axis).set_default(-1)
+      .describe("Axis along which to sort the input tensor.");
+    TVM_ATTR_FIELD(ret_type).set_default("both")
+      .describe("The return type [both, values, indices]."
+                "both - return both top k data and indices."
+                "values - return top k data only."
+                "indices - return top k indices only.");
+    TVM_ATTR_FIELD(is_ascend).set_default(false)
+      .describe("Whether to sort in ascending or descending order."
+                "By default, sort in descending order");
+    TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
+      .describe("Data type of the output indices.");
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_ALGORITHM_H_
index 0bc7923..0975a33 100644 (file)
@@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs):
     return _op.argsort(inputs[0], **new_attrs)
 
 
+def _mx_topk(inputs, attrs):
+    assert len(inputs) == 1
+    new_attrs = {}
+    new_attrs["k"] = attrs.get_int("k", 1)
+    new_attrs["axis"] = attrs.get_int("axis", -1)
+    new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
+    ret_type = attrs.get_str("ret_typ", "indices")
+    if ret_type == "mask":
+        raise tvm.error.OpAttributeUnimplemented(
+            "Attribute ret_type=mask is not supported in topk operator")
+    new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type
+    new_attrs["dtype"] = attrs.get_str("dtype", "float32")
+    return _op.topk(inputs[0], **new_attrs)
+
+
 def _mx_rnn_param_concat(inputs, _):
     # We don't need to concatenate RNN params because we will unravel the RNN op
     return [inputs]
@@ -914,6 +929,7 @@ _convert_map = {
     "shape_array"   : _mx_shape_array,
     "Embedding"     : _mx_embedding,
     "argsort"       : _mx_argsort,
+    "topk"          : _mx_topk,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
     "LinearRegressionOutput" : _mx_linear_regression_output,
index 7fe82ea..307fb20 100644 (file)
@@ -1082,6 +1082,20 @@ def _softplus():
         return _get_relay_op('log')(add_out)
     return _impl
 
+def _topk():
+    def _impl(inputs, attr, params):
+        k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
+        if k < 1:
+            raise tvm.error.OpAttributeInvalid(
+                'Attribute k must be positive in operator TopKV2')
+        if attr['sorted'] is False:
+            raise tvm.error.OpAttributeUnimplemented(
+                'Attribute sorted=False is not supported in operator TopKV2')
+        return AttrCvt(op_name='topk',
+                       ignores=['sorted'],
+                       extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
+    return _impl
+
 def _logical(name):
     def _impl(inputs, attr, params):
         return AttrCvt(op_name=name)(inputs, attr)
@@ -1271,6 +1285,7 @@ _convert_map = {
     'Sum'                               : _sum(),
     'Tanh'                              : AttrCvt('tanh'),
     'Tile'                              : _tile(),
+    'TopKV2'                            : _topk(),
     'Transpose'                         : _transpose(),
     'Unpack'                            : _unpack(),
 
index 57e7165..09746be 100644 (file)
@@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target):
     """Compute definition of argsort"""
     axis = get_const_int(attrs.axis)
     is_ascend = bool(get_const_int(attrs.is_ascend))
-    dtype = str(attrs.dtype)
-    return [
-        topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
-                            dtype=dtype, flag=False)
-    ]
+    dtype = attrs.dtype
+    return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
 
 
 register_pattern("argsort", OpPattern.OPAQUE)
+
+
+@register_schedule("topk")
+def schedule_topk(_, outs, target):
+    """Schedule definition of argsort"""
+    with target:
+        return topi.generic.schedule_topk(outs)
+
+
+@register_compute("topk")
+def compute_topk(attrs, inputs, _, target):
+    """Compute definition of argsort"""
+    k = get_const_int(attrs.k)
+    axis = get_const_int(attrs.axis)
+    ret_type = attrs.ret_type
+    is_ascend = bool(get_const_int(attrs.is_ascend))
+    dtype = attrs.dtype
+    out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
+    out = out if isinstance(out, list) else [out]
+    return out
+
+
+register_pattern("topk", OpPattern.OPAQUE)
index 6451eb4..6f87591 100644 (file)
@@ -17,8 +17,9 @@
 """Classic algorithm operation"""
 from __future__ import absolute_import as _abs
 from . import _make
+from ..expr import TupleWrapper
 
-def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
+def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
 
@@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
         Whether to sort in ascending or descending order.
 
     dtype : string, optional
-        DType of the output indices.
+        The data type of the output indices.
 
     Returns
     -------
@@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
         Tensor with same shape as data.
     """
     return _make.argsort(data, axis, is_ascend, dtype)
+
+
+def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
+    """Get the top k elements in an input tensor along the given axis.
+
+    ret_type specifies the return type, can be one of ("both", "values", "indices").
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data tensor.
+
+    k : int, optional
+        Number of top elements to select. Return all elements if k < 1.
+
+    axis : int, optional
+        Axis long which to sort the input tensor.
+
+    ret_type: str, optional
+        The return type [both, values, indices].
+        "both": return both top k data and indices.
+        "values": return top k data only.
+        "indices": return top k indices only.
+
+    is_ascend : boolean, optional
+        Whether to sort in ascending or descending order.
+
+    dtype : string, optional
+        The data type of the indices output.
+
+    Returns
+    -------
+    out : relay.Expr or List[relay.Expr]
+        The computed result.
+    """
+    out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
+    if ret_type == "both":
+        return TupleWrapper(out, 2)
+    return out
index b772c43..b4ebffb 100644 (file)
@@ -401,7 +401,7 @@ def upsampling(data,
     with data of shape (n, c, h, w)
     out will have a shape (n, c, h*scale, w*scale)
 
-    method indicates the algorithm to be used while calculating ghe out value
+    method indicates the algorithm to be used while calculating the out value
     and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
 
     Parameters
index 9c76b7e..02fd492 100644 (file)
@@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"):
         the flattened input array is used.
 
     mode : str, optional
-        Specifies how out-of-bound indices will behave.
-        clip - clip to the range (default)
-        wrap - wrap around the indices
+        Specifies how out-of-bound indices will behave [clip, wrap].
+        clip: clip to the range (default).
+        wrap: wrap around the indices.
 
     Returns
     -------
index 834b4ee..0a488f3 100644 (file)
@@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name,
     t->device_type = kDLGPU;
     t->keys_array.push_back(ir::StringImm::make("cuda"));
     t->keys_array.push_back(ir::StringImm::make("gpu"));
-    t->max_num_threads = 512;
+    t->max_num_threads = 1024;
     t->thread_warp_size = 32;
   } else if (target_name == "rocm" || target_name == "opencl") {
     // For now assume rocm schedule for opencl
index cf25e89..87691f2 100644 (file)
@@ -34,14 +34,14 @@ namespace contrib {
 using namespace runtime;
 
 template<typename DType>
-bool CompareAscend(const std::pair<int32_t, DType>& lhs,
-                   const std::pair<int32_t, DType>& rhs) {
+bool CompareAscend(const std::pair<int64_t, DType>& lhs,
+                   const std::pair<int64_t, DType>& rhs) {
   return lhs.second < rhs.second;
 }
 
 template<typename DType>
-bool CompareDescend(const std::pair<int32_t, DType>& lhs,
-                    const std::pair<int32_t, DType>& rhs) {
+bool CompareDescend(const std::pair<int64_t, DType>& lhs,
+                    const std::pair<int64_t, DType>& rhs) {
   return lhs.second > rhs.second;
 }
 
@@ -110,6 +110,41 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
   }
 });
 
+template<typename DataType, typename OutType>
+void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
+  auto data_ptr = static_cast<DataType *>(input->data);
+  auto out_ptr = static_cast<OutType *>(output->data);
+  std::vector<std::pair<int64_t, DataType> > sorter;
+
+  int axis_mul_before = 1;
+  int axis_mul_after = 1;
+  for (int i = 0; i < input->ndim; ++i) {
+    if (i < axis) {
+      axis_mul_before *= input->shape[i];
+    } else if (i > axis) {
+      axis_mul_after *= input->shape[i];
+    }
+  }
+
+  for (int i = 0 ; i < axis_mul_before; ++i) {
+    for (int j = 0 ; j < axis_mul_after; ++j) {
+      sorter.clear();
+      int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
+      for (int64_t k = 0; k < input->shape[axis]; ++k) {
+        int64_t full_idx = base_idx + k * axis_mul_after;
+        sorter.emplace_back(std::make_pair(k, data_ptr[full_idx]));
+      }
+      if (is_ascend) {
+        std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
+      } else {
+        std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
+      }
+      for (int64_t k = 0; k < input->shape[axis]; ++k) {
+        out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].first);
+      }
+    }
+  }
+}
 
 // Argsort implemented C library sort.
 // Return indices of sorted tensor.
@@ -124,25 +159,84 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
   DLTensor *output = args[1];
   int32_t axis = args[2];
   bool is_ascend = args[3];
-
-  auto dtype = input->dtype;
-  auto data_ptr = static_cast<float *>(input->data);
-  std::vector<std::pair<float, float>> sorter;
-  int64_t axis_mul_before = 1;
-  int64_t axis_mul_after = 1;
-
   if (axis < 0) {
     axis = input->ndim + axis;
   }
-
-  // Currently only supports input dtype to be float32.
-  CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
-      "to be float32.";
-  CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
-      "to be float32.";
   CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
-      "input ndim " << input->ndim;
+                                 "input ndim " << input->ndim;
+
+  auto data_dtype = TVMType2String(input->dtype);
+  auto out_dtype = TVMType2String(output->dtype);
+
+  if (data_dtype == "float32") {
+    if (out_dtype == "int32") {
+      argsort<float, int32_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      argsort<float, int64_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      argsort<float, float>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      argsort<float, double>(input, output, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  } else if (data_dtype == "float64") {
+    if (out_dtype == "int32") {
+      argsort<double, int32_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      argsort<double, int64_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      argsort<double, float>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      argsort<double, double>(input, output, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  } else if (data_dtype == "int32") {
+    if (out_dtype == "int32") {
+      argsort<int32_t, int32_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      argsort<int32_t, int64_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      argsort<int32_t, float>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      argsort<int32_t, double>(input, output, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  }  else if (data_dtype == "int64") {
+    if (out_dtype == "int32") {
+      argsort<int64_t, int32_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      argsort<int64_t, int64_t>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      argsort<int64_t, float>(input, output, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      argsort<int64_t, double>(input, output, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  } else {
+    LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
+  }
+});
 
+template<typename DataType, typename IndicesType>
+void topk(DLTensor* input,
+          DLTensor* out_values,
+          DLTensor* out_indices,
+          int k,
+          int axis,
+          bool is_ascend) {
+  DataType* data_ptr = static_cast<DataType *>(input->data);
+  DataType* values_ptr = (out_values == nullptr) ? nullptr :
+          static_cast<DataType *>(out_values->data);
+  IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr :
+          static_cast<IndicesType *>(out_indices->data);
+  std::vector<std::pair<int64_t, DataType> > sorter;
+
+  int axis_mul_before = 1;
+  int axis_mul_after = 1;
   for (int i = 0; i < input->ndim; ++i) {
     if (i < axis) {
       axis_mul_before *= input->shape[i];
@@ -150,27 +244,124 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
       axis_mul_after *= input->shape[i];
     }
   }
+  if (k < 1) {
+    k = input->shape[axis];
+  }
 
-  int32_t current_sort_num = input->shape[axis];
-  for (int64_t i = 0 ; i < axis_mul_before; ++i) {
-    for (int64_t j = 0 ; j < axis_mul_after; ++j) {
+  for (int i = 0 ; i < axis_mul_before; ++i) {
+    for (int j = 0 ; j < axis_mul_after; ++j) {
       sorter.clear();
-      int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
-      for (int64_t k = 0; k < current_sort_num; ++k) {
-        int64_t full_idx = base_idx + k * axis_mul_after;
-        sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
+      int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j;
+      int64_t dst_base_idx = i * k * axis_mul_after + j;
+      for (int64_t kk = 0; kk < input->shape[axis]; ++kk) {
+        int64_t full_idx = src_base_idx + kk * axis_mul_after;
+        sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx]));
       }
       if (is_ascend) {
-        std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
+        std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>);
       } else {
-        std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
+        std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
       }
-      for (int32_t k = 0; k < input->shape[axis]; ++k) {
-        *(static_cast<float *>(output->data) + base_idx + k * axis_mul_after)
-            = k < static_cast<float>(sorter.size()) ? sorter[k].first : k;
+      int64_t cnt = k > 0 ? k : input->shape[axis];
+      for (int64_t kk = 0; kk < cnt; ++kk) {
+        if (indices_ptr != nullptr) {
+          indices_ptr[dst_base_idx + kk * axis_mul_after] =
+                  static_cast<IndicesType>(sorter[kk].first);
+        }
+        if (values_ptr != nullptr) {
+          values_ptr[dst_base_idx + kk * axis_mul_after] =
+                  static_cast<DataType>(sorter[kk].second);
+        }
       }
     }
   }
+}
+
+// Argsort implemented C library sort.
+// Return indices of sorted tensor.
+// By default, the last axis will be used to sort.
+// sort_num specify the number of elements to be sorted.
+// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
+// and sort axis is dk. sort_num should have dimension of
+// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  DLTensor* input = args[0];
+  DLTensor* values_out = nullptr;
+  DLTensor* indices_out = nullptr;
+  int k = args[args.num_args - 4];
+  int axis = args[args.num_args - 3];
+  std::string ret_type = args[args.num_args - 2];
+  bool is_ascend = args[args.num_args - 1];
+  if (ret_type == "both") {
+    values_out = args[1];
+    indices_out = args[2];
+  } else if (ret_type == "values") {
+    values_out = args[1];
+  } else if (ret_type == "indices") {
+    indices_out = args[1];
+  } else {
+    LOG(FATAL) << "Unsupported ret type: " << ret_type;
+  }
+  if (axis < 0) {
+    axis = input->ndim + axis;
+  }
+  CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim;
+
+  auto data_dtype = TVMType2String(input->dtype);
+  auto out_dtype = (indices_out == nullptr) ? "int64" : TVMType2String(indices_out->dtype);
+
+  if (data_dtype == "float32") {
+    if (out_dtype == "int32") {
+      topk<float, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      topk<float, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      topk<float, float>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      topk<float, double>(input, values_out, indices_out, k, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  } else if (data_dtype == "float64") {
+    if (out_dtype == "int32") {
+      topk<double, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      topk<double, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      topk<double, float>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      topk<double, double>(input, values_out, indices_out, k, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  } else if (data_dtype == "int32") {
+    if (out_dtype == "int32") {
+      topk<int32_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      topk<int32_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      topk<int32_t, float>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      topk<int32_t, double>(input, values_out, indices_out, k, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  }  else if (data_dtype == "int64") {
+    if (out_dtype == "int32") {
+      topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "int64") {
+      topk<int64_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float32") {
+      topk<int64_t, float>(input, values_out, indices_out, k, axis, is_ascend);
+    } else if (out_dtype == "float64") {
+      topk<int64_t, double>(input, values_out, indices_out, k, axis, is_ascend);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  } else {
+    LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
+  }
 });
 
 }  // namespace contrib
similarity index 94%
rename from src/relay/op/algorithm/sort.cc
rename to src/relay/op/algorithm/argsort.cc
index 5777b79..31aa888 100644 (file)
@@ -18,9 +18,9 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
- * \file nms.cc
- * \brief Non-maximum suppression operators
+ *  Copyright (c) 2019 by Contributors
+ * \file argsort.cc
+ * \brief Argsort operators
  */
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/algorithm.h>
@@ -44,7 +44,6 @@ bool ArgsortRel(const Array<Type>& types,
         << types[0];
     return false;
   }
-  CHECK_EQ(param->dtype, Float(32));
   reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
   return true;
 }
@@ -74,5 +73,6 @@ input array along the given axis.
 .add_argument("data", "Tensor", "Input data.")
 .set_support_level(6)
 .add_type_rel("Argsort", ArgsortRel);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc
new file mode 100644 (file)
index 0000000..c88e2c3
--- /dev/null
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file topk.cc
+ * \brief TopK operators
+ */
+#include <tvm/relay/op.h>
+#include <tvm/relay/attrs/algorithm.h>
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(TopKAttrs);
+
+bool TopKRel(const Array<Type>& types,
+             int num_inputs,
+             const Attrs& attrs,
+             const TypeReporter& reporter) {
+  // `types` contains: [data, result]
+  const TopKAttrs* param = attrs.as<TopKAttrs>();
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  CHECK(data);
+  int ndim = data->shape.size();
+  int axis = param->axis;
+  if (axis < 0) {
+    axis += ndim;
+  }
+  CHECK(axis >= 0 && axis < ndim);
+  Array<IndexExpr> out_shape;
+  for (int i = 0; i < ndim; ++i) {
+    if (i != axis || param->k < 1) {
+      out_shape.push_back(data->shape[i]);
+    } else {
+      out_shape.push_back(param->k);
+    }
+  }
+  auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
+  auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
+  if (param->ret_type == "both") {
+    reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty}));
+  } else if (param->ret_type == "values") {
+    reporter->Assign(types[1], values_ty);
+  } else if (param->ret_type == "indices") {
+    reporter->Assign(types[1], indices_ty);
+  } else {
+    LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
+  }
+  return true;
+}
+
+Expr MakeTopK(Expr data,
+              int k,
+              int axis,
+              std::string ret_type,
+              bool is_ascend,
+              DataType dtype) {
+  auto attrs = make_node<TopKAttrs>();
+  attrs->k = k;
+  attrs->axis = axis;
+  attrs->ret_type = ret_type;
+  attrs->is_ascend = is_ascend;
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("topk");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_API("relay.op._make.topk")
+.set_body_typed(MakeTopK);
+
+RELAY_REGISTER_OP("topk")
+.describe(R"doc(Get the top k elements in an input tensor along the given axis.
+)doc" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.TopKAttrs")
+.add_argument("data", "Tensor", "Input data.")
+.set_support_level(6)
+.add_type_rel("TopK", TopKRel);
+
+}  // namespace relay
+}  // namespace tvm
+
index 50a25a9..7569257 100644 (file)
@@ -608,6 +608,45 @@ def test_forward_Crop():
     verify((5, 32, 40, 40), (5, 32, 25, 25))
     verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
 
+def test_forward_argsort():
+    def verify(shape, axis, is_ascend, dtype="float32"):
+        x_np = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype)
+        mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype)
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(x_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((2, 3, 4), axis=0, is_ascend=False)
+    verify((1, 4, 6), axis=1, is_ascend=True)
+    verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32")
+
+def test_forward_topk():
+    def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
+        x_np = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type,
+                             is_ascend=is_ascend, dtype=dtype)
+        mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type,
+                             is_ascend=is_ascend, dtype=dtype)
+        new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(new_sym)(x_np)
+                if isinstance(ref_res, list):
+                    assert len(op_res) == len(ref_res)
+                    for i, t in enumerate(op_res):
+                        tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy())
+                else:
+                    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((3, 4), k=1, axis=0, ret_type="both")
+    verify((3, 4), k=1, axis=-1, ret_type="indices")
+    verify((3, 5, 6), k=2, axis=2, ret_type="value")
+    verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
+    verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
+
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -650,3 +689,5 @@ if __name__ == '__main__':
     test_forward_bilinear_resize()
     test_forward_rnn_layer()
     test_forward_Crop()
+    test_forward_argsort()
+    test_forward_topk()
index 023cdf5..eebb73c 100644 (file)
@@ -754,6 +754,24 @@ def test_forward_split():
     _test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
 
 
+######################################################################
+# TopKV2
+# ------
+
+def _test_forward_top_k_v2(in_shape, k):
+    np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32")
+    tf.reset_default_graph()
+    in_data = tf.placeholder("float32", in_shape, name="in_data")
+    tf.math.top_k(in_data, k, name='TopK')
+    compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
+
+def test_forward_top_k_v2():
+    _test_forward_top_k_v2((3,), 1)
+    _test_forward_top_k_v2((3,), 3)
+    _test_forward_top_k_v2((3, 5, 7), 3)
+    _test_forward_top_k_v2((3, 5, 7), 3)
+
+
 #######################################################################
 # Unstack
 # -------
@@ -1704,6 +1722,7 @@ if __name__ == '__main__':
     test_forward_split()
     test_forward_unstack()
     test_forward_tile()
+    test_forward_top_k_v2()
 
     # Activations
     test_forward_sigmoid()
index 983a915..76478ba 100644 (file)
 # under the License.
 """ Support level6 operator test cases.
 """
-import math
 import numpy as np
 import tvm
 from tvm import relay
 from tvm.relay.testing import ctx_list
-import topi.testing
 
 def test_argsort():
-    def verify_argsort(shape, axis, is_ascend):
+    def verify_argsort(shape, axis, is_ascend, dtype):
         x = relay.var("x", relay.TensorType(shape, "float32"))
-        z = relay.argsort(x, axis=axis, is_ascend=is_ascend)
-        zz = relay.ir_pass.infer_type(z)
+        z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
         func = relay.Function([x], z)
         x_data = np.random.uniform(size=shape).astype("float32")
         if is_ascend:
@@ -39,11 +36,58 @@ def test_argsort():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5)
-    verify_argsort((2, 3, 4), axis=0, is_ascend=False)
-    verify_argsort((1, 4, 6), axis=1, is_ascend=True)
-    verify_argsort((3, 5, 6), axis=-1, is_ascend=False)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype(dtype), rtol=1e-5)
+    for dtype in ["int32", "int64", "float32", "float64"]:
+        verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype)
+        verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype)
+        verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype)
+
+
+def test_topk():
+    def verify_topk(k, axis, ret_type, is_ascend, dtype):
+        shape = (20, 100)
+        x = relay.var("x", relay.TensorType(shape, "float32"))
+        out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
+        if isinstance(out, relay.expr.TupleWrapper):
+            out = out.astuple()
+        func = relay.Function([x], out)
+        np_data = np.random.uniform(size=shape).astype("float32")
+        if is_ascend:
+            np_indices = np.argsort(np_data, axis=axis)
+        else:
+            np_indices = np.argsort(-np_data, axis=axis)
+        kk = k if k >= 1 else shape[axis]
+        if axis == 0:
+            np_indices = np_indices[:kk, :]
+            np_values = np.zeros(np_indices.shape).astype("float32")
+            for i in range(shape[1]):
+                np_values[:, i] = np_data[np_indices[:, i], i]
+        else:
+            np_indices = np_indices[:, :kk]
+            np_values = np.zeros(np_indices.shape).astype("float32")
+            for i in range(shape[0]):
+                np_values[i, :] = np_data[i, np_indices[i, :]]
+        np_indices = np_indices.astype(dtype)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(np_data)
+                if ret_type == "both":
+                    tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
+                    tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
+                elif ret_type == "values":
+                    tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
+                else:
+                    tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+    for k in [0, 1, 5]:
+        for axis in [0, -1, 1]:
+            for ret_type in ["both", "values", "indices"]:
+                for dtype in ["int64", "float32"]:
+                    verify_topk(k, axis, ret_type, False, dtype)
+                    verify_topk(k, axis, ret_type, True, dtype)
 
 
 if __name__ == "__main__":
     test_argsort()
+    test_topk()
index 526429b..403f67b 100644 (file)
@@ -21,3 +21,4 @@ from . import ssd
 from .ssd import *
 from .nms import *
 from .rcnn import *
+from .sort import *
index 925cf24..911dd84 100644 (file)
@@ -732,7 +732,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
     score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
-    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
+    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
 
     sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
                                       "sort_tensor_buf", data_alignment=8)
index 678d494..1d9148f 100644 (file)
 import tvm
 
 from tvm import api
-from topi.sort import argsort
-from topi.math import identity
+from ..sort import argsort, topk
+from ..math import identity
+from ..transform import strided_slice
 from .. import generic
 from .. import tag
 
+def _schedule_sort(outs):
+    """Schedule for argsort operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of argsort
+        in the format of an array of tensors.
 
-def sort_ir(data, output, axis, is_ascend):
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    scheduled_ops = []
+    from .injective import _schedule_injective
+    def traverse(op):
+        if tag.is_injective(op.tag):
+            _schedule_injective(op, s)
+        for tensor in op.input_tensors:
+            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+                traverse(tensor.op)
+        scheduled_ops.append(op)
+    for out in outs:
+        traverse(out.op)
+    return s
+
+def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
     """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
 
     Parameters
     ----------
     data: Buffer
-        Buffer of input data.
+        Buffer of input data. Data will be sorted in place.
 
     output : Buffer
         Output buffer of indicies of sorted tensor with same shape as data.
@@ -47,14 +76,12 @@ def sort_ir(data, output, axis, is_ascend):
     stmt : Stmt
         The result IR statement.
     """
-    size = 1
     axis_mul_before = 1
     axis_mul_after = 1
     shape = data.shape
     if axis < 0:
         axis = len(shape) + axis
     for i, value in enumerate(shape, 0):
-        size *= value
         if i < axis:
             axis_mul_before *= value
         elif i > axis:
@@ -62,52 +89,62 @@ def sort_ir(data, output, axis, is_ascend):
     max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
     ib = tvm.ir_builder.create()
     data = ib.buffer_ptr(data)
-    output = ib.buffer_ptr(output)
+    values_out = ib.buffer_ptr(values_out)
+    if indices_out is not None:
+        indices_out = ib.buffer_ptr(indices_out)
     nthread_tx = max_threads
-    nthread_bx = size // max_threads + 1
+    nthread_bx = shape[axis] // max_threads + 1
+
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("vthread")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "virtual_thread", nthread_bx)
     tid = bx * nthread_tx + tx
-    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
-    temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local")
-    is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
+    temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
+    if indices_out is not None:
+        temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
 
     with ib.for_range(0, axis_mul_before) as i:
         with ib.for_range(0, axis_mul_after) as j:
-            current_sort_num = shape[axis]
             base_idx = i * shape[axis] * axis_mul_after + j
             with ib.if_scope(tid < shape[axis]):
-                output[base_idx + tid * axis_mul_after] = tid.astype("float32")
+                values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
+                if indices_out is not None:
+                    indices_out[base_idx + tid * axis_mul_after] = \
+                        tvm.generic.cast(tid, indices_out.dtype)
+    ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
+                          tvm.convert(['shared']),
+                          tvm.expr.Call.Intrinsic, None, 0))
+
+    with ib.for_range(0, axis_mul_before) as i:
+        with ib.for_range(0, axis_mul_after) as j:
+            current_sort_num = shape[axis]
+            base_idx = i * shape[axis] * axis_mul_after + j
             # OddEvenTransposeSort
             with ib.for_range(0, current_sort_num) as k:
                 with ib.if_scope(tid < (current_sort_num + 1) // 2):
                     offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
-                    with ib.if_scope(tvm.all(is_ascend == 1, \
-                                             2 * tid + (k % 2) + 1 < current_sort_num, \
-                                             data[offset] > data[offset + axis_mul_after])):
-                        temp_data[0] = data[offset]
-                        data[offset] = data[offset + axis_mul_after]
-                        data[offset + axis_mul_after] = temp_data[0]
-                        temp_index[0] = output[offset]
-                        output[offset] = output[offset + axis_mul_after]
-                        output[offset + axis_mul_after] = temp_index[0]
-                    with ib.if_scope(tvm.all(is_ascend == 0, \
-                                             2 * tid + (k % 2) + 1 < current_sort_num, \
-                                             data[offset] < data[offset + axis_mul_after])):
-                        temp_data[0] = data[offset]
-                        data[offset] = data[offset + axis_mul_after]
-                        data[offset + axis_mul_after] = temp_data[0]
-                        temp_index[0] = output[offset]
-                        output[offset] = output[offset + axis_mul_after]
-                        output[offset + axis_mul_after] = temp_index[0]
+                    if is_ascend:
+                        cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
+                                       values_out[offset] > values_out[offset + axis_mul_after])
+                    else:
+                        cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
+                                       values_out[offset] < values_out[offset + axis_mul_after])
+                    with ib.if_scope(cond):
+                        temp_data[0] = values_out[offset]
+                        values_out[offset] = values_out[offset + axis_mul_after]
+                        values_out[offset + axis_mul_after] = temp_data[0]
+                        if indices_out is not None:
+                            temp_index[0] = indices_out[offset]
+                            indices_out[offset] = indices_out[offset + axis_mul_after]
+                            indices_out[offset + axis_mul_after] = temp_index[0]
                 ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                       tvm.convert(['shared']),
                                       tvm.expr.Call.Intrinsic, None, 0))
 
     return ib.get()
 
+
 def sort_nms_ir(data, valid_count, output, axis, is_ascend):
     """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
 
@@ -197,7 +234,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
     return ib.get()
 
 @argsort.register(["cuda", "gpu"])
-def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
+def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
 
@@ -206,26 +243,27 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
     data: tvm.Tensor
         The input array.
 
-    valid_count : tvm.Tensor
+    valid_count : tvm.Tensor, optional
         The number of valid elements to be sorted.
 
-    axis : int
+    axis : int, optional
         Axis long which to sort the input tensor.
 
-    is_ascend : boolean
+    is_ascend : boolean, optional
         Whether to sort in ascending or descending order.
 
-    flag : boolean
-        Whether this argsort is used in nms operator
+    dtype : string, optional
+        DType of the output indices.
 
     Returns
     -------
     out : tvm.Tensor
         The output of this function.
     """
-    sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8)
-    sorted_data = identity(data)
-    if flag:
+    if valid_count is not None:
+        sorted_data = identity(data)
+        sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf",
+                                          data_alignment=8)
         valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
                                           "valid_count_buf", data_alignment=4)
         out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
@@ -239,16 +277,15 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
                          name="argsort_nms_gpu",
                          tag="argsort_nms_gpu")
     else:
-        out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
-        out = tvm.extern([data.shape],
-                         [sorted_data],
+        value_buf = api.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
+        indices_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+        out = tvm.extern([data.shape, data.shape],
+                         [data],
                          lambda ins, outs: sort_ir(
-                             ins[0], outs[0], axis, is_ascend),
-                         dtype=dtype,
-                         in_buffers=[sorted_data_buf],
-                         out_buffers=[out_buf],
+                             ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
+                         out_buffers=[value_buf, indices_buf],
                          name="argsort_gpu",
-                         tag="argsort_gpu")
+                         tag="argsort_gpu")[1]
     return out
 
 @generic.schedule_argsort.register(["cuda", "gpu"])
@@ -266,17 +303,99 @@ def schedule_argsort(outs):
     s: Schedule
       The computation schedule for the op.
     """
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-    from .injective import _schedule_injective
-    def traverse(op):
-        if tag.is_broadcast(op.tag):
-            _schedule_injective(op, s)
-        for tensor in op.input_tensors:
-            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
-                traverse(tensor.op)
-        scheduled_ops.append(op)
-    traverse(outs[0].op)
+    return _schedule_sort(outs)
 
-    return s
+@topk.register(["cuda", "gpu"])
+def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
+    """Get the top k elements in an input tensor along the given axis.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        The input tensor.
+
+    k : int, optional
+        Number of top elements to select. Return all elements if k < 1.
+
+    axis : int, optional
+        Axis long which to sort the input tensor.
+
+    ret_type: str, optional
+        The return type [both, values, indices].
+        "both": return both top k data and indices.
+        "values": return top k data only.
+        "indices": return top k indices only.
+
+    is_ascend : boolean, optional
+        Whether to sort in ascending or descending order.
+
+    dtype : string, optional
+        The data type of the indices output.
+
+    Returns
+    -------
+    out : tvm.Tensor or List[tvm.Tensor]
+        The computed result.
+    """
+    assert ret_type in ["both", "values", "indices"]
+    ndim = len(data.shape)
+    axis = axis + ndim if axis < 0 else axis
+    assert 0 <= axis < ndim
+    values_buf = api.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
+    indices_buf = api.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
+    if ret_type == "values":
+        output = tvm.extern([data.shape],
+                            [data],
+                            lambda ins, outs: sort_ir(
+                                ins[0], outs[0], axis, is_ascend),
+                            out_buffers=[values_buf],
+                            name="topk_gpu",
+                            tag="topk_gpu")
+    else:
+        output = tvm.extern([data.shape, data.shape],
+                            [data],
+                            lambda ins, outs: sort_ir(
+                                ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
+                            out_buffers=[values_buf, indices_buf],
+                            name="topk_gpu",
+                            tag="topk_gpu")
+    if k < 1:
+        if ret_type == "indices":
+            return output[1]
+        return output
+    beg = [0] * ndim
+    end = []
+    for i in range(ndim):
+        if i == axis:
+            end.append(k)
+        else:
+            end.append(data.shape[i])
+    if ret_type == "both":
+        values_out, indices_out = output
+        values_out = strided_slice(values_out, beg, end)
+        indices_out = strided_slice(indices_out, beg, end)
+        output = [values_out, indices_out]
+    elif ret_type == "values":
+        output = [strided_slice(output, beg, end)]
+    else: # ret_type == "indices"
+        indices_out = output[1]
+        output = [strided_slice(indices_out, beg, end)]
+    return output
+
+
+@generic.schedule_topk.register(["cuda", "gpu"])
+def schedule_topk(outs):
+    """Schedule for argsort operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of argsort
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    return _schedule_sort(outs)
index 1ad088c..5462f2c 100644 (file)
@@ -36,3 +36,20 @@ def schedule_argsort(outs):
       The computation schedule for the op.
     """
     return _default_schedule(outs, False)
+
+@tvm.target.generic_func
+def schedule_topk(outs):
+    """Schedule for topk operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+      The indices that would sort an input array along
+      the given axis.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
index 84fff8d..22899c4 100644 (file)
 """Argsort operator"""
 import tvm
 from tvm import api
+from .util import get_const_tuple
 
 @tvm.target.generic_func
-def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
+def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array
     of indices having the same shape as an input array that index
     data in sorted order.
@@ -30,22 +31,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
     data : tvm.Tensor
         The input tensor.
 
-    valid_count : tvm.Tensor
+    valid_count : tvm.Tensor, optional
         1-D tensor for valid number of boxes only for ssd.
 
-    axis : optional, int
-       Axis along which to sort the input tensor.
+    axis : int, optional
+           Axis along which to sort the input tensor.
         By default the flattened array is used.
 
-    is_ascend : optional, boolean
+    is_ascend : boolean, optional
         Whether to sort in ascending or descending order.
 
-    dtype : optional, string
+    dtype : string, optional
         DType of the output indices.
 
-    flag : optional, boolean
-        Whether valid_count is valid.
-
     Returns
     -------
     out : tvm.Tensor
@@ -58,23 +56,19 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
         # An example to use argsort
         dshape = (1, 5, 6)
         data = tvm.placeholder(dshape, name="data")
-        valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
         axis = 0
         is_ascend = False
-        flag = False
-        out = argsort(data, valid_count, axis, is_ascend, flag)
+        out = argsort(data, axis=axis, is_ascend=is_ascend)
         np_data = np.random.uniform(dshape)
-        np_valid_count = np.array([4])
         s = topi.generic.schedule_argsort(out)
-        f = tvm.build(s, [data, valid_count, out], "llvm")
+        f = tvm.build(s, [data, out], "llvm")
         ctx = tvm.cpu()
         tvm_data = tvm.nd.array(np_data, ctx)
-        tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
         tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
-        f(tvm_data, tvm_valid_count, tvm_out)
+        f(tvm_data, tvm_out)
     """
     data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
-    if flag:
+    if valid_count is not None:
         valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
                                           "valid_count_buf", data_alignment=4)
         out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
@@ -103,3 +97,58 @@ def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
                        name="argsort_cpu",
                        tag="argsort_cpu")
     return out
+
+
+@tvm.target.generic_func
+def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
+    """Get the top k elements in an input tensor along the given axis.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        The input tensor.
+
+    k : int, optional
+        Number of top elements to select. Return all elements if k < 1.
+
+    axis : int, optional
+        Axis long which to sort the input tensor.
+
+    ret_type: str, optional
+        The return type [both, values, indices].
+        "both": return both top k data and indices.
+        "values": return top k data only.
+        "indices": return top k indices only.
+
+    is_ascend : boolean, optional
+        Whether to sort in ascending or descending order.
+
+    dtype : string, optional
+        The data type of the indices output.
+
+    Returns
+    -------
+    out : tvm.Tensor or List[tvm.Tensor]
+        The computed result.
+    """
+    assert ret_type in ["both", "values", "indices"]
+    data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    out_shape = list(get_const_tuple(data.shape))
+    if k >= 1:
+        out_shape[axis] = k
+    out_bufs = []
+    if ret_type in ["both", "values"]:
+        out_bufs.append(api.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8))
+    if ret_type in ["both", "indices"]:
+        out_bufs.append(api.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8))
+    out_shapes = [out_shape] * len(out_bufs)
+
+    out = tvm.extern(out_shapes,
+                     [data],
+                     lambda ins, outs: tvm.call_packed(
+                         "tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend),
+                     in_buffers=[data_buf],
+                     out_buffers=out_bufs,
+                     name="topk_cpu",
+                     tag="topk_cpu")
+    return out
index 2ad1f6e..04af151 100644 (file)
@@ -151,6 +151,8 @@ def strided_slice(a, begin, end, strides=None):
     -------
     ret : tvm.Tensor
     """
+    if strides is None:
+        strides = []
     return cpp.strided_slice(a, begin, end, strides)
 
 
index 979565d..7c8d7db 100644 (file)
@@ -331,7 +331,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
     score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
-    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
+    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
     out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
                                   tvm.const(max_output_size, dtype="int32"),
                                   tvm.const(iou_threshold, dtype="float32"),
index 3a2c9c2..ed902b9 100644 (file)
 # under the License.
 """Test code for vision package"""
 from __future__ import print_function
-import math
 import numpy as np
 import tvm
 import topi
 import topi.testing
 
-from tvm.contrib.pickle_memoize import memoize
-from topi.util import get_const_tuple
-from topi import argsort
-
 def test_argsort():
-    dshape = (1, 8)
-    valid_count_shape = (2,)
+    dshape = (20, 100)
     data = tvm.placeholder(dshape, name="data", dtype="float32")
-    valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
     np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
-    np_valid_count = np.array([4]).astype(valid_count.dtype)
     np_result = np.argsort(-np_data)
     def check_device(device):
         ctx = tvm.context(device, 0)
@@ -41,19 +33,77 @@ def test_argsort():
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False)
+            out = topi.argsort(data, axis=-1, is_ascend=False)
             s = topi.generic.schedule_argsort(out)
 
         tvm_data = tvm.nd.array(np_data, ctx)
-        tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
         tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
-        f = tvm.build(s, [data, valid_count, out], device)
-        f(tvm_data, tvm_valid_count, tvm_out)
+        f = tvm.build(s, [data, out], device)
+        f(tvm_data, tvm_out)
         tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
 
     for device in ['llvm', 'cuda', 'opencl']:
         check_device(device)
 
+def verify_topk(k, axis, ret_type, is_ascend, dtype):
+    shape = (20, 100)
+    data_dtype = "float32"
+    data = tvm.placeholder(shape, name="data", dtype=data_dtype)
+
+    np_data = np.random.uniform(size=shape).astype(data_dtype)
+    if is_ascend:
+        np_indices = np.argsort(np_data, axis=axis)
+    else:
+        np_indices = np.argsort(-np_data, axis=axis)
+    kk = k if k >= 1 else shape[axis]
+    if axis == 0:
+        np_indices = np_indices[:kk, :]
+        np_values = np.zeros(np_indices.shape).astype(data_dtype)
+        for i in range(shape[1]):
+            np_values[:, i] = np_data[np_indices[:, i], i]
+    else:
+        np_indices = np_indices[:, :kk]
+        np_values = np.zeros(np_indices.shape).astype(data_dtype)
+        for i in range(shape[0]):
+            np_values[i, :] = np_data[i, np_indices[i, :]]
+    np_indices = np_indices.astype(dtype)
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            outs = topi.topk(data, k, axis, ret_type, is_ascend, dtype)
+            outs = outs if isinstance(outs, list) else [outs]
+            s = topi.generic.schedule_topk(outs)
+        tvm_data = tvm.nd.array(np_data, ctx)
+        tvm_res = []
+        for t in outs:
+            tvm_res.append(tvm.nd.empty(t.shape, dtype=t.dtype, ctx=ctx))
+        f = tvm.build(s, [data] + outs, device)
+        f(tvm_data, *tvm_res)
+        if ret_type == "both":
+            tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
+            tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_indices)
+        elif ret_type == "values":
+            tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_values)
+        else:
+            tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices)
+
+    for device in ['llvm', 'cuda', 'opencl']:
+        check_device(device)
+
+def test_topk():
+    for k in [0, 1, 5]:
+        for axis in [0, -1, 1]:
+            for ret_type in ["both", "values", "indices"]:
+                for dtype in ["int64", "float32"]:
+                    verify_topk(k, axis, ret_type, True, dtype)
+                    verify_topk(k, axis, ret_type, False, dtype)
+
 
 if __name__ == "__main__":
     test_argsort()
+    test_topk()