tvm.relay.vision.yolo_reorg
+**Level 6: Algorithm Operators**
+
+.. autosummary::
+ :nosignatures:
+
+ tvm.relay.argsort
+
+
**Level 10: Temporary Operators**
This level support backpropagation of broadcast operators. It is temporary.
.. autofunction:: tvm.relay.vision.yolo_reorg
+Level 6 Definitions
+-------------------
+.. autofunction:: tvm.relay.argsort
+
+
Level 10 Definitions
--------------------
.. autofunction:: tvm.relay.broadcast_to_like
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/attrs/vision.h
+ * \brief Auxiliary attributes for vision operators.
+ */
+#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
+#define TVM_RELAY_ATTRS_ALGORITHM_H_
+
+#include <tvm/attrs.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Attributes used in argsort operators */
+struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
+ int axis;
+ bool is_ascend;
+ DataType dtype;
+
+ TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") {
+ TVM_ATTR_FIELD(axis).set_default(-1)
+ .describe("Axis along which to sort the input tensor."
+ "If not given, the flattened array is used.");
+ TVM_ATTR_FIELD(is_ascend).set_default(true)
+ .describe("Whether to sort in ascending or descending order."
+ "By default, sort in ascending order");
+ TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
+ .describe("DType of the output indices.");
+ }
+};
+
+} // namespace relay
+} // namespace tvm
+#endif // TVM_RELAY_ATTRS_ALGORITHM_H_
double iou_threshold;
bool force_suppress;
int top_k;
+ int coord_start;
+ int score_index;
int id_index;
bool return_indices;
bool invalid_to_bottom;
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
+ TVM_ATTR_FIELD(coord_start).set_default(2)
+ .describe("Start index of the consecutive 4 coordinates.");
+ TVM_ATTR_FIELD(score_index).set_default(1)
+ .describe("Index of the scores/confidence of boxes.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(return_indices).set_default(true)
bool force_suppress;
int top_k;
int id_index;
+ int coord_start;
+ int score_index;
int max_output_size;
bool invalid_to_bottom;
DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) {
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
+ DMLC_DECLARE_FIELD(coord_start).set_default(2)
+ .describe("Start index of the consecutive 4 coordinates.");
+ DMLC_DECLARE_FIELD(score_index).set_default(1)
+ .describe("Index of the scores/confidence of boxes.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
DMLC_DECLARE_FIELD(return_indices).set_default(true)
id_index = attrs.get_int('id_index')
invalid_to_bottom = attrs.get_bool('invalid_to_bottom')
- return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
- iou_threshold, force_suppress, top_k,
- id_index, return_indices, invalid_to_bottom)
+ return topi.vision.non_max_suppression(inputs[0], inputs[1],
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ force_suppress=force_suppress,
+ top_k=top_k, id_index=id_index,
+ return_indices=return_indices,
+ invalid_to_bottom=invalid_to_bottom)
reg.register_pattern("non_max_suppression", OpPattern.OPAQUE)
if clip:
np_out = np.clip(np_out, 0, 1)
- target = "llvm"
- ctx = tvm.cpu()
- graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
- m = graph_runtime.create(graph, lib, ctx)
- m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
- m.run()
- out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
- tvm.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
+ for target, ctx in ctx_list():
+ graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
+ m = graph_runtime.create(graph, lib, ctx)
+ m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
+ m.run()
+ tvm_out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
+ tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
def test_multibox_prior():
verify_multibox_prior((1, 3, 50, 50))
[0, 0.44999999, 1, 1, 1, 1],
[0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])
- target = "llvm"
dtype = "float32"
- ctx = tvm.cpu()
- graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
- "loc_preds": (batch_size, num_anchors * 4),
- "anchors": (1, num_anchors, 4)})
- m = graph_runtime.create(graph, lib, ctx)
- m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
- m.run()
- out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
- tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
+ for target, ctx in ctx_list():
+ graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
+ "loc_preds": (batch_size, num_anchors * 4),
+ "anchors": (1, num_anchors, 4)})
+ m = graph_runtime.create(graph, lib, ctx)
+ m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
+ m.run()
+ tvm_out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
+ tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
def test_non_max_suppression():
dshape = (1, 5, 6)
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
- target = "llvm"
- ctx = tvm.cpu()
- graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
- dtype={"data": "float32", "valid_count": "int32"})
- m = graph_runtime.create(graph, lib, ctx)
- m.set_input(**{"data": np_data, "valid_count": np_valid_count})
- m.run()
- out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
- tvm.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
+ for target, ctx in ctx_list():
+ graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
+ dtype={"data": "float32", "valid_count": "int32"})
+ m = graph_runtime.create(graph, lib, ctx)
+ m.set_input(**{"data": np_data, "valid_count": np_valid_count})
+ m.run()
+ tvm_out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
+ tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
def np_slice_like(np_data, np_shape_like, axis=[]):
begin_idx = [0 for _ in np_data.shape]
from .op.reduce import *
from .op.tensor import *
from .op.transform import *
+from .op.algorithm import *
from . import nn
from . import annotation
from . import vision
'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize()))
+def _mx_adaptive_avg_pooling(inputs, attrs):
+ output_size = attrs.get_int_tuple("output_size", [])
+ if output_size != (1,):
+ raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
+ return _op.nn.global_avg_pool2d(inputs[0])
+
+
def _mx_dropout(inputs, attrs):
rate = attrs.get_float("p", 0.5)
return _op.nn.dropout(inputs[0], rate=rate)
id_index = attrs.get_int('id_index', -1)
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
- if coord_start != 2:
- raise tvm.error.OpAttributeInvalid(
- 'Value of attribute "coord_start" must equal 2 for operator box_nms.')
- if score_index != 1:
- raise tvm.error.OpAttributeInvalid(
- 'Value of attribute "score_index" must equal 1 for operator box_nms.')
- if id_index != -1 and int(id_index) != 0:
- raise tvm.error.OpAttributeInvalid(
- 'Value of attribute "id_index" must equal either -1 or 0 for operator box_nms.')
if in_format != 'corner':
raise tvm.error.OpAttributeInvalid(
'Value of attribute "in_format" must equal "corner" for operator box_nms.')
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
+ coord_start=coord_start,
+ score_index=score_index,
id_index=id_index,
return_indices=False,
invalid_to_bottom=True)
return res
+def _mx_argsort(inputs, attrs):
+ assert len(inputs) == 1
+ new_attrs = {}
+ new_attrs["axis"] = attrs.get_int("axis", -1)
+ new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
+ new_attrs["dtype"] = attrs.get_str("dtype", "float32")
+ return _op.argsort(inputs[0], **new_attrs)
+
+
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
"BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
+ "argsort" : _mx_argsort,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"smooth_l1" : _mx_smooth_l1,
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
+ "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
from .reduce import *
from .tensor import *
from .transform import *
+from .algorithm import *
from . import nn
from . import annotation
from . import image
from . import _tensor_grad
from . import _transform
from . import _reduce
+from . import _algorithm
from ..expr import Expr
from ..base import register_relay_node
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"Definition of classic algorithms"
+# pylint: disable=invalid-name,unused-argument
+from __future__ import absolute_import
+
+import topi
+from topi.util import get_const_int
+from ..op import OpPattern, register_compute, register_schedule, register_pattern
+
+
+@register_schedule("argsort")
+def schedule_argsort(_, outs, target):
+ """Schedule definition of argsort"""
+ with target:
+ return topi.generic.schedule_argsort(outs)
+
+
+@register_compute("argsort")
+def compute_argsort(attrs, inputs, _, target):
+ """Compute definition of argsort"""
+ axis = get_const_int(attrs.axis)
+ is_ascend = bool(get_const_int(attrs.is_ascend))
+ dtype = str(attrs.dtype)
+ return [
+ topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
+ dtype=dtype, flag=False)
+ ]
+
+
+register_pattern("argsort", OpPattern.OPAQUE)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Classic algorithm operation"""
+from __future__ import absolute_import as _abs
+from . import _make
+
+def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
+ """Performs sorting along the given axis and returns an array of indicies
+ having same shape as an input array that index data in sorted order.
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data tensor.
+
+ valid_count : tvm.Tensor
+ The number of valid elements to be sorted.
+
+ axis : int, optional
+ Axis long which to sort the input tensor.
+
+ is_ascend : boolean, optional
+ Whether to sort in ascending or descending order.
+
+ dtype : string, optional
+ DType of the output indices.
+
+ Returns
+ -------
+ out : relay.Expr
+ Tensor with same shape as data.
+ """
+ return _make.argsort(data, axis, is_ascend, dtype)
return _make.concatenate(Tuple(data), axis)
+def stack(data, axis):
+ """Join a sequence of arrays along a new axis.
+
+ Parameters
+ ----------
+ data : Union(List[relay.Expr], Tuple(relay.Expr))
+ A list of tensors.
+
+ axis : int
+ The axis in the result array along which the input arrays are stacked.
+
+ Returns
+ -------
+ ret : relay.Expr
+ The stacked tensor.
+ """
+ data = list(data)
+ if not data:
+ raise ValueError("relay.stack requires data to be non-empty.")
+ if not isinstance(axis, int):
+ raise ValueError("For now, we only support integer axis")
+ return _make.stack(Tuple(data), axis)
+
+
def copy(data):
"""Copy a tensor.
return _make.arange(start, stop, step, dtype)
-def stack(data, axis):
- """Join a sequence of arrays along a new axis.
-
- Parameters
- ----------
- data : relay.Expr
- The input data to the operator.
-
- axis : int
- The axis in the result array along which the input arrays are stacked.
-
- .. note::
- Each array in the input array sequence must have the same shape.
-
- Returns
- -------
- ret : relay.Expr
- The computed result.
- """
- return _make.stack(data, axis)
-
-
def repeat(data, repeats, axis):
"""Repeats elements of an array.
By default, repeat flattens the input array into 1-D and then repeats the elements.
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
"""
-
return _make.gather_nd(data, indices)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
+ coord_start = get_const_int(attrs.coord_start)
+ score_index = get_const_int(attrs.score_index)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
return [
topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
- id_index, return_indices, invalid_to_bottom)
+ coord_start, score_index, id_index,
+ return_indices, invalid_to_bottom)
]
iou_threshold=0.5,
force_suppress=False,
top_k=-1,
+ coord_start=2,
+ score_index=1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
top_k : int, optional
Keep maximum top k detections before nms, -1 for no limit.
+ coord_start : int, optional
+ The starting index of the consecutive 4 coordinates.
+
+ score_index : int, optional
+ Index of the scores/confidence of boxes.
+
id_index : int, optional
index of the class categories, -1 to disable.
"""
return _make.non_max_suppression(data, valid_count, max_output_size,
iou_threshold, force_suppress, top_k,
- id_index, return_indices, invalid_to_bottom)
+ coord_start, score_index, id_index,
+ return_indices, invalid_to_bottom)
}
-// Argsort implemented C library sort.
+// Argsort implemented C library sort for nms.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// sort_num specify the number of elements to be sorted.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
-TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *input = args[0];
DLTensor *sort_num = args[1];
DLTensor *output = args[2];
int32_t axis = args[3];
- bool is_descend = args[4];
+ bool is_ascend = args[4];
auto dtype = input->dtype;
auto data_ptr = static_cast<float *>(input->data);
int64_t full_idx = base_idx + k * axis_mul_after;
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
}
- if (is_descend) {
- std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
- } else {
+ if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
+ } else {
+ std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
}
for (int32_t k = 0; k < input->shape[axis]; ++k) {
*(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
}
});
+
+// Argsort implemented C library sort.
+// Return indices of sorted tensor.
+// By default, the last axis will be used to sort.
+// sort_num specify the number of elements to be sorted.
+// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
+// and sort axis is dk. sort_num should have dimension of
+// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
+TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ DLTensor *input = args[0];
+ DLTensor *output = args[1];
+ int32_t axis = args[2];
+ bool is_ascend = args[3];
+
+ auto dtype = input->dtype;
+ auto data_ptr = static_cast<float *>(input->data);
+ std::vector<std::pair<float, float>> sorter;
+ int64_t axis_mul_before = 1;
+ int64_t axis_mul_after = 1;
+
+ if (axis < 0) {
+ axis = input->ndim + axis;
+ }
+
+ // Currently only supports input dtype to be float32.
+ CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
+ "to be float32.";
+ CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
+ "to be float32.";
+ CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
+ "input ndim " << input->ndim;
+
+ for (int i = 0; i < input->ndim; ++i) {
+ if (i < axis) {
+ axis_mul_before *= input->shape[i];
+ } else if (i > axis) {
+ axis_mul_after *= input->shape[i];
+ }
+ }
+
+ int32_t current_sort_num = input->shape[axis];
+ for (int64_t i = 0 ; i < axis_mul_before; ++i) {
+ for (int64_t j = 0 ; j < axis_mul_after; ++j) {
+ sorter.clear();
+ int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
+ for (int64_t k = 0; k < current_sort_num; ++k) {
+ int64_t full_idx = base_idx + k * axis_mul_after;
+ sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
+ }
+ if (is_ascend) {
+ std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
+ } else {
+ std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
+ }
+ for (int32_t k = 0; k < input->shape[axis]; ++k) {
+ *(static_cast<float *>(output->data) + base_idx + k * axis_mul_after)
+ = k < static_cast<float>(sorter.size()) ? sorter[k].first : k;
+ }
+ }
+ }
+});
+
} // namespace contrib
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file nms.cc
+ * \brief Non-maximum suppression operators
+ */
+#include <tvm/relay/op.h>
+#include <tvm/relay/attrs/algorithm.h>
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(ArgsortAttrs);
+
+bool ArgsortRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // `types` contains: [data, result]
+ const ArgsortAttrs* param = attrs.as<ArgsortAttrs>();
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "Argsort: expect input type to be TensorType but get "
+ << types[0];
+ return false;
+ }
+ CHECK_EQ(param->dtype, Float(32));
+ reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
+ return true;
+}
+
+Expr MakeArgsort(Expr data,
+ int axis,
+ bool is_ascend,
+ DataType dtype) {
+ auto attrs = make_node<ArgsortAttrs>();
+ attrs->axis = axis;
+ attrs->is_ascend = is_ascend;
+ attrs->dtype = dtype;
+ static const Op& op = Op::Get("argsort");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_API("relay.op._make.argsort")
+.set_body_typed(MakeArgsort);
+
+RELAY_REGISTER_OP("argsort")
+.describe(R"doc(Returns the indices that would sort an
+input array along the given axis.
+)doc" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.ArgsortAttrs")
+.add_argument("data", "Tensor", "Input data.")
+.set_support_level(6)
+.add_type_rel("Argsort", ArgsortRel);
+} // namespace relay
+} // namespace tvm
double iou_threshold,
bool force_suppress,
int top_k,
+ int coord_start,
+ int score_index,
int id_index,
bool return_indices,
bool invalid_to_bottom) {
attrs->iou_threshold = iou_threshold;
attrs->force_suppress = force_suppress;
attrs->top_k = top_k;
+ attrs->coord_start = coord_start;
+ attrs->score_index = score_index;
attrs->id_index = id_index;
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
data = tvm.placeholder((n, l, m), name='data')
sort_num = tvm.placeholder((n, m), name="sort_num", dtype="int32")
axis = 1
- is_descend = True
+ is_ascend = False
out = tvm.extern(data.shape, [data, sort_num],
lambda ins, outs: tvm.call_packed(
- "tvm.contrib.sort.argsort", ins[0],
- ins[1], outs[0], axis, is_descend),
+ "tvm.contrib.sort.argsort_nms", ins[0],
+ ins[1], outs[0], axis, is_ascend),
dtype='int32', name="sort_tensor")
input = [[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]]
dshape = (1, 2, 3, 4, 5, 6)
axis = 4
reduced_shape = (1, 2, 3, 4, 6)
- is_descend = False
+ is_ascend = True
data = tvm.placeholder(dshape, name='data')
sort_num = tvm.placeholder(reduced_shape, name="sort_num", dtype="int32")
out = tvm.extern(data.shape, [data, sort_num],
lambda ins, outs: tvm.call_packed(
- "tvm.contrib.sort.argsort", ins[0],
- ins[1], outs[0], axis, is_descend),
+ "tvm.contrib.sort.argsort_nms", ins[0],
+ ins[1], outs[0], axis, is_ascend),
dtype='int32', name="sort_tensor")
ctx = tvm.cpu(0)
assert "score_threshold" in z.astext()
func = relay.Function([x], z.astuple())
func = relay.ir_pass.infer_type(func)
- ctx_list = [("llvm", tvm.cpu(0))]
- for target, ctx in ctx_list:
+ for target, ctx in ctx_list():
+ if target == 'cuda':
+ return
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)
- tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3)
- tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3)
+ tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
+ tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
iou_threshold=0.5, force_suppress=False, top_k=-1,
check_type_only=False):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
- x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int"))
- z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False)
- z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k)
+ x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32"))
+ z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
+ iou_threshold = iou_threshold, force_suppress = force_suppress, \
+ top_k = top_k, return_indices=False)
+ z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
+ iou_threshold = iou_threshold, force_suppress = force_suppress, \
+ top_k = top_k)
assert "iou_threshold" in z.astext()
assert "iou_threshold" in z_indices.astext()
zz = relay.ir_pass.infer_type(z)
func = relay.ir_pass.infer_type(func)
func_indices = relay.Function([x0, x1], z_indices)
func_indices = relay.ir_pass.infer_type(func_indices)
- ctx_list = [("llvm", tvm.cpu(0))]
- for target, ctx in ctx_list:
+ for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x0_data, x1_data)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data)
nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False)
func = relay.Function([cls_prob, loc_pred, anchors], nms)
func = relay.ir_pass.infer_type(func)
- ctx_list = [("llvm", tvm.cpu(0))]
- for target, ctx in ctx_list:
+ for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds,
np_anchors)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Support level6 operator test cases.
+"""
+import math
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.testing import ctx_list
+import topi.testing
+
+def test_argsort():
+ def verify_argsort(shape, axis, is_ascend):
+ x = relay.var("x", relay.TensorType(shape, "float32"))
+ z = relay.argsort(x, axis=axis, is_ascend=is_ascend)
+ zz = relay.ir_pass.infer_type(z)
+ func = relay.Function([x], z)
+ x_data = np.random.uniform(size=shape).astype("float32")
+ if is_ascend:
+ ref_res = np.argsort(x_data, axis=axis)
+ else:
+ ref_res = np.argsort(-x_data, axis=axis)
+
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(x_data)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5)
+ verify_argsort((2, 3, 4), axis=0, is_ascend=False)
+ verify_argsort((1, 4, 6), axis=1, is_ascend=True)
+ verify_argsort((3, 5, 6), axis=-1, is_ascend=False)
+
+
+if __name__ == "__main__":
+ test_argsort()
from .reduction import *
from .transform import *
from .broadcast import *
+from .sort import *
from . import nn
from . import x86
from . import cuda
import tvm
from tvm import api
-from topi.vision import non_max_suppression
-from ..util import get_const_tuple
+from tvm.generic import cast
+from tvm.intrin import if_then_else, log, power
+from topi.vision import non_max_suppression, get_valid_counts
+from .sort import argsort
-def sort_ir(data, index, output):
- """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+
+def get_valid_counts_pre(data, flag, idx, score_threshold):
+ """Low level IR to Prepare get valid count of bounding boxes
+ given a score threshold. Also moves valid boxes to the
+ top of input data.
Parameters
----------
data: Buffer
- 2D Buffer of input boxes' score with shape [batch_size, num_anchors].
+ 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+ flag : Buffer
+ 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
- index : Buffer
- 1D Buffer of number of valid number of boxes.
+ idx : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
- output : Buffer
- 2D Output buffer of indicies of sorted tensor with shape [batch_size, num_anchors].
+ score_threshold : float32
+ Lower limit of score for valid bounding boxes.
Returns
-------
stmt : Stmt
The result IR statement.
"""
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ box_data_length = data.shape[2]
+
+ ib = tvm.ir_builder.create()
+
+ data = ib.buffer_ptr(data)
+ flag = ib.buffer_ptr(flag)
+ idx = ib.buffer_ptr(idx)
+ score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
- assert data.dtype == "float32", "Currently only supports input dtype to be float32"
- batch, num_anchors = get_const_tuple(data.shape)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ nthread_bx = batch_size * num_anchors // max_threads + 1
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+
+ with ib.if_scope(tid < batch_size * num_anchors):
+ with ib.if_scope(data[tid * box_data_length + 1] > score_threshold):
+ flag[tid] = 1
+ idx[tid] = 1
+ with ib.else_scope():
+ flag[tid] = 0
+ idx[tid] = 0
+
+ return ib.get()
+
+def get_valid_counts_upsweep(data, idx_in, idx, partial):
+ """Low level IR of first step of scan: unsweep.
+
+ Parameters
+ ----------
+ data: Buffer
+ 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+ idx_in : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+ idx : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+ partial : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, new_range].
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
ib = tvm.ir_builder.create()
- p_data = ib.buffer_ptr(data)
- p_index = ib.buffer_ptr(index)
- p_out = ib.buffer_ptr(output)
+ data = ib.buffer_ptr(data)
+ idx_in = ib.buffer_ptr(idx_in)
+ idx = ib.buffer_ptr(idx)
+ partial = ib.buffer_ptr(partial)
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ elem_per_thread = num_anchors // max_threads + 1
nthread_tx = max_threads
- nthread_bx = num_anchors // max_threads + 1
+ nthread_bx = batch_size
tx = tvm.thread_axis("threadIdx.x")
- bx = tvm.thread_axis("vthread")
+ bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "virtual_thread", nthread_bx)
- tid = bx * nthread_tx + tx
- temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
- temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
-
- with ib.for_range(0, batch, for_type="unroll") as b:
- start = b * num_anchors
- with ib.if_scope(tid < num_anchors):
- p_out[start + tid] = tid
- # OddEvenTransposeSort
- with ib.for_range(0, p_index[b]) as k:
- with ib.if_scope(tid < (p_index[b] + 1) // 2):
- offset = start + 2 * tid + (k % 2)
- with ib.if_scope( \
- tvm.all(offset + 1 < p_index[0], p_data[offset] < p_data[offset + 1])):
- temp_data[0] = p_data[offset]
- p_data[offset] = p_data[offset + 1]
- p_data[offset + 1] = temp_data[0]
- temp_index[0] = p_out[offset]
- p_out[offset] = p_out[offset + 1]
- p_out[offset + 1] = temp_index[0]
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ new_range = num_anchors // elem_per_thread + 1
+ # Scan: Upsweep:
+ with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)):
+ with ib.for_range(0, elem_per_thread) as i:
+ with ib.if_scope(bx * num_anchors + \
+ tx * elem_per_thread + i < batch_size * num_anchors):
+ with ib.if_scope(i == 0):
+ partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread]
+ idx[bx * num_anchors + tx * elem_per_thread] = \
+ idx_in[bx * num_anchors + tx * elem_per_thread]
+ with ib.else_scope():
+ partial[bx * new_range + tx] += \
+ idx_in[bx * num_anchors + tx * elem_per_thread + i]
+ idx[bx * num_anchors + tx * elem_per_thread + i] = \
+ idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \
+ idx_in[bx * num_anchors + tx * elem_per_thread + i]
+ return ib.get()
+
+def get_valid_counts_scan(data, partial_in, partial):
+ """Low level IR to do scan.
+
+ Parameters
+ ----------
+ data: Buffer
+ 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+ idx_in : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+ idx : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+ partial : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, new_range].
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ ib = tvm.ir_builder.create()
+ partial_in = ib.buffer_ptr(partial_in)
+ partial = ib.buffer_ptr(partial)
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ elem_per_thread = num_anchors // max_threads + 1
+ nthread_tx = max_threads
+ nthread_bx = batch_size
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ var = tvm.make.node("FloatImm", dtype="float32", value=2)
+ new_range = num_anchors // elem_per_thread + 1
+ iteration = log(cast(new_range, "float32")) // math.log(2)
+ # Scan: Kogge-Stone adder
+ with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
+ with ib.for_range(0, iteration) as k:
+ with ib.if_scope(k == 0):
+ with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))):
+ partial[bx * new_range + tx] = \
+ partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1]
+ with ib.else_scope():
+ partial[bx * new_range] = partial_in[bx * new_range]
+ with ib.else_scope():
+ with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \
+ tx < tvm.min(new_range, num_anchors))):
+ partial[bx * new_range + tx] += \
+ partial[bx * new_range + tx - cast(power(var, k), "int32")]
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
+ return ib.get()
+
+def get_valid_counts_downsweep(data, idx_in, partial, idx):
+ """Low level IR to do downsweep of scan.
+
+ Parameters
+ ----------
+ data: Buffer
+ 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+ idx_in : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+ partial : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, new_range].
+
+ idx : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ ib = tvm.ir_builder.create()
+ idx_in = ib.buffer_ptr(idx_in)
+ idx = ib.buffer_ptr(idx)
+ partial = ib.buffer_ptr(partial)
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ elem_per_thread = num_anchors // max_threads + 1
+ nthread_tx = max_threads
+ nthread_bx = batch_size * num_anchors // max_threads + 1
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+ new_range = num_anchors // elem_per_thread + 1
+ # Scan: Downsweep:
+ with ib. if_scope(tid < batch_size * num_anchors):
+ i = tid / num_anchors # number of batches
+ j = tid % num_anchors # number of anchors
+ with ib.if_scope(j < elem_per_thread):
+ idx[tid] = idx_in[tid]
+ with ib.else_scope():
+ idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1]
return ib.get()
-def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
+def get_valid_counts_ir(data, flag, idx, valid_count, out):
+ """Low level IR to get valid count of bounding boxes
+ given a score threshold. Also moves valid boxes to the
+ top of input data.
+
+ Parameters
+ ----------
+ data : Buffer
+ Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].
+
+ flag : Buffer
+ 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
+
+ idx : Buffer
+ 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+
+ valid_count : Buffer
+ 1-D buffer for valid number of boxes.
+
+ out : Buffer
+ Rearranged data buffer.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ elem_length = data.shape[2]
+ size = batch_size * num_anchors * elem_length
+
+ ib = tvm.ir_builder.create()
+
+ data = ib.buffer_ptr(data)
+ flag = ib.buffer_ptr(flag)
+ idx = ib.buffer_ptr(idx)
+ valid_count = ib.buffer_ptr(valid_count)
+ out = ib.buffer_ptr(out)
+
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+
+ with ib.if_scope(tid < batch_size * num_anchors):
+ i = tid / num_anchors
+ j = tid % num_anchors
+ base_idx = i * num_anchors * elem_length
+ with ib.if_scope(flag[tid] > 0):
+ with ib.for_range(0, elem_length) as k:
+ with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size):
+ out[base_idx + (idx[tid] - 1) * elem_length + k] =\
+ data[base_idx + j * elem_length + k]
+ with ib.if_scope(j == 0):
+ valid_count[i] = idx[tid + num_anchors - 1]
+ with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]):
+ with ib.for_range(0, elem_length) as l:
+ with ib.if_scope(tid * elem_length + l < size):
+ out[tid * elem_length + l] = -1.0
+ return ib.get()
+
+
+@get_valid_counts.register(["cuda", "gpu"])
+def get_valid_counts_gpu(data, score_threshold=0):
+ """Get valid count of bounding boxes given a score threshold.
+ Also moves valid boxes to the top of input data.
+
+ Parameters
+ ----------
+ data : tvm.Tensor
+ Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+ score_threshold : optional, float
+ Lower limit of score for valid bounding boxes.
+
+ Returns
+ -------
+ valid_count : tvm.Tensor
+ 1-D tensor for valid number of boxes.
+
+ out_tensor : tvm.Tensor
+ Rearranged data tensor.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ elem_per_thread = num_anchors // max_threads + 1
+ new_range = num_anchors // elem_per_thread + 1
+ temp_flag_buf = api.decl_buffer(
+ (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8)
+ temp_idx_buf = api.decl_buffer(
+ (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8)
+ temp_partial_buf = api.decl_buffer(
+ (batch_size, new_range), "int32", "temp_partial", data_alignment=8)
+ data_buf = api.decl_buffer(
+ data.shape, data.dtype, "data_buf", data_alignment=8)
+
+ temp_flag, temp_idx = \
+ tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
+ lambda ins, outs: get_valid_counts_pre(
+ ins[0], outs[0], outs[1], score_threshold),
+ dtype=["int32", "int32"],
+ out_buffers=[temp_flag_buf, temp_idx_buf],
+ name="get_valid_counts_phase_one")
+ temp_idx_new, temp_partial = \
+ tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx],
+ lambda ins, outs: get_valid_counts_upsweep(
+ ins[0], ins[1], outs[0], outs[1]),
+ dtype=["int32", "int32"],
+ out_buffers=[temp_idx_buf, temp_partial_buf],
+ name="get_valid_counts_phase_two")
+ temp_partial_new = \
+ tvm.extern([(batch_size, new_range)], [data, temp_partial],
+ lambda ins, outs: get_valid_counts_scan(
+ ins[0], ins[1], outs[0]),
+ dtype=["int32"],
+ out_buffers=[temp_partial_buf],
+ name="get_valid_counts_phase_three")
+ temp_idx_final = \
+ tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new],
+ lambda ins, outs: get_valid_counts_downsweep(
+ ins[0], ins[1], ins[2], outs[0]),
+ dtype=["int32"],
+ out_buffers=[temp_idx_buf],
+ name="get_valid_counts_phase_four")
+ valid_count, out_tensor = \
+ tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final],
+ lambda ins, outs: get_valid_counts_ir(
+ ins[0], ins[1], ins[2], outs[0], outs[1]),
+ dtype=["int32", data.dtype],
+ in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
+ name="get_valid_counts_phase_five",
+ tag="get_valid_counts_gpu")
+
+ return [valid_count, out_tensor]
+
+
+def nms_ir(data, sorted_index, valid_count, out, box_indices,
+ max_output_size, iou_threshold, force_suppress,
+ top_k, coord_start, id_index):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
----------
- data: Buffer
+ data : Buffer
Buffer of output boxes with class and score.
- sort_result : Buffer
+ sort_index : Buffer
Buffer of output box indexes sorted by score.
valid_count : Buffer
out : Buffer
Output buffer.
- nms_threshold : float
- Non-maximum suppression threshold.
+ max_output_size : int
+ Max number of output valid boxes for each instance.
+ By default all valid boxes are returned.
+
+ iou_threshold : float
+ Overlapping(IoU) threshold to suppress object with smaller score.
force_suppress : boolean
Whether to suppress all detections regardless of class_id.
- nms_topk : int
+ top_k : int
Keep maximum top k detections before nms, -1 for no limit.
+ coord_start : int
+ Start index of the consecutive 4 coordinates.
+
+ id_index : int
+ index of the class categories, -1 to disable.
+
Returns
-------
stmt : Stmt
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.expr.Select(u <= 0.0, 0.0, i / u)
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ box_data_length = data.shape[2]
+
+ ib = tvm.ir_builder.create()
+
+ data = ib.buffer_ptr(data)
+ sorted_index = ib.buffer_ptr(sorted_index)
+ valid_count = ib.buffer_ptr(valid_count)
+ out = ib.buffer_ptr(out)
+ box_indices = ib.buffer_ptr(box_indices)
+ num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
+
max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
- ib = tvm.ir_builder.create()
- p_data = ib.buffer_ptr(data)
- p_sort_result = ib.buffer_ptr(sort_result)
- p_valid_count = ib.buffer_ptr(valid_count)
- p_out = ib.buffer_ptr(out)
- batch_size = out.shape[0]
- num_anchors = out.shape[1]
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
- i = bx * max_threads + tx
-
- nms_threshold_node = tvm.make.node(
- "FloatImm", dtype="float32", value=nms_threshold)
- nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
- force_suppress_node = tvm.make.node(
- "IntImm", dtype="int32", value=1 if force_suppress else 0)
- with ib.for_range(0, batch_size, for_type="unroll") as b:
- base_idx = b * num_anchors * 6
- with ib.if_scope( \
- tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
- p_valid_count[0] > 0)):
+ k = bx * max_threads + tx
+
+ iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
+ top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
+ coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
+ id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
+ force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
+
+ with ib.for_range(0, batch_size, for_type="unroll") as i:
+ base_idx = i * num_anchors * box_data_length
+ with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
# Reorder output
- nkeep = tvm.if_then_else( \
- tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b]),
- nms_topk, p_valid_count[b])
- with ib.for_range(0, nkeep) as l:
- with ib.if_scope(i < 6):
- p_out[(base_idx + l * 6 + i)] = \
- p_data[(base_idx + p_sort_result[b * num_anchors + l] * 6 + i)]
- with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])):
- with ib.for_range(0, p_valid_count[b] - nkeep) as l:
- with ib.if_scope(i < 6):
- p_out[(base_idx + (l + nkeep) * 6 + i)] = -1.0
+ nkeep = if_then_else( \
+ tvm.all(top_k > 0, top_k < valid_count[i]),
+ top_k, valid_count[i])
+ with ib.for_range(0, nkeep) as j:
+ with ib.if_scope(k < box_data_length):
+ out[(base_idx + j * box_data_length + k)] = \
+ data[(base_idx + sorted_index[i * num_anchors + j] \
+ * box_data_length + k)]
+ box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
+ with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
+ with ib.for_range(0, valid_count[i] - nkeep) as j:
+ with ib.if_scope(k < box_data_length):
+ out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
+ box_indices[i * num_anchors + (j + nkeep)] = -1
# Apply nms
- with ib.for_range(0, p_valid_count[b]) as l:
- offset_l = l * 6
- with ib.if_scope(p_out[base_idx + offset_l] >= 0):
- with ib.if_scope(i < p_valid_count[b]):
- offset_i = i * 6
- with ib.if_scope(tvm.all(i > l, p_out[base_idx
- + offset_i] >= 0)):
- with ib.if_scope(tvm.any(force_suppress_node > 0,
- p_out[base_idx + offset_l] ==
- p_out[base_idx + offset_i])):
- # When force_suppress == True or class_id equals
- iou = calculate_overlap(p_out, base_idx + offset_l + 2,
- base_idx + offset_i + 2)
- with ib.if_scope(iou >= nms_threshold):
- p_out[base_idx + offset_i] = -1.0
+ with ib.for_range(0, valid_count[i]) as j:
+ offset_j = j * box_data_length
+ with ib.if_scope(out[base_idx + offset_j] >= 0):
+ with ib.if_scope(k < valid_count[i]):
+ offset_k = k * box_data_length
+ with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \
+ tvm.any(force_suppress > 0, id_index < 0, \
+ out[base_idx + offset_j] == \
+ out[base_idx + offset_k]))):
+ iou = calculate_overlap(out, base_idx + offset_k + coord_start,
+ base_idx + offset_j + coord_start)
+ with ib.if_scope(iou >= iou_threshold):
+ out[base_idx + offset_k] = -1.0
+ box_indices[i * num_anchors + k] = -1
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
with ib.else_scope():
- with ib.for_range(0, p_valid_count[b]) as c:
- with ib.if_scope(i < 6):
- p_out[(base_idx + c * 6 + i)] = p_data[base_idx + c * 6 + i]
+ with ib.for_range(0, valid_count[i]) as j:
+ offset_j = j * box_data_length
+ with ib.if_scope(k < box_data_length):
+ out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
+ box_indices[i * num_anchors + j] = j
# Set invalid entry to be -1
- with ib.for_range(0, num_anchors - p_valid_count[b]) as c:
- with ib.if_scope(i < 6):
- p_out[base_idx + (c + p_valid_count[b]) * 6 + i] = -1.0
- body = ib.get()
- return body
+ with ib.for_range(0, num_anchors - valid_count[i]) as j:
+ with ib.if_scope(k < box_data_length):
+ out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
+ box_indices[i * num_anchors + j + valid_count[i]] = -1
+ # Only return max_output_size number of valid boxes
+ num_valid_boxes[0] = 0
+ with ib.if_scope(max_output_size > 0):
+ with ib.for_range(0, valid_count[i]) as j:
+ offset_j = j * box_data_length
+ with ib.if_scope(out[base_idx + offset_j] >= 0):
+ with ib.if_scope(num_valid_boxes[0] == max_output_size):
+ with ib.if_scope(k < box_data_length):
+ out[base_idx + offset_j + k] = -1.0
+ box_indices[i * num_anchors + j] = -1
+ with ib.else_scope():
+ num_valid_boxes[0] += 1
+ ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
+ tvm.convert(['shared']),
+ tvm.expr.Call.Intrinsic, None, 0))
+
+ return ib.get()
+
+
+def invalid_to_bottom_pre(data, flag, idx):
+ """Low level IR to rearrange nms output to move all valid entries to top.
+
+ Parameters
+ ----------
+ data: Buffer
+ 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+ flag : Buffer
+ 1D Buffer of flag indicating valid data with [num_anchors].
+
+ idx : Buffer
+ 1D Buffer of valid data indices with [num_anchors].
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ elem_length = data.shape[2]
+
+ ib = tvm.ir_builder.create()
+
+ data = ib.buffer_ptr(data)
+ flag = ib.buffer_ptr(flag)
+ idx = ib.buffer_ptr(idx)
+
+ max_threads = int(math.sqrt(
+ tvm.target.current_target(allow_none=False).max_num_threads))
+ nthread_tx = max_threads
+ nthread_bx = num_anchors // max_threads + 1
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ j = bx * max_threads + tx
+
+ with ib.for_range(0, batch_size, for_type="unroll") as i:
+ base_idx = i * num_anchors * elem_length
+ with ib.if_scope(j < num_anchors):
+ with ib.if_scope(data[base_idx + j * elem_length] >= 0):
+ flag[i * num_anchors + j] = 1
+ idx[i * num_anchors + j] = 1
+ with ib.else_scope():
+ flag[i * num_anchors + j] = 0
+ idx[i * num_anchors + j] = 0
+
+ with ib.if_scope(j < batch_size):
+ with ib.for_range(0, num_anchors) as k:
+ with ib.if_scope(k > 0):
+ idx[j * num_anchors + k] += idx[j * num_anchors + k - 1]
+ return ib.get()
+
+
+def invalid_to_bottom_ir(data, flag, idx, out):
+ """Low level IR to rearrange nms output to move all valid entries to top.
+
+ Parameters
+ ----------
+ data: Buffer
+ 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+
+ flag : Buffer
+ 1D Buffer of flag indicating valid data with [num_anchors].
+
+ idx : Buffer
+ 1D Buffer of valid data indices with [num_anchors].
+
+ out : Buffer
+ 3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length].
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ elem_length = data.shape[2]
+
+ ib = tvm.ir_builder.create()
+
+ data = ib.buffer_ptr(data)
+ flag = ib.buffer_ptr(flag)
+ idx = ib.buffer_ptr(idx)
+ out = ib.buffer_ptr(out)
+
+ max_threads = int(math.sqrt(
+ tvm.target.current_target(allow_none=False).max_num_threads))
+ nthread_tx = max_threads
+ nthread_bx = num_anchors // max_threads + 1
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ j = bx * max_threads + tx
+
+ with ib.for_range(0, batch_size, for_type="unroll") as i:
+ base_idx = i * num_anchors * elem_length
+ with ib.if_scope(j < num_anchors):
+ with ib.for_range(0, elem_length) as k:
+ out[base_idx + j * elem_length + k] = -1.0
+ with ib.if_scope(flag[i * num_anchors + j] > 0):
+ with ib.for_range(0, elem_length) as k:
+ out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
+ = data[base_idx + j * elem_length + k]
+ return ib.get()
@non_max_suppression.register(["cuda", "gpu"])
-def nms_gpu(data,
- valid_count,
- max_output_size=-1,
- iou_threshold=0.5,
- force_suppress=False,
- top_k=-1,
- id_index=0,
- return_indices=True,
- invalid_to_bottom=False):
+def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
+ iou_threshold=0.5, force_suppress=False, top_k=-1,
+ coord_start=2, score_index=1, id_index=0,
+ return_indices=True, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data : tvm.Tensor
- 3-D tensor with shape [batch_size, num_anchors, 6].
+ 3-D tensor with shape [batch_size, num_anchors, elem_length].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
- return_indices : boolean
- Whether to return box indices in input data.
+ max_output_size : optional, int
+ Max number of output valid boxes for each instance.
+ By default all valid boxes are returned.
iou_threshold : optional, float
Non-maximum suppression threshold.
top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit.
+ coord_start : required, int
+ Start index of the consecutive 4 coordinates.
+
+ score_index : optional, int
+ Index of the scores/confidence of boxes.
+
id_index : optional, int
index of the class categories, -1 to disable.
+ return_indices : boolean
+ Whether to return box indices in input data.
+
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
-------
out : tvm.Tensor
- 3-D tensor with shape [batch_size, num_anchors, 6].
+ 3-D tensor with shape [batch_size, num_anchors, elem_length].
Example
--------
iou_threshold = 0.7
force_suppress = True
top_k = -1
- out = nms(data, valid_count, iou_threshold, force_suppress, topk)
+ out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold,
+ force_suppress=force_supress, top_k=top_k, return_indices=False)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
- f = tvm.build(s, [data, valid_count, out], "llvm")
- ctx = tvm.cpu()
+ f = tvm.build(s, [data, valid_count, out], "cuda")
+ ctx = tvm.gpu(0)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
+
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
"valid_count_buf", data_alignment=4)
- data_buf = api.decl_buffer(
- data.shape, data.dtype, "data_buf", data_alignment=8)
+ score_axis = score_index
score_shape = (batch_size, num_anchors)
- score_tensor = tvm.compute(
- score_shape, lambda i, j: data[i, j, 1], name="score_tensor")
- score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
- "score_tensor_buf", data_alignment=8)
+ score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
+ sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
- sort_tensor_dtype = "int32"
- sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
+ sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
- sort_tensor = \
- tvm.extern(score_shape,
- [score_tensor, valid_count],
- lambda ins, outs: sort_ir(
- ins[0], ins[1], outs[0]),
- dtype=sort_tensor_dtype,
- in_buffers=[score_tensor_buf, valid_count_buf],
- out_buffers=sort_tensor_buf,
- name="nms_sort")
+ data_buf = api.decl_buffer(
+ data.shape, data.dtype, "data_buf", data_alignment=8)
- out = \
- tvm.extern(data.shape,
+ out_buf = api.decl_buffer(
+ data.shape, data.dtype, "out_buf", data_alignment=8)
+
+ out, box_indices = \
+ tvm.extern([data.shape, score_shape],
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
- ins[0], ins[1], ins[2], outs[0], iou_threshold,
- force_suppress, top_k),
- dtype="float32",
+ ins[0], ins[1], ins[2], outs[0], outs[1],
+ max_output_size, iou_threshold, force_suppress,
+ top_k, coord_start, id_index),
+ dtype=[data.dtype, "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+ name="nms",
tag="nms")
+
+ if return_indices:
+ return box_indices
+
+ if invalid_to_bottom:
+ output_buf = api.decl_buffer(
+ data.shape, data.dtype, "output_buf", data_alignment=8)
+ temp_flag_buf = api.decl_buffer(
+ score_shape, valid_count_dtype, "temp_flag", data_alignment=8)
+ temp_idx_buf = api.decl_buffer(
+ score_shape, valid_count_dtype, "temp_idx", data_alignment=8)
+ temp_flag, temp_idx = tvm.extern([score_shape, score_shape], [out],
+ lambda ins, outs: invalid_to_bottom_pre(
+ ins[0], outs[0], outs[1]),
+ dtype=["int32", "int32"],
+ in_buffers=[out_buf],
+ out_buffers=[temp_flag_buf, temp_idx_buf],
+ name="invalid_to_bottom_phase_one")
+
+ output = tvm.extern([data.shape], [out, temp_flag, temp_idx],
+ lambda ins, outs: invalid_to_bottom_ir(
+ ins[0], ins[1], ins[2], outs[0]),
+ dtype=[data.dtype],
+ in_buffers=[out_buf, temp_flag_buf, temp_idx_buf],
+ out_buffers=[output_buf],
+ name="invalid_to_bottom",
+ tag="invalid_to_bottom")
+ return output
+
return out
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
+"""Argsort operator """
+import tvm
+
+from tvm import api
+from topi.sort import argsort
+
+def sort_ir(data, output, axis, is_ascend):
+ """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+
+ Parameters
+ ----------
+ data: Buffer
+ Buffer of input data.
+
+ output : Buffer
+ Output buffer of indicies of sorted tensor with same shape as data.
+
+ axis : Int
+ Axis long which to sort the input tensor.
+
+ is_ascend : Boolean
+ Whether to sort in ascending or descending order.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ size = 1
+ axis_mul_before = 1
+ axis_mul_after = 1
+ shape = data.shape
+ if axis < 0:
+ axis = len(shape) + axis
+ for i, value in enumerate(shape, 0):
+ size *= value
+ if i < axis:
+ axis_mul_before *= value
+ elif i > axis:
+ axis_mul_after *= value
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ ib = tvm.ir_builder.create()
+ data = ib.buffer_ptr(data)
+ output = ib.buffer_ptr(output)
+ nthread_tx = max_threads
+ nthread_bx = size // max_threads + 1
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("vthread")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "virtual_thread", nthread_bx)
+ tid = bx * nthread_tx + tx
+ temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
+ temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local")
+ is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
+
+ with ib.for_range(0, axis_mul_before) as i:
+ with ib.for_range(0, axis_mul_after) as j:
+ current_sort_num = shape[axis]
+ base_idx = i * shape[axis] * axis_mul_after + j
+ with ib.if_scope(tid < shape[axis]):
+ output[base_idx + tid * axis_mul_after] = tid.astype("float32")
+ # OddEvenTransposeSort
+ with ib.for_range(0, current_sort_num) as k:
+ with ib.if_scope(tid < (current_sort_num + 1) // 2):
+ offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
+ with ib.if_scope(tvm.all(is_ascend == 1, \
+ 2 * tid + (k % 2) + 1 < current_sort_num, \
+ data[offset] > data[offset + axis_mul_after])):
+ temp_data[0] = data[offset]
+ data[offset] = data[offset + axis_mul_after]
+ data[offset + axis_mul_after] = temp_data[0]
+ temp_index[0] = output[offset]
+ output[offset] = output[offset + axis_mul_after]
+ output[offset + axis_mul_after] = temp_index[0]
+ with ib.if_scope(tvm.all(is_ascend == 0, \
+ 2 * tid + (k % 2) + 1 < current_sort_num, \
+ data[offset] < data[offset + axis_mul_after])):
+ temp_data[0] = data[offset]
+ data[offset] = data[offset + axis_mul_after]
+ data[offset + axis_mul_after] = temp_data[0]
+ temp_index[0] = output[offset]
+ output[offset] = output[offset + axis_mul_after]
+ output[offset + axis_mul_after] = temp_index[0]
+ ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
+ tvm.convert(['shared']),
+ tvm.expr.Call.Intrinsic, None, 0))
+
+ return ib.get()
+
+
+
+def sort_nms_ir(data, valid_count, output, axis, is_ascend):
+ """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
+
+ Parameters
+ ----------
+ data: Buffer
+ Buffer of input data.
+
+ valid_count : Buffer
+ 1D Buffer of number of valid number of boxes.
+
+ output : Buffer
+ Output buffer of indicies of sorted tensor with same shape as data.
+
+ axis : Int
+ Axis long which to sort the input tensor.
+
+ is_ascend : Boolean
+ Whether to sort in ascending or descending order.
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+
+ size = 1
+ axis_mul_before = 1
+ axis_mul_after = 1
+ shape = data.shape
+ if axis < 0:
+ axis = len(shape) + axis
+ for i, value in enumerate(shape, 0):
+ size *= value
+ if i < axis:
+ axis_mul_before *= value
+ elif i > axis:
+ axis_mul_after *= value
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ ib = tvm.ir_builder.create()
+ data = ib.buffer_ptr(data)
+ valid_count = ib.buffer_ptr(valid_count)
+ output = ib.buffer_ptr(output)
+ nthread_tx = max_threads
+ nthread_bx = size // max_threads + 1
+ tx = tvm.thread_axis("threadIdx.x")
+ bx = tvm.thread_axis("vthread")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "virtual_thread", nthread_bx)
+ tid = bx * nthread_tx + tx
+ temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
+ temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
+ is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
+
+ with ib.for_range(0, axis_mul_before) as i:
+ with ib.for_range(0, axis_mul_after) as j:
+ current_sort_num = valid_count[i * axis_mul_after + j]
+ base_idx = i * shape[axis] * axis_mul_after + j
+ with ib.if_scope(tid < shape[axis]):
+ output[base_idx + tid * axis_mul_after] = tid
+ # OddEvenTransposeSort
+ with ib.for_range(0, current_sort_num) as k:
+ with ib.if_scope(tid < (current_sort_num + 1) // 2):
+ offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
+ with ib.if_scope(tvm.all(is_ascend == 1, \
+ 2 * tid + (k % 2) + 1 < current_sort_num, \
+ data[offset] > data[offset + axis_mul_after])):
+ temp_data[0] = data[offset]
+ data[offset] = data[offset + axis_mul_after]
+ data[offset + axis_mul_after] = temp_data[0]
+ temp_index[0] = output[offset]
+ output[offset] = output[offset + axis_mul_after]
+ output[offset + axis_mul_after] = temp_index[0]
+ with ib.if_scope(tvm.all(is_ascend == 0, \
+ 2 * tid + (k % 2) + 1 < current_sort_num, \
+ data[offset] < data[offset + axis_mul_after])):
+ temp_data[0] = data[offset]
+ data[offset] = data[offset + axis_mul_after]
+ data[offset + axis_mul_after] = temp_data[0]
+ temp_index[0] = output[offset]
+ output[offset] = output[offset + axis_mul_after]
+ output[offset + axis_mul_after] = temp_index[0]
+ ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
+ tvm.convert(['shared']),
+ tvm.expr.Call.Intrinsic, None, 0))
+
+ return ib.get()
+
+@argsort.register(["cuda", "gpu"])
+def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
+ """Performs sorting along the given axis and returns an array of indicies
+ having same shape as an input array that index data in sorted order.
+
+ Parameters
+ ----------
+ data: tvm.Tensor
+ The input array.
+
+ valid_count : tvm.Tensor
+ The number of valid elements to be sorted.
+
+ axis : int
+ Axis long which to sort the input tensor.
+
+ is_ascend : boolean
+ Whether to sort in ascending or descending order.
+
+ flag : boolean
+ Whether this argsort is used in nms operator
+
+ Returns
+ -------
+ out : tvm.Tensor
+ The output of this function.
+ """
+ data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+ if flag:
+ valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
+ "valid_count_buf", data_alignment=4)
+ out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
+ out = tvm.extern([data.shape],
+ [data, valid_count],
+ lambda ins, outs: sort_nms_ir(
+ ins[0], ins[1], outs[0], axis, is_ascend),
+ dtype="int32",
+ in_buffers=[data_buf, valid_count_buf],
+ out_buffers=[out_buf],
+ name="argsort_nms_gpu",
+ tag="argsort_nms_gpu")
+ else:
+ out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+ out = tvm.extern([data.shape],
+ [data],
+ lambda ins, outs: sort_ir(
+ ins[0], outs[0], axis, is_ascend),
+ dtype=dtype,
+ in_buffers=[data_buf],
+ out_buffers=[out_buf],
+ name="argsort_gpu",
+ tag="argsort_gpu")
+ return out
import tvm
from tvm import api
+from tvm.intrin import if_then_else, exp
import topi
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
- w = tvm.if_then_else(k < num_sizes,
- size_ratio_concat[
- k] * in_height / in_width / 2.0,
- size_ratio_concat[0] * in_height / in_width *
- math.sqrt(size_ratio_concat[k + 1]) / 2.0)
- h = tvm.if_then_else(
+ w = if_then_else(k < num_sizes,
+ size_ratio_concat[k] * in_height / in_width / 2.0,
+ size_ratio_concat[0] * in_height / in_width *
+ math.sqrt(size_ratio_concat[k + 1]) / 2.0)
+ h = if_then_else(
k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
count = (i * in_width * (num_sizes + num_ratios - 1) +
out = topi.clip(out, 0, 1)
return out
-
-def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold):
+def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp_score, threshold):
"""Low level IR routing for transform location data preparation.
Parameters
valid_count : Buffer
Buffer of number of valid output boxes.
- temp_flag : Buffer
+ temp_valid_count : Buffer
Output intermediate result buffer
- temp_id : Buffer
+ temp_cls_id : Buffer
Output intermediate result buffer
- temp_score_out : Buffer
+ temp_score : Buffer
Output buffer
threshold : float
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
- max_threads = int(
- tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create()
- score = ib.buffer_ptr(temp_score_out)
- cls_id = ib.buffer_ptr(temp_id)
- flag = ib.buffer_ptr(temp_flag)
+
+ cls_prob = ib.buffer_ptr(cls_prob)
+ cls_id = ib.buffer_ptr(temp_cls_id)
+ valid_count = ib.buffer_ptr(valid_count)
+ temp_valid_count = ib.buffer_ptr(temp_valid_count)
+ score = ib.buffer_ptr(temp_score)
+
+ threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)
+
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ nthread_bx = (batch_size * num_anchors) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
- nthread_tx = max_threads
- nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
- p_cls_prob = ib.buffer_ptr(cls_prob)
- p_valid_count = ib.buffer_ptr(valid_count)
with ib.if_scope(tid < batch_size * num_anchors):
- n = tid / num_anchors # number of batches
- i = tid % num_anchors # number of anchors
- score[i] = -1.0
- cls_id[i] = 0
- p_valid_count[n] = 0
- with ib.for_range(0, num_classes-1, name="k") as k:
- temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i]
- with ib.if_scope(temp > score[i]):
- cls_id[i] = k + 1
- score[i] = temp
- with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)):
- cls_id[i] = 0
- with ib.if_scope(cls_id[i] > 0):
- flag[i] = 1
+ i = tid / num_anchors
+ j = tid % num_anchors
+ valid_count[i] = 0
+ score[tid] = -1.0
+ cls_id[tid] = 0
+ with ib.for_range(0, num_classes - 1) as k:
+ temp = cls_prob[i * num_classes * num_anchors + (k + 1) * num_anchors + j]
+ cls_id[tid] = if_then_else(temp > score[tid], k + 1, cls_id[tid])
+ score[tid] = tvm.max(temp, score[tid])
+ with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
+ cls_id[tid] = 0
+ with ib.if_scope(cls_id[tid] > 0):
+ temp_valid_count[tid] = 1
with ib.else_scope():
- flag[i] = 0
+ temp_valid_count[tid] = 0
with ib.if_scope(tid < batch_size):
- with ib.for_range(0, num_anchors, name="k") as k:
+ with ib.for_range(0, num_anchors) as k:
with ib.if_scope(k > 0):
- flag[tid * num_anchors +
- k] += flag[tid * num_anchors + k - 1]
- p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1]
+ temp_valid_count[tid * num_anchors + k] += \
+ temp_valid_count[tid * num_anchors + k - 1]
+ valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors - 1]
- body = ib.get()
- return body
+ return ib.get()
-
-def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
- out, clip, variances, batch_size, num_classes, num_anchors):
+def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score, out, \
+ clip, variances, batch_size, num_anchors):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
anchor : Buffer
Buffer of prior anchor boxes.
- temp_flag : Buffer
+ temp_valid_count : Buffer
Intermediate result buffer.
- temp_id : Buffer
+ temp_cls_id : Buffer
Intermediate result buffer.
- temp_score_in : Buffer
+ temp_score : Buffer
Input buffer which stores intermediate results.
out : Buffer
batch_size : int
Batch size
- num_classes : int
- Number of classes
-
num_anchors : int
Number of anchors
ph = loc[loc_base_idx + 3]
ox = px * vx * aw + ax
oy = py * vy * ah + ay
- ow = tvm.exp(pw * vw) * aw / 2.0
- oh = tvm.exp(ph * vh) * ah / 2.0
- return tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \
- tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \
- tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \
- tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh)
-
- max_threads = int(
- tvm.target.current_target(allow_none=False).max_num_threads)
+ ow = exp(pw * vw) * aw / 2.0
+ oh = exp(ph * vh) * ah / 2.0
+ return tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox - ow)), ox - ow), \
+ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy - oh)), oy - oh), \
+ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox + ow)), ox + ow), \
+ tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy + oh)), oy + oh)
+
ib = tvm.ir_builder.create()
- score = ib.buffer_ptr(temp_score_in)
- cls_id = ib.buffer_ptr(temp_id)
- flag = ib.buffer_ptr(temp_flag)
+
+ loc_pred = ib.buffer_ptr(loc_pred)
+ anchor = ib.buffer_ptr(anchor)
+ temp_valid_count = ib.buffer_ptr(temp_valid_count)
+ cls_id = ib.buffer_ptr(temp_cls_id)
+ score = ib.buffer_ptr(temp_score)
+ out_loc = ib.buffer_ptr(out)
+
+ max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ nthread_bx = (batch_size * num_anchors) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
- nthread_tx = max_threads
- nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
- p_loc_pred = ib.buffer_ptr(loc_pred)
- p_anchor = ib.buffer_ptr(anchor)
- p_out = ib.buffer_ptr(out)
with ib.if_scope(tid < batch_size * num_anchors):
- n = tid / num_anchors # number of batches
- i = tid % num_anchors # number of anchors
+ i = tid / num_anchors
+ j = tid % num_anchors
with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0):
- out_base_idx = n * num_anchors * 6
+ out_base_idx = i * num_anchors * 6
+ out_loc[out_base_idx] = cls_id[tid] - 1.0
+ out_loc[out_base_idx + 1] = score[tid]
+ out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \
+ out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4,
+ anchor, j * 4, clip, variances[0],
+ variances[1], variances[2],
+ variances[3])
with ib.else_scope():
- out_base_idx = n * num_anchors * 6 + flag[tid - 1] * 6
- p_out[out_base_idx] = cls_id[tid] - 1.0
- p_out[out_base_idx + 1] = score[tid]
- p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
- p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4,
- p_anchor, i*4, clip, variances[0],
- variances[1], variances[2], variances[3])
+ out_base_idx = i * num_anchors * 6 + temp_valid_count[tid - 1] * 6
+ out_loc[out_base_idx] = cls_id[tid] - 1.0
+ out_loc[out_base_idx + 1] = score[tid]
+ out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \
+ out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4,
+ anchor, j * 4, clip, variances[0],
+ variances[1], variances[2],
+ variances[3])
- body = ib.get()
- return body
+ return ib.get()
@multibox_transform_loc.register(["cuda", "gpu"])
1-D tensor with shape (batch_size,), number of valid anchor boxes.
"""
batch_size = cls_prob.shape[0]
- num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
oshape = (batch_size, num_anchors, 6)
# Define data alignment for intermediate buffer
valid_count_dtype = "int32"
+ out_loc_dtype = loc_pred.dtype
+
valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
"valid_count_buf", data_alignment=4)
- out_buf = api.decl_buffer(
- oshape, cls_prob.dtype, "out_buf", data_alignment=8)
- size = num_anchors
- temp_flag_buf = api.decl_buffer(
- (size,), valid_count_dtype, "flag", data_alignment=8)
- temp_id_buf = api.decl_buffer(
- (size,), valid_count_dtype, "cls_id", data_alignment=8)
+ loc_pred_buf = api.decl_buffer(loc_pred.shape, loc_pred.dtype,
+ "loc_pred_buf", data_alignment=8)
+ anchor_buf = api.decl_buffer(anchor.shape, anchor.dtype,
+ "anchor_buf", data_alignment=8)
+
+ temp_valid_count_buf = api.decl_buffer(
+ (batch_size, num_anchors,), valid_count_dtype, "temp_valid_count", data_alignment=8)
+ temp_cls_id_buf = api.decl_buffer(
+ (batch_size, num_anchors,), valid_count_dtype, "temp_cls_id", data_alignment=8)
temp_score_buf = api.decl_buffer(
- (size,), cls_prob.dtype, "score", data_alignment=8)
+ (batch_size, num_anchors,), cls_prob.dtype, "temp_score", data_alignment=8)
- valid_count, temp_flag, temp_id, temp_score = \
- tvm.extern([(batch_size,), (size,), (size,), (size,)],
- [cls_prob],
+ valid_count, temp_valid_count, temp_cls_id, temp_score = \
+ tvm.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), \
+ (batch_size, num_anchors,)], [cls_prob],
lambda ins, outs: transform_loc_pre(
ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
- dtype=[valid_count_dtype,
- valid_count_dtype, valid_count_dtype, cls_prob.dtype],
- out_buffers=[valid_count_buf,
- temp_flag_buf, temp_id_buf, temp_score_buf],
- tag="multibox_transform_loc_first_step")
+ dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype],
+ out_buffers=[valid_count_buf, temp_valid_count_buf, \
+ temp_cls_id_buf, temp_score_buf],
+ tag="multibox_transform_loc_phase_one")
- out = \
+ out_loc = \
tvm.extern([oshape],
- [loc_pred, anchor, temp_flag, temp_id, temp_score],
+ [loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score],
lambda ins, outs: transform_loc_ir(
- ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \
- variances, batch_size, num_classes, num_anchors),
- dtype=[cls_prob.dtype],
- out_buffers=[out_buf],
+ ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, variances, \
+ batch_size, num_anchors),
+ in_buffers=[loc_pred_buf, anchor_buf, temp_valid_count_buf, \
+ temp_cls_id_buf, temp_score_buf],
+ dtype=[out_loc_dtype],
tag="multibox_transform_loc")
- return [out, valid_count]
+
+ return [out_loc, valid_count]
@multibox_detection.register(["cuda", "gpu"])
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
- out = non_max_suppression(
- inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
+ out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1,
+ iou_threshold=nms_threshold, force_suppress=force_suppress,
+ top_k=nms_topk, return_indices=False)
return out
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
- if "nms" in op.tag:
- sort = op.input_tensors[1]
+ if op.tag in ["nms", "invalid_to_bottom"]:
+ if op.tag == "nms":
+ sort = op.input_tensors[1]
+ else:
+ out = op.input_tensors[0]
+ sort = s[out].op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
- num_thread = tvm.target.current_target(allow_none=False).max_num_threads
+ num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads)
bx, tx = s[score].split(fused, factor=num_thread)
s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
The computation schedule for the op.
"""
return _default_schedule(outs)
+
+@generic.schedule_argsort.register(["cuda", "gpu"])
+def schedule_argsort(outs):
+ """Schedule for argsort operator.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of argsort
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for the op.
+ """
+ outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+ s = tvm.create_schedule([x.op for x in outs])
+ scheduled_ops = []
+ from .injective import _schedule_injective
+ def traverse(op):
+ for tensor in op.input_tensors:
+ if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+ traverse(tensor.op)
+ scheduled_ops.append(op)
+ traverse(outs[0].op)
+ return s
from .injective import *
from .extern import *
from .vision import *
+from .sort import *
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member
+"""Generic vision operators"""
+from __future__ import absolute_import as _abs
+import tvm
+from .vision import _default_schedule
+
+@tvm.target.generic_func
+def schedule_argsort(outs):
+ """Schedule for argsort operator.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The indices that would sort an input array along
+ the given axis.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for the op.
+ """
+ return _default_schedule(outs, False)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-arguments
+"""Argsort operator"""
+import tvm
+from tvm import api
+
+@tvm.target.generic_func
+def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
+ """Performs sorting along the given axis and returns an array
+ of indices having the same shape as an input array that index
+ data in sorted order.
+
+ Parameters
+ ----------
+ data : tvm.Tensor
+ The input tensor.
+
+ valid_count : tvm.Tensor
+ 1-D tensor for valid number of boxes only for ssd.
+
+ axis : optional, int
+ Axis along which to sort the input tensor.
+ By default the flattened array is used.
+
+ is_ascend : optional, boolean
+ Whether to sort in ascending or descending order.
+
+ dtype : optional, string
+ DType of the output indices.
+
+ flag : optional, boolean
+ Whether valid_count is valid.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ Sorted index tensor.
+
+ Example
+ --------
+ .. code-block:: python
+
+ # An example to use argsort
+ dshape = (1, 5, 6)
+ data = tvm.placeholder(dshape, name="data")
+ valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
+ axis = 0
+ is_ascend = False
+ flag = False
+ out = argsort(data, valid_count, axis, is_ascend, flag)
+ np_data = np.random.uniform(dshape)
+ np_valid_count = np.array([4])
+ s = topi.generic.schedule_argsort(out)
+ f = tvm.build(s, [data, valid_count, out], "llvm")
+ ctx = tvm.cpu()
+ tvm_data = tvm.nd.array(np_data, ctx)
+ tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
+ tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
+ f(tvm_data, tvm_valid_count, tvm_out)
+ """
+ data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+ if flag:
+ valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
+ "valid_count_buf", data_alignment=4)
+ out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
+ out = \
+ tvm.extern(data.shape,
+ [data, valid_count],
+ lambda ins, outs: tvm.call_packed(
+ "tvm.contrib.sort.argsort_nms", ins[0], ins[1],
+ outs[0], axis, is_ascend),
+ dtype="int32",
+ in_buffers=[data_buf, valid_count_buf],
+ out_buffers=out_buf,
+ name="argsort_nms_cpu",
+ tag="argsort_nms_cpu")
+ else:
+ out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
+ out = \
+ tvm.extern(data.shape,
+ [data],
+ lambda ins, outs: tvm.call_packed(
+ "tvm.contrib.sort.argsort", ins[0],
+ outs[0], axis, is_ascend),
+ dtype=dtype,
+ in_buffers=[data_buf],
+ out_buffers=out_buf,
+ name="argsort_cpu",
+ tag="argsort_cpu")
+ return out
"""Non-maximum suppression operator"""
import tvm
-from tvm import api, hybrid
+from tvm import hybrid
+from ..sort import argsort
@hybrid.script
def hybrid_rearrange_out(data):
@hybrid.script
def hybrid_nms(data, sorted_index, valid_count,
max_output_size, iou_threshold, force_suppress,
- top_k, id_index):
+ top_k, coord_start, id_index):
"""Hybrid routing for non-maximum suppression.
Parameters
top_k : tvm.const
Keep maximum top k detections before nms, -1 for no limit.
+ coord_start : tvm.const
+ Start index of the consecutive 4 coordinates.
+
id_index : tvm.const
index of the class categories, -1 to disable.
batch_idx = i
box_a_idx = j
box_b_idx = k
- box_start_idx = 2
+ box_start_idx = coord_start
a_t = output[batch_idx, box_a_idx, box_start_idx + 1]
a_b = output[batch_idx, box_a_idx, box_start_idx + 3]
a_l = output[batch_idx, box_a_idx, box_start_idx]
@tvm.target.generic_func
def non_max_suppression(data, valid_count, max_output_size=-1,
iou_threshold=0.5, force_suppress=False, top_k=-1,
- id_index=0, return_indices=True, invalid_to_bottom=False):
+ coord_start=2, score_index=1, id_index=0,
+ return_indices=True, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit.
+ coord_start : required, int
+ Start index of the consecutive 4 coordinates.
+
+ score_index: optional, int
+ Index of the scores/confidence of boxes.
+
id_index : optional, int
index of the class categories, -1 to disable.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
- valid_count_dtype = "int32"
- valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
- "valid_count_buf", data_alignment=4)
- score_axis = 1
+ score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
- score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
- "score_tensor_buf", data_alignment=8)
- sort_tensor_dtype = "int32"
- sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
- "sort_tensor_buf", data_alignment=8)
- sort_tensor = \
- tvm.extern(score_shape,
- [score_tensor, valid_count],
- lambda ins, outs: tvm.call_packed(
- "tvm.contrib.sort.argsort", ins[0], ins[1],
- outs[0], score_axis, True),
- dtype=sort_tensor_dtype,
- in_buffers=[score_tensor_buf, valid_count_buf],
- out_buffers=sort_tensor_buf,
- name="nms_sort")
+ sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"),
tvm.const(force_suppress, dtype="bool"),
tvm.const(top_k, dtype="int32"),
+ tvm.const(coord_start, dtype="int32"),
tvm.const(id_index, dtype="int32"))
if not return_indices and invalid_to_bottom:
out = hybrid_rearrange_out(out)
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
- out = non_max_suppression(inter_out[0], inter_out[1], -1,
- nms_threshold, force_suppress, nms_topk,
- return_indices=False)
+ out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1,
+ iou_threshold=nms_threshold, force_suppress=force_suppress,
+ top_k=nms_topk, return_indices=False)
return out
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for vision package"""
+from __future__ import print_function
+import math
+import numpy as np
+import tvm
+import topi
+import topi.testing
+
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+from topi import argsort
+
+def test_argsort():
+ dshape = (1, 8)
+ valid_count_shape = (2,)
+ data = tvm.placeholder(dshape, name="data", dtype="float32")
+ valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
+ np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
+ np_valid_count = np.array([4]).astype(valid_count.dtype)
+ np_result = np.argsort(-np_data)
+ def check_device(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ return
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False)
+ s = topi.generic.schedule_argsort(out)
+
+ tvm_data = tvm.nd.array(np_data, ctx)
+ tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
+ tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
+ f = tvm.build(s, [data, valid_count, out], device)
+ f(tvm_data, tvm_valid_count, tvm_out)
+ tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
+
+ for device in ['llvm', 'cuda', 'opencl']:
+ check_device(device)
+
+
+if __name__ == "__main__":
+ test_argsort()
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
- for device in ['llvm']:
+ for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
f(tvm_data, tvm_valid_count, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
- for device in ['llvm']:
+ for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4)
- for device in ['llvm', 'opencl']:
+ for device in ['llvm', 'opencl', 'cuda']:
check_device(device)
f(tvm_a, tvm_rois, tvm_b)
tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3)
- for device in ['llvm', 'cuda']:
+ for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
+`Leyuan Wang <https://github.com/Laurawly>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use GluonCV pre-trained SSD model and convert it to Relay IR
# ------------------------------
# .. note::
#
-# Currently we support compiling SSD on CPU only.
-# GPU support is in progress.
+# We support compiling SSD on bot CPUs and GPUs now.
#
# To get best inference performance on CPU, change
# target argument according to your device and
# follow the :ref:`tune_relay_x86` to tune x86 CPU and
# :ref:`tune_relay_arm` for arm cpu.
#
+# To get best performance fo SSD on Intel graphics,
+# change target argument to 'opencl -device=intel_graphics'
+#
# SSD with VGG as body network is not supported yet since
# x86 conv2d schedule doesn't support dilation.
supported_model = [
- 'ssd_512_resnet18_v1_voc',
- 'ssd_512_resnet18_v1_coco',
'ssd_512_resnet50_v1_voc',
'ssd_512_resnet50_v1_coco',
'ssd_512_resnet101_v2_voc',
- 'ssd_512_mobilenet1_0_voc',
- 'ssd_512_mobilenet1_0_coco',
+ 'ssd_512_mobilenet1.0_voc',
+ 'ssd_512_mobilenet1.0_coco',
]
-model_name = "ssd_512_resnet50_v1_voc"
+model_name = supported_model[0]
dshape = (1, 3, 512, 512)
-dtype = "float32"
target_list = ctx_list()
######################################################################
block = model_zoo.get_model(model_name, pretrained=True)
-def compile(target):
+def build(target):
net, params = relay.frontend.from_mxnet(block, {"data": dshape})
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
return class_IDs, scores, bounding_boxs
for target, ctx in target_list:
- if target == "cuda":
- print("GPU not supported yet, skip.")
- continue
- graph, lib, params = compile(target)
+ graph, lib, params = build(target)
class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
######################################################################