From aa808570dbee19a178130c0f7d9b397ff7f51f0e Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 10 Jun 2020 00:48:33 +0800 Subject: [PATCH] [TOPI][Relay][OP] support dynamic NMS(Non Maximum Suppression), symbolic begin, end, and strides for strided_slice (#4312) * [TOPI][Relay][OP] Dynamic NMS and strided_slice * Incorporate comments * fix nnvm compatibility issues * fix InferCorrectLayout * Minor fix * fix for fuse * Workaround to pass batch_size into hybrid function to handle dynamic shape * Seperate rearrange * fix lint * fix ci, comments * change attr to Optional * clang format * remove empty lines * partial ignore for end of strided_slice * pylint * add out_indices for gpu get_valid_counts * change to slice_mode * clang-format, fix comments * fix comment * change slice_mode to string * fix CI * update docstring Co-authored-by: Yao Wang --- include/tvm/relay/attrs/transform.h | 18 +- include/tvm/relay/attrs/vision.h | 4 +- python/tvm/relay/_parser.py | 2 +- python/tvm/relay/frontend/keras.py | 4 +- python/tvm/relay/frontend/mxnet.py | 25 +- python/tvm/relay/frontend/onnx.py | 13 +- python/tvm/relay/frontend/pytorch.py | 16 +- python/tvm/relay/frontend/tensorflow.py | 84 +++++- python/tvm/relay/frontend/tflite.py | 2 +- python/tvm/relay/op/_tensor_grad.py | 6 +- python/tvm/relay/op/_transform.py | 69 +++++ python/tvm/relay/op/strategy/generic.py | 10 +- python/tvm/relay/op/transform.py | 37 ++- python/tvm/relay/op/vision/_vision.py | 37 +++ python/tvm/relay/op/vision/nms.py | 49 +++- src/relay/analysis/util.cc | 7 + src/relay/op/tensor/transform.cc | 305 ++++++++++++++------- src/relay/op/vision/nms.cc | 25 +- src/relay/transforms/combine_parallel_conv2d.cc | 17 +- src/relay/transforms/pattern_util.h | 2 +- tests/python/frontend/tensorflow/test_debugging.py | 7 +- tests/python/frontend/tensorflow/test_forward.py | 28 +- tests/python/relay/test_any.py | 51 +++- tests/python/relay/test_op_level2.py | 2 +- tests/python/relay/test_op_level4.py | 64 +++-- tests/python/relay/test_op_level5.py | 60 ++-- tests/python/relay/test_pass_alter_op_layout.py | 41 ++- .../relay/test_pass_combine_parallel_conv2d.py | 58 +++- topi/include/topi/transform.h | 24 +- topi/python/topi/cuda/conv2d_alter_op.py | 3 +- topi/python/topi/cuda/nms.py | 32 ++- topi/python/topi/cuda/ssd/multibox.py | 2 +- topi/python/topi/image/dilation2d.py | 12 +- topi/python/topi/math.py | 12 +- topi/python/topi/sort.py | 4 +- topi/python/topi/testing/strided_slice_python.py | 32 ++- topi/python/topi/transform.py | 17 +- topi/python/topi/vision/nms.py | 226 ++++++++++++--- topi/python/topi/vision/ssd/multibox.py | 2 +- topi/python/topi/x86/conv2d_alter_op.py | 4 +- topi/python/topi/x86/conv3d.py | 6 +- topi/src/transform.cc | 2 +- topi/tests/python/test_topi_vision.py | 54 ++-- 43 files changed, 1123 insertions(+), 352 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 03605ee..d709ff2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -210,14 +210,24 @@ struct SplitAttrs : public tvm::AttrsNode { /*! \brief Attributes for StridedSlice operator */ struct StridedSliceAttrs : public tvm::AttrsNode { - Array begin; - Array end; - Array strides; + Optional> begin; + Optional> end; + Optional> strides; + std::string slice_mode; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive"); - TVM_ATTR_FIELD(strides).set_default(Array({})).describe("Stride values of the slice"); + TVM_ATTR_FIELD(strides).describe( + "Stride values of the slice, a stride can be negative, which causes a reverse slice."); + TVM_ATTR_FIELD(slice_mode) + .set_default("end") + .describe( + "The slice mode [end, size]." + "end - The default slice mode, ending indices for the slice." + "size - The input strides will be ignored, input end in this mode indicates the size" + "of a slice starting at the location specified by begin. If end[i] is -1," + "all remaining elements in that dimension are included in the slice"); } }; diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index e7e24b1..550e24b 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -104,7 +104,9 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode 4 else 0.0 + + # Generate data with shape (1, num_anchors, 5) + scores = AttrCvt(op_name="expand_dims", + ignores=['T_threshold'], + extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr) + data = get_relay_op('concatenate')([scores, inputs[0]], -1) + data = get_relay_op('expand_dims')(data, 0, 1) + + # reason why using get_valid_counts is for inference performance + ct, data, indices = get_relay_op('get_valid_counts')(data, + score_threshold=score_threshold, + id_index=-1, + score_index=0) + # TensorFlow NMS doesn't have parameter top_k + top_k = -1 + # TF doesn't have class id for nms input + score_index = 0 + nms_ret = get_relay_op('non_max_suppression')(data=data, + valid_count=ct, + indices=indices, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False) + + # squeeze it, TF NMS is not batched + size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) + data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) + + # slice to get the dynamic result + ret = get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), + end=size, slice_mode="size") + return ret + return _impl + def _decode_image(): def _impl(inputs, attr, params, mod): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. @@ -1119,25 +1175,20 @@ def _slice(): try: begin = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - begin = _infer_value(inputs[1], params).asnumpy().tolist()[0] + # Handle symbolic begin + try: + begin = _infer_value(inputs[1], params).asnumpy().tolist() + except Exception: + begin = inputs[1] try: size = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): # Handle symbolic size try: - size = _infer_value(inputs[2], params).asnumpy().tolist()[0] + size = _infer_value(inputs[2], params).asnumpy().tolist() except Exception: size = inputs[2] - data_shape = _infer_shape(inputs[0], mod) - data_dim = len(data_shape) - end = size - if not isinstance(end, (_expr.Call, _expr.Var)): - for i in range(data_dim): - if size[i] == -1: - end[i] = data_shape[i] - else: - end[i] += begin[i] - return _op.strided_slice(inputs[0], begin=begin, end=end) + return _op.strided_slice(inputs[0], begin=begin, end=size, slice_mode="size") return _impl @@ -1466,8 +1517,11 @@ def _stridedSlice(): fshape_indices = None if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) - out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_shape(out, mod) + out = _op.strided_slice(inputs[0], + begin=begin, + end=end, + strides=stride) + out_shape = _infer_shape(out, mod=mod) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -2026,6 +2080,8 @@ _convert_map = { 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), + 'NonMaxSuppressionV2' : _nms(), + 'NonMaxSuppressionV3' : _nms(), 'NoOp' : _no_op(), 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 08ea715..113f764 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2544,7 +2544,7 @@ class OperatorConverter(object): ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob, anchor_expr, **multibox_transform_loc_attrs) - ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs) + ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs) ret = _op.vision.get_valid_counts(ret, 0) valid_count = ret[0] # keep only the top 'max_detections' rows diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 61488f1..0deb87a 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -390,8 +390,10 @@ def conv2d_grad(orig, grad): assert padded_weight_grad_h >= filter_h assert padded_weight_grad_w >= filter_w if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: - backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0], - end=[None, None, filter_h, filter_w]) + backward_weight = strided_slice(backward_weight, + begin=const([0, 0, 0, 0], dtype="int64"), + end=const([out_channel, in_channel // attrs.groups, + filter_h, filter_w], dtype="int64")) return [backward_data, backward_weight] diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index c99be3c..1d9253f 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -100,9 +100,78 @@ def _arange_shape_func(start, stop, step): @_reg.register_shape_func("arange", True) def arange_shape_func(attrs, inputs, _): + """ + Shape func for arange + """ return [_arange_shape_func(*inputs)] @script +def _strided_slice_shape_func_input_data(data, begin, end, strides, + slice_mode): + ndim = len(data.shape) + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + cbegin = 0 + cend = data.shape[i] + cstride = 1 + if strides.shape[0] > i: + cstride = strides[i] + if begin.shape[0] > i: + cbegin = begin[i] + if end.shape[0] <= i: + cend = data.shape[i] + elif slice_mode != 0: + cstride = 1 + if end[i] < 0: + cend = data.shape[i] + else: + cend = cbegin + end[i] + else: + cend = end[i] + assert cstride != 0, "Strides can't be zero." + out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride))) + return out + +@script +def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode): + ndim = data_shape.shape[0] + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + cbegin = int64(0) + cend = int64(data_shape[i]) + cstride = int64(1) + if len(strides) > i: + cstride = int64(strides[i]) + if len(begin) > i: + cbegin = int64(begin[i]) + if len(end) <= i: + cend = int64(data_shape[i]) + elif slice_mode != 0: + cstride = int64(1) + if end[i] < 0: + cend = int64(data_shape[i]) + else: + cend = cbegin + int64(end[i]) + else: + cend = int64(end[i]) + assert cstride != 0, "Strides can't be zero." + out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride))) + return out + + +@_reg.register_shape_func("strided_slice", True) +def strided_slice_shape_func(attrs, inputs, _): + """ + Shape func for strided_slice + """ + slice_mode = convert(0 if attrs.slice_mode == "end" else 1) + # data independent if begin, end and strides exist + if attrs.begin and attrs.end and attrs.strides: + return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end, + attrs.strides, slice_mode)] + return [_strided_slice_shape_func_input_data(*inputs, slice_mode)] + +@script def _concatenate_shape_func(inputs, axis): ndim = inputs[0].shape[0] out = output_tensor((ndim,), "int64") diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 99439af..de808d1 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -695,9 +695,13 @@ def wrap_compute_nms(topi_compute): score_index = get_const_int(attrs.score_index) id_index = get_const_int(attrs.id_index) invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom)) - return [topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, - force_suppress, top_k, coord_start, score_index, - id_index, return_indices, invalid_to_bottom)] + if return_indices: + return topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, + force_suppress, top_k, coord_start, score_index, id_index, + return_indices, invalid_to_bottom) + return [topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, + force_suppress, top_k, coord_start, score_index, id_index, + return_indices, invalid_to_bottom)] return _compute_nms @override_native_generic_func("non_max_suppression_strategy") diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index a3c6517..1ee2bdb 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -611,7 +611,7 @@ def split(data, indices_or_sections, axis=0): return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) -def strided_slice(data, begin, end, strides=None): +def strided_slice(data, begin, end, strides=None, slice_mode="end"): """Strided slice of an array. Parameters @@ -619,23 +619,36 @@ def strided_slice(data, begin, end, strides=None): data : relay.Expr The source array to be sliced. - begin: list of int + begin : relay.Expr, Tuple[int], or List[int] The indices to begin with in the slicing. - end: list of int + end : relay.Expr, Tuple[int], or List[int] Indices indicating end of the slice. - strides: list of int, optional + strides : relay.Expr, Tuple[int], or List[int], optional Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. + slice_mode : str, optional + The slice mode [end, size]. + end: The ending indices for the slice [default]. + size: The input strides will be ignored, input end in this mode indicates + the size of a slice starting at the location specified by begin. If end[i] + is -1, all remaining elements in that dimension are included in the slice. + Returns ------- ret : relay.Expr The computed result. """ - strides = strides or [] - return _make.strided_slice(data, list(begin), list(end), list(strides)) + strides = strides or const([1], dtype="int32") + if isinstance(begin, (tuple, list)): + begin = const(list(begin)) + if isinstance(end, (tuple, list)): + end = const(list(end)) + if isinstance(strides, (tuple, list)): + strides = const(list(strides)) + return _make.strided_slice(data, begin, end, strides, slice_mode) def strided_set(data, v, begin, end, strides=None): @@ -649,13 +662,13 @@ def strided_set(data, v, begin, end, strides=None): v : relay.Expr The data to be set. - begin: relay.Expr + begin: relay.Expr, Tuple[int], or List[int] The indices to begin with in the slicing. - end: relay.Expr + end: relay.Expr, Tuple[int], or List[int] Indices indicating end of the slice. - strides: relay.Expr, optional + strides: relay.Expr, Tuple[int], or List[int], optional Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. @@ -665,6 +678,12 @@ def strided_set(data, v, begin, end, strides=None): The computed result. """ strides = strides or const([1], dtype="int32") + if isinstance(begin, (tuple, list)): + begin = const(list(begin)) + if isinstance(end, (tuple, list)): + end = const(list(end)) + if isinstance(strides, (tuple, list)): + strides = const(list(strides)) return _make.strided_set(data, v, begin, end, strides) diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 6e2008a..f6c4f81 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -18,6 +18,8 @@ """Definition of vision ops""" from __future__ import absolute_import +import topi +from tvm.te.hybrid import script from .. import op as reg from .. import strategy from ..op import OpPattern @@ -40,3 +42,38 @@ reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE) # non-maximum suppression reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy) reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE) + +@script +def _get_valid_counts_shape_func(data_shape): + valid_counts_shape = output_tensor((1,), "int64") + out_tensor_shape = output_tensor((data_shape.shape[0],), "int64") + out_indices_shape = output_tensor((2,), "int64") + + valid_counts_shape[0] = data_shape[0] + for i in const_range(data_shape.shape[0]): + out_tensor_shape[i] = data_shape[i] + out_indices_shape[0] = data_shape[0] + out_indices_shape[1] = data_shape[1] + + return valid_counts_shape, out_tensor_shape, out_indices_shape + +@reg.register_shape_func("vision.get_valid_counts", False) +def get_valid_counts_shape_func(attrs, inputs, _): + return _get_valid_counts_shape_func(inputs[0]) + +@script +def _nms_shape_func(data_shape): + out_shape = output_tensor((2,), "int64") + count_shape = output_tensor((2,), "int64") + + out_shape[0] = data_shape[0] + out_shape[1] = data_shape[1] + count_shape[0] = data_shape[0] + count_shape[1] = int64(1) + return out_shape, count_shape + +@reg.register_shape_func("vision.non_max_suppression", False) +def nms_shape_func(attrs, inputs, _): + if attrs.return_indices: + return _nms_shape_func(inputs[0]) + return [topi.math.identity(inputs[0])] diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 70a9ec9..b60b49a 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -47,14 +47,18 @@ def get_valid_counts(data, out_tensor : relay.Expr Rearranged data tensor. + + out_indices: relay.Expr + Indices in input data """ return expr.TupleWrapper( _make.get_valid_counts(data, score_threshold, - id_index, score_index), 2) + id_index, score_index), 3) def non_max_suppression(data, valid_count, + indices, max_output_size=-1, iou_threshold=0.5, force_suppress=False, @@ -69,12 +73,23 @@ def non_max_suppression(data, Parameters ---------- data : relay.Expr - 3-D tensor with shape [batch_size, num_anchors, 6]. + 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. The last dimension should be in format of - [class_id, score, box_left, box_top, box_right, box_bottom]. + [class_id, score, box_left, box_top, box_right, box_bottom] + or [score, box_left, box_top, box_right, box_bottom]. It could + be the second output out_tensor of get_valid_counts. valid_count : relay.Expr - 1-D tensor for valid number of boxes. + 1-D tensor for valid number of boxes. It could be the output + valid_count of get_valid_counts. + + indices: relay.Expr + 2-D tensor with shape [batch_size, num_anchors], represents + the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the + second dimension are like the output of arange(num_anchors) + if get_valid_counts is not used before non_max_suppression. max_output_size : int, optional Max number of output valid boxes for each instance. @@ -106,10 +121,24 @@ def non_max_suppression(data, Returns ------- - out : relay.Expr - 3-D tensor with shape [batch_size, num_anchors, 6]. + out : relay.Expr or relay.Tuple + return relay.Expr if return_indices is disabled, a 3-D tensor + with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5]. + if return_indices is True, return relay.Tuple of two 2-D tensors, with + shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively. """ - return _make.non_max_suppression(data, valid_count, max_output_size, - iou_threshold, force_suppress, top_k, - coord_start, score_index, id_index, - return_indices, invalid_to_bottom) + out = _make.non_max_suppression(data, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, + return_indices, + invalid_to_bottom) + if return_indices: + return expr.TupleWrapper(out, 2) + return out diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 346adf9..0885a35 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -458,6 +458,13 @@ bool IsDataDependant(const CallNode* call) { return false; } } + } else if (op->name == "strided_slice") { + if (const auto* attrs = call->attrs.as()) { + if (attrs->begin && attrs->end && attrs->strides) { + // not data dependant if begin, end and strides exist + return false; + } + } } return tshape_data_dependant[op]; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 0275a89..136ae00 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1096,7 +1096,8 @@ bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, const te::Tensor& step, tvm::DataType dtype, - std::string name = "tensor", std::string tag = topi::kInjective) { + std::string name = "T_arange_dynamic", + std::string tag = topi::kInjective) { tvm::PrimExpr num_elem = tvm::tir::Var("num_elem"); return te::compute( {num_elem}, @@ -1109,6 +1110,7 @@ inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, Array ArangeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const ArangeAttrs* param = attrs.as(); + CHECK(param != nullptr); te::Tensor start = inputs[0]; te::Tensor stop = inputs[1]; te::Tensor step = inputs[2]; @@ -1670,93 +1672,109 @@ Array GetIntArray(Array arr) { // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); + bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - + CHECK_EQ(types.size(), 5); const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); - + const auto* data = types[0].as(); + CHECK(data != nullptr); auto dshape = data->shape; - auto num_axis = dshape.size(); - - std::vector stride_vec; - for (Integer i : param->strides) { - CHECK(i.defined()); - stride_vec.push_back(i->value); - } - for (size_t i = stride_vec.size(); i < num_axis; ++i) { - stride_vec.push_back(1); - } - const int64_t max_range = std::numeric_limits::max(); - - std::vector begin_vec; - for (size_t i = 0; i < param->begin.size(); ++i) { - if (!param->begin[i].defined()) { - // value=None + int64_t num_axis = dshape.size(); + + // calculate output shape + std::vector oshape(num_axis); + if (param->begin && param->end && param->strides) { + // stride will be set as 1 if slice mode is enabled + std::vector stride_vec(num_axis, 1); + if (param->slice_mode == "end") { + for (size_t i = 0; i < param->strides.value().size(); ++i) { + CHECK(param->strides.value()[i].defined()); + stride_vec[i] = param->strides.value()[i]->value; + } + } + const int64_t max_range = std::numeric_limits::max(); + std::vector begin_vec; + for (size_t i = 0; i < param->begin.value().size(); ++i) { + if (!param->begin.value()[i].defined()) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(param->begin.value()[i]->value); + } + } + for (int64_t i = begin_vec.size(); i < num_axis; ++i) { begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(param->begin[i]->value); } - } - for (size_t i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } - std::vector end_vec; - for (size_t i = 0; i < param->end.size(); ++i) { - // allow end to be None - if (!param->end[i].defined()) { + std::vector end_vec; + for (size_t i = 0; i < param->end.value().size(); ++i) { + // allow end to be None + if (!param->end.value()[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (param->slice_mode == "size") { + if (param->end.value()[i]->value < 0) { + end_vec.push_back(max_range); + } else { + end_vec.push_back(begin_vec[i] + param->end.value()[i]->value); + } + } else if (param->slice_mode == "end") { + end_vec.push_back(param->end.value()[i]->value); + } else { + LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode; + } + } + for (int64_t i = end_vec.size(); i < num_axis; ++i) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else { - end_vec.push_back(param->end[i]->value); } - } - for (size_t i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } - std::vector oshape(dshape.size()); - for (size_t i = 0; i < num_axis; ++i) { - int64_t stride_v = stride_vec[i]; - int64_t begin_v = begin_vec[i]; - int64_t end_v = end_vec[i]; + for (int64_t i = 0; i < num_axis; ++i) { + int64_t stride_v = stride_vec[i]; + int64_t begin_v = begin_vec[i]; + int64_t end_v = end_vec[i]; - if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || - (stride_v == -1 && begin_v == max_range && end_v == 0)) { - // Quick path, do not slice this dimension. - oshape[i] = dshape[i]; - continue; + if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || + (stride_v == -1 && begin_v == max_range && end_v == 0)) { + // Quick path, do not slice this dimension. + oshape[i] = dshape[i]; + continue; + } + // Normal path, require the shape to be concrete integer. + // Require concrete integer as symbolic inference of min/max + // can get complicated and not very helpful. + const int64_t* p_dim_size = tir::as_const_int(dshape[i]); + if (!p_dim_size) { + oshape[i] = dshape[i]; + continue; + } + int64_t dim_size = p_dim_size[0]; + begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; + end_v = (end_v < 0) ? dim_size + end_v : end_v; + + int64_t slice_range, step; + if (stride_v < 0) { + if (end_v < -1) end_v = -1; + CHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; + begin_v = std::min(dim_size - 1, begin_v); + slice_range = begin_v - end_v; + step = -stride_v; + } else { + if (begin_v < 0) begin_v = 0; + CHECK_GE(stride_v, 0); + CHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; + end_v = std::min(dim_size, end_v); + slice_range = end_v - begin_v; + step = stride_v; + } + oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } - // Normal path, require the shape to be concrete integer. - // Require concrete integer as symbolic inference of min/max - // can get complicated and not very helpful. - const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - CHECK(p_dim_size) << "strided_slice requires sliced dimension to be concrete int"; - int64_t dim_size = p_dim_size[0]; - begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; - end_v = (end_v < 0) ? dim_size + end_v : end_v; - - int64_t slice_range, step; - if (stride_v < 0) { - if (end_v < -1) end_v = -1; - CHECK_LT(end_v, begin_v) << "strided_slice get empty slice at axis " << i; - begin_v = std::min(dim_size - 1, begin_v); - slice_range = begin_v - end_v; - step = -stride_v; - } else { - if (begin_v < 0) begin_v = 0; - CHECK_GE(stride_v, 0); - CHECK_LT(begin_v, end_v) << "strided_slice get empty slice at axis " << i; - end_v = std::min(dim_size, end_v); - slice_range = end_v - begin_v; - step = stride_v; + } else { + for (int64_t i = 0; i < num_axis; ++i) { + oshape[i] = Any::make(); } - oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } - reporter->Assign(types[1], TensorType(oshape, data->dtype)); + + reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } @@ -1771,22 +1789,39 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, } CHECK(old_in_layouts.defined()); - CHECK_EQ(old_in_layouts.size(), 1); + CHECK_GE(old_in_layouts.size(), 1); CHECK(old_in_shapes.defined()); - CHECK_EQ(old_in_shapes.size(), 1); + CHECK_GE(old_in_shapes.size(), 1); auto layout = old_in_layouts[0]; if (layout.defined() && new_in_layouts.defined()) { - CHECK_EQ(new_in_layouts.size(), 1); + CHECK_GE(new_in_layouts.size(), 1); auto new_layout = new_in_layouts[0]; auto shape = old_in_shapes[0]; // NOTE: Discard "const" qualifier here. auto* params = const_cast(attrs.as()); + CHECK(params != nullptr); + Array begin, end, strides; + if (params->begin && params->end && params->strides) { + for (Integer i : params->strides.value()) { + CHECK(i.defined()); + strides.push_back(params->slice_mode == "size" ? 1 : i->value); + } + + for (Integer i : params->begin.value()) { + CHECK(i.defined()); + begin.push_back(i->value); + } + for (Integer i : params->end.value()) { + CHECK(i.defined()); + end.push_back(i->value); + } + } Array new_begin, new_end; - for (size_t i = 0; i < params->begin.size(); i++) { + for (size_t i = 0; i < begin.size(); i++) { const LayoutAxis& axis = layout[i]; if (!axis.IsPrimal()) { // original layout that contains splitted axes is not supported @@ -1794,50 +1829,115 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, } auto factor = new_layout.FactorOf(axis); if (factor == -1) { - new_begin.push_back(params->begin[i]); - new_end.push_back(params->end[i]); + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); } else { - if (params->strides.defined() && i < params->strides.size()) { - auto stride = params->strides[i]; + if (strides.defined() && i < strides.size()) { + auto stride = strides[i]; // arbitrary stride is not supported if (stride.defined() && stride->value != 1) { return {{Layout::Undef()}, {Layout::Undef()}}; } } - int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0; - int64_t end = - params->end[i].defined() ? params->end[i]->value : shape[i].as()->value; - if (begin % factor || end % factor) { + int64_t bg = begin[i].defined() ? begin[i]->value : 0; + int64_t ed; + if (!end[i].defined()) { + ed = shape[i].as()->value; + } else if (params->slice_mode == "size") { + if (end[i]->value < 0) { + ed = shape[i].as()->value; + } else { + ed = bg + end[i]->value; + } + } else { + ed = end[i]->value; + } + + if (bg % factor || ed % factor) { // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; } - new_begin.push_back(tvm::Integer(begin / factor)); - new_end.push_back(tvm::Integer(end / factor)); + new_begin.push_back(tvm::Integer(bg / factor)); + new_end.push_back(tvm::Integer(ed / factor)); } } + layout = new_layout; params->begin = new_begin; params->end = new_end; } - return {{layout}, {layout}}; + return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}}; } -// Positional relay function to create StridedSlice operator used by frontend FFI. -Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides) { - auto attrs = make_object(); - attrs->begin = std::move(begin); - attrs->end = std::move(end); - attrs->strides = std::move(strides); - static const Op& op = Op::Get("strided_slice"); - return Call(op, {data}, Attrs(attrs), {}); +inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const te::Tensor& begin, + const te::Tensor& end, const te::Tensor& strides, + std::string name = "T_strided_slice_dynamic", + std::string tag = topi::kInjective) { + int64_t src_tensor_dim = input->shape.size(); + Array out_shape; + for (int64_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(tvm::tir::Var("dim")); + } + // TODO(yongwww): move the compute into topi + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (int32_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides(i) + begin(i)); + } + return input(real_indices); + }, + name, tag); } Array StridedSliceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); - return Array{ - topi::strided_slice(inputs[0], param->begin, param->end, param->strides)}; + if (param->begin && param->end && param->strides) { + Array begin, end, strides; + begin = param->begin.value(); + end = param->end.value(); + strides = param->strides.value(); + return Array{ + topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; + } else { + te::Tensor data = inputs[0]; + te::Tensor begin = inputs[1]; + te::Tensor end = inputs[2]; + te::Tensor strides = inputs[3]; + // Dynamic computation + int64_t attr_size = data->shape.size(); + CHECK(begin->shape[0].as()->value == attr_size && + end->shape[0].as()->value == attr_size && + strides->shape[0].as()->value == attr_size) + << "begin, end, and strides are required to have the same length" + << " if they are non-constant."; + return Array{DynamicStridedSlice(data, begin, end, strides)}; + } +} + +// Positional relay function to create StridedSlice operator used by frontend FFI. +Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode) { + auto attrs = make_object(); + const ConstantNode *cbegin, *cend, *cstrides; + if ((cbegin = begin.as()) && (cend = end.as()) && + (cstrides = strides.as())) { + CHECK_EQ(cbegin->data->ndim, 1); + CHECK_EQ(cend->data->ndim, 1); + CHECK_EQ(cstrides->data->ndim, 1); + Array begin, end, strides; + begin = ToVector(cbegin->data); + end = ToVector(cend->data); + strides = ToVector(cstrides->data); + attrs->begin = begin; + attrs->end = end; + attrs->strides = strides; + } + attrs->slice_mode = slice_mode; + static const Op& op = Op::Get("strided_slice"); + return Call(op, {data, begin, end, strides}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.strided_slice").set_body_typed(MakeStridedSlice); @@ -1866,13 +1966,18 @@ Examples:: [[ 5., 6.], [ 7., 8.]]] )code" TVM_ADD_FILELINE) - .set_num_inputs(1) + .set_num_inputs(4) .add_argument("data", "Tensor", "The input tensor.") + .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") + .add_argument("end", "Tensor", "Indices indicating end of the slice.") + .add_argument("strides", "Tensor", "The stride values.") + .add_argument("slice_mode", "Tensor", "The slice mode.") .set_support_level(4) .set_attrs_type() .add_type_rel("StridedSlice", StridedSliceRel) .set_attr("FTVMCompute", StridedSliceCompute) .set_attr("TOpPattern", kInjective) + .set_attr("AnyCodegenStrategy", kVariableDimensions) .set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); // strided_set @@ -2126,7 +2231,7 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& } } return Array{topi::strided_slice(inputs[0], GetIntArray(begin_idx), - GetIntArray(end_idx), GetIntArray(strides))}; + GetIntArray(end_idx), GetIntArray(strides), "end")}; } TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike); diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index b1aaaf0..7486db7 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -37,9 +37,11 @@ bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& att CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; std::vector oshape({data->shape[0]}); + std::vector oshape_indices({data->shape[0], data->shape[1]}); std::vector fields; fields.push_back(TensorType(oshape, DataType::Int(32))); fields.push_back(TensorType(data->shape, data->dtype)); + fields.push_back(TensorType(oshape_indices, DataType::Int(32))); // assign output type reporter->Assign(types[1], TupleType(Array(fields))); @@ -71,7 +73,7 @@ TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); + CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* valid_count = types[1].as(); const NonMaximumSuppressionAttrs* param = attrs.as(); @@ -82,15 +84,20 @@ bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, // assign output type if (param->return_indices) { + std::vector fields; + // dynamic happens for return_indices in TensorFlow & ONNX std::vector oshape({dshape[0], dshape[1]}); - reporter->Assign(types[2], TensorType(oshape, DataType::Int(32))); + fields.push_back(TensorType(oshape, DataType::Int(32))); + std::vector countshape({dshape[0], 1}); + fields.push_back(TensorType(countshape, DataType::Int(32))); + reporter->Assign(types[3], TupleType(Array(fields))); } else { - reporter->Assign(types[2], TensorType(dshape, data->dtype)); + reporter->Assign(types[3], TensorType(dshape, data->dtype)); } return true; } -Expr MakeNMS(Expr data, Expr valid_count, int max_output_size, double iou_threshold, +Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size, double iou_threshold, bool force_suppress, int top_k, int coord_start, int score_index, int id_index, bool return_indices, bool invalid_to_bottom) { auto attrs = make_object(); @@ -104,19 +111,21 @@ Expr MakeNMS(Expr data, Expr valid_count, int max_output_size, double iou_thresh attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; static const Op& op = Op::Get("vision.non_max_suppression"); - return Call(op, {data, valid_count}, Attrs(attrs), {}); + return Call(op, {data, valid_count, indices}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS); RELAY_REGISTER_OP("vision.non_max_suppression") .describe(R"doc(Non-maximum suppression. The input boxes should -be in the format of [class_id, score, left, top, right, bottom]. -Set id_index to be -1 to ignore class_id axis. +be in the format of [class_id, score, left, top, right, bottom] +or [score, left, top, right, bottom]. Set id_index to be -1 to +ignore class_id axis. )doc" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(3) .add_argument("data", "Tensor", "Input data.") .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") + .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.") .set_support_level(5) .add_type_rel("NMS", NMSRel); diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 1990414..0bf9e7f 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -164,19 +164,28 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int64_t index = 0; + for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; int64_t channels = GetConv2DSuperChannelsDim(conv2d); - Array begin; - Array end; + std::vector begin; + std::vector end; for (size_t i = 0; i < channel_pos_; i++) { begin.push_back(0); - end.push_back(NullValue()); + end.push_back(-1); } begin.push_back(index); index += channels; end.push_back(index); - auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); + std::vector strides(begin.size(), 1); + for (size_t i = 0; i < begin.size(); ++i) { + end[i] -= begin[i]; + } + std::vector ndarray_shape = {static_cast(begin.size())}; + Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin); + Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end); + Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides); + auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size"); subst_map->insert({GetRef(branch[depth]), slice}); } } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 06b1e82..7518eb9 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -673,7 +673,7 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeRepeat(Expr data, int repeats, int axis); -Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); +Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode); Expr MakeStack(Expr data, int axis); diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py index 01ad6a2..a6df6ff 100644 --- a/tests/python/frontend/tensorflow/test_debugging.py +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -51,7 +51,8 @@ def test_assert_true(): # do that, it's happening in Relay, and that optimization shouldn't # affect the arity of the main function. We should have to pass in # x_value here. - np.testing.assert_allclose(0, run_relay(g, {'input':shape}).asnumpy()) + np.testing.assert_allclose(0, run_relay(g, {'input': shape}).asnumpy()) + def test_assert_true_var_capture(): g = tf.Graph() @@ -71,7 +72,7 @@ def test_assert_true_var_capture(): # the graph as a boolean, which is not correct - as you can see above, # TF believes that the value of this graph is None. np.testing.assert_allclose(True, - run_relay(g, None, x_value).asnumpy()) + run_relay(g, None, x_value).asnumpy()) def test_assert_false(): g = tf.Graph() @@ -91,9 +92,7 @@ def test_assert_false(): # argument is false. np.testing.assert_allclose(0, run_relay(g).asnumpy()) - if __name__ == "__main__": test_assert_true() test_assert_true_var_capture() test_assert_false() - diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f9fc5dd..07c1cd3 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -474,7 +474,7 @@ def test_forward_convolution(): ####################################################################### # Convolution3D -# ----------- +# ------------- def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes, @@ -1892,6 +1892,30 @@ def test_forward_crop_and_resize(): ####################################################################### +# Non Max Suppression +# ------------------- +def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"): + boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype) + scores = np.random.uniform(size=score_shape).astype(dtype) + tf.reset_default_graph() + in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1") + in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2") + tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2, + max_output_size=out_size, iou_threshold=iou_threshold, + score_threshold=score_threshold, name="nms") + compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], + 'nms/NonMaxSuppressionV3:0', mode='vm') + compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], + 'nms/NonMaxSuppressionV3:0', mode='debug') + +def test_forward_nms_v3(): + """ NonMaxSuppressionV3 """ + _test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5) + _test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10) + _test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000) + + +####################################################################### # LSTM # ---- @@ -3568,6 +3592,7 @@ if __name__ == '__main__': test_forward_truncatemod() test_forward_one_hot() test_forward_atan2() + test_forward_nms_v3() # Activations test_forward_sigmoid() @@ -3625,6 +3650,7 @@ if __name__ == '__main__': # NN test_forward_convolution() + test_forward_convolution3d() test_forward_pooling() test_forward_concat_v2() test_forward_lrn() diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 504c20a..8e535a6 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -22,6 +22,7 @@ from tvm import te from tvm import relay from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type +import topi.testing def int32(val): return relay.const(val, 'int32') @@ -642,6 +643,52 @@ def test_arange_with_dynamic_shape(): result = ex.evaluate()(data) tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1) +def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, + data_np_shape, slice_mode="end", const_attrs=False): + # Generate random numpy input data + np_data = np.random.uniform(size=data_np_shape).astype('float32') + np_begin = np.random.randint(2, size=begin_shape, dtype="int32") + np_end = np.random.randint(5, 10, size=end_shape, dtype="int32") + np_strides = np.random.randint(1, 2 if slice_mode == "size" else 3, size=strides_shape, dtype="int32") + # target numpy result + ref_res = topi.testing.strided_slice_python(np_data, np_begin, np_end, np_strides, slice_mode) + + # Relay Module + mod = tvm.IRModule() + data = relay.var('data', shape=data_shape, dtype='float32') + if const_attrs: + data = relay.var('data', shape=data_np_shape, dtype='float32') + begin = relay.const(np_begin) + end = relay.const(np_end) + strides = relay.const(np_strides) + args = [data] + np_inputs = [np_data] + else: + begin = relay.var('begin', shape=begin_shape, dtype="int32") + end = relay.var('end', shape=end_shape, dtype="int32") + strides = relay.var('strides', shape=strides_shape, dtype="int32") + args = [data, begin, end, strides] + np_inputs = [np_data, np_begin, np_end, np_strides] + + y = relay.strided_slice(data, begin=begin, end=end, + strides=strides, slice_mode=slice_mode) + mod["main"] = relay.Function(args, y) + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(*np_inputs) + tvm.testing.assert_allclose(result.asnumpy(), ref_res) + + +def test_any_strided_slice(): + verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21)) + verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21)) + verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (23, 29, 41)) + verify_any_strided_slice(any_dims(4), (4,), (4,), (4,), (40, 50, 60, 70)) + verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode="size") + verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21), const_attrs=True) + + def test_recursive_concat(): """ fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) { @@ -767,7 +814,7 @@ def test_mixed_input_type(): ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1) assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape)) + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) if __name__ == "__main__": test_any_full() @@ -796,7 +843,9 @@ if __name__ == "__main__": test_any_softmax() test_any_topk() test_fused_ops() + test_any_argwhere() test_arange_with_dynamic_shape() + test_any_strided_slice() test_recursive_concat() test_recursive_concat_with_wrong_annotation() test_tuple_get_item() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 3e8720d..2b5e67c 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -780,7 +780,7 @@ def _test_pool2d_int(opfunc, reffunc, dtype): x = relay.var("x", shape=dshape, dtype=dtype) y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) func = relay.Function([x], y) - data = np.random.random_integers(low=-128, high=128, size=dshape) + data = np.random.randint(low=-128, high=128, size=dshape) ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)).astype(dtype) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 947a4bf..74231cb 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -58,11 +58,11 @@ def test_binary_op(): def test_cmp_type(): for op, ref in ((relay.greater, np.greater), - (relay.greater_equal, np.greater_equal), - (relay.less, np.less), - (relay.less_equal, np.less_equal), - (relay.equal, np.equal), - (relay.not_equal, np.not_equal)): + (relay.greater_equal, np.greater_equal), + (relay.less, np.less), + (relay.less_equal, np.less_equal), + (relay.equal, np.equal), + (relay.not_equal, np.not_equal)): x = relay.var("x", relay.TensorType((10, 4), "float32")) y = relay.var("y", relay.TensorType((5, 10, 1), "float32")) z = op(x, y) @@ -296,38 +296,68 @@ def test_mean_var_std(): def test_strided_slice(): - def verify(dshape, begin, end, strides, output, test_ref=True): + def verify(dshape, begin, end, strides, output, slice_mode="end", + attr_const=True, test_ref=True, dtype="int32"): x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.strided_slice(x, begin=begin, end=end, strides=strides) + ndim = len(dshape) + begin = begin if begin else [0] * ndim + end = end if end else list(dshape) + + # target numpy result + x_data = np.random.uniform(size=dshape).astype("float32") + ref_res = topi.testing.strided_slice_python( + x_data, begin, end, strides, slice_mode) + + if attr_const: + begin = relay.const(begin, dtype=dtype) + end = relay.const(end, dtype=dtype) + + if strides: + if attr_const: + strides = relay.const(strides, dtype=dtype) + z = relay.strided_slice(x, + begin=begin, + end=end, + strides=strides, + slice_mode=slice_mode) + else: + z = relay.strided_slice(x, + begin=begin, + end=end, + slice_mode=slice_mode) func = relay.Function([x], z) + func = run_infer_type(func) text = func.astext() assert "begin=" in text assert "end=" in text + if output: assert func.body.checked_type == relay.ty.TensorType(output, "float32") + if not test_ref: return - x_data = np.random.uniform(size=dshape).astype("float32") - ref_res = topi.testing.strided_slice_python( - x_data, begin, end, strides) for target, ctx in ctx_list(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) - d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") - verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False) + verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64") + verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3], + [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64") + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) - verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) - verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2), attr_const=False) verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) - + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], + (2, 4, 3), slice_mode="size", test_ref=False) + verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], + (2, 2, 3), slice_mode="size", test_ref=True) def test_strided_set(): def verify(dshape, begin, end, strides, vshape, test_ref=True): diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index c306752..40842eb 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -244,6 +244,7 @@ def test_get_valid_counts(): np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor)) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 @@ -253,10 +254,12 @@ def test_get_valid_counts(): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 + np_out3[i, inter_idx] = j inter_idx += 1 if j >= np_out1[i]: for k in range(elem_length): np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 x = relay.var("x", relay.ty.TensorType(dshape, dtype)) z = relay.vision.get_valid_counts(x, score_threshold, id_index, score_index) @@ -271,6 +274,7 @@ def test_get_valid_counts(): if target == 'cuda': return tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) + tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04) verify_get_valid_counts((1, 2500, 6), 0, 0, 1) verify_get_valid_counts((1, 2500, 5), -1, -1, 0) @@ -279,69 +283,79 @@ def test_get_valid_counts(): def test_non_max_suppression(): - def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, + def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res, iou_threshold=0.5, force_suppress=False, top_k=-1, check_type_only=False): x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32")) - z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ - iou_threshold = iou_threshold, force_suppress = force_suppress, \ - top_k = top_k, return_indices=False) - z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ - iou_threshold = iou_threshold, force_suppress = force_suppress, \ - top_k = top_k) + x2 = relay.var("x2", relay.ty.TensorType((dshape[0], dshape[1]), "int32")) + z = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \ + iou_threshold=iou_threshold, force_suppress=force_suppress, \ + top_k=top_k, return_indices=False) + z_indices = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \ + iou_threshold=iou_threshold, force_suppress=force_suppress, \ + top_k=top_k, return_indices=True) + if isinstance(z_indices, relay.expr.TupleWrapper): + z_indices = z_indices.astuple() assert "iou_threshold" in z.astext() assert "iou_threshold" in z_indices.astext() zz = run_infer_type(z) zz_indices = run_infer_type(z_indices) assert zz.checked_type == relay.ty.TensorType(dshape, "float32") - assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32") + assert zz_indices.checked_type == relay.ty.TupleType( + [relay.ty.TensorType((dshape[0], dshape[1]), "int32"), + relay.ty.TensorType((dshape[0], 1), "int32")]) if check_type_only: return - func = relay.Function([x0, x1], z) + func = relay.Function([x0, x1, x2], z) func = run_infer_type(func) - func_indices = relay.Function([x0, x1], z_indices) + func_indices = relay.Function([x0, x1, x2], z_indices) func_indices = run_infer_type(func_indices) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x0_data, x1_data) - op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data) + op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) - tvm.testing.assert_allclose(op_indices_res1.asnumpy(), ref_indices_res, rtol=1e-5) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res2 = intrp2.evaluate(func)(x0_data, x1_data) - op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data) + op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) - tvm.testing.assert_allclose(op_indices_res2.asnumpy(), ref_indices_res, rtol=1e-5) + if target == 'cuda': + return + op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data) + tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5) + op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data) + tvm.testing.assert_allclose(op_indices_res2[0].asnumpy(), ref_indices_res, rtol=1e-5) np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") + + np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[3, 0, -1, -1, -1]]) + np_indices_result = np.array([[4, 0, -1, -1, -1]]) num_anchors = 5 dshape = (te.size_var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, force_suppress=True, top_k=2, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, force_suppress=True, top_k=2, check_type_only=False) np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[3, 0, 1, -1, -1]]) + np_indices_result = np.array([[4, 0, 1, -1, -1]]) dshape = (te.size_var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, top_k=3) @@ -384,7 +398,7 @@ def test_multibox_transform_loc(): assert ret.checked_type == ref_type - nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False) + nms = relay.vision.non_max_suppression(mtl[0], mtl[1], mtl[0], return_indices=False) func = relay.Function([cls_prob, loc_pred, anchors], nms) func = run_infer_type(func) for target, ctx in ctx_list(): diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index bc0420f..bbe10c7 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -18,10 +18,11 @@ import pytest import tvm -from tvm import te from tvm import relay from tvm.relay import transform, analysis from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.relay.testing import ctx_list, run_infer_type +import numpy as np def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] @@ -620,7 +621,10 @@ def test_alter_layout_strided_slice(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) - y = relay.strided_slice(y, begin=[0, 16], end=[None, None]) + y = relay.strided_slice(y, + begin=relay.const([0, 16], "int32"), + end=relay.const([1, 33], "int32"), + strides=relay.const([1, 1], "int32")) y = relay.Function(analysis.free_vars(y), y) return y @@ -632,22 +636,41 @@ def test_alter_layout_strided_slice(): def expected(): x = relay.var("x", shape=(1, 32, 28, 28)) - weight = relay.var("weight") + weight = relay.var("weight", shape=(32, 32, 3, 3)) + weight = relay.layout_transform(weight, "OIHW", "OIHW4i4o") x = relay.layout_transform(x, "NCHW", "NCHW4c") - y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), - data_layout="NCHW4c") - y = relay.strided_slice(y, begin=[0, 4], end=[None, 8]) + y = relay.op.nn.contrib_conv2d_nchwc(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), + data_layout="NCHW4c") + + y = relay.strided_slice(y, + begin=relay.const([0, 4], "int32"), + end=relay.const([1, 21], "int32"), + strides=relay.const([1, 1], "int32")) + y = relay.layout_transform(y, "NCHW4c", "NCHW") y = relay.Function(analysis.free_vars(y), y) return y with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = before() - a = run_opt_pass(a, [transform.CanonicalizeOps(), - transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + # Verify inference result + mod_before = tvm.IRModule() + mod_new = tvm.IRModule() + mod_before['main'] = a + mod_new['main'] = b + with relay.build_config(opt_level=3): + for target, ctx in ctx_list(): + for kind in ["graph", "debug", "vm"]: + ex_before = relay.create_executor(kind, mod=mod_before, ctx=ctx, target=target) + ex_new = relay.create_executor(kind, mod=mod_new, ctx=ctx, target=target) + np_data = np.random.uniform(size=(1, 32, 28, 28)).astype("float32") + np_weight = np.random.uniform(size=(32, 32, 3, 3)).astype("float32") + result_before = ex_before.evaluate()(np_data, np_weight) + result_new = ex_new.evaluate()(np_data, np_weight) + tvm.testing.assert_allclose(result_before.asnumpy(), result_new.asnumpy(), rtol=1e-5, atol=1e-5) + def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 7f7f185..68e7fec 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te from tvm import relay from tvm.relay import transform @@ -50,17 +49,28 @@ def test_combine_parallel_conv2d(): args = [x, w1, w2, w3, w4] w = relay.concatenate((w1, w2, w4), axis=0) y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels1], "int64"), + strides=relay.const([1, 1], 'int64'), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels1], "int64"), + end=relay.const([-1, channels2], "int64"), + strides=relay.const([1, 1], 'int64'), + slice_mode="size") y3 = relay.nn.conv2d(x, w3) - y4 = relay.strided_slice(y, [0, channels1 + channels2], - [None, channels1 + channels2 + channels4]) + y4 = relay.strided_slice(y, + begin=relay.const([0, channels1 + channels2], "int64"), + end=relay.const([-1, channels4], "int64"), + strides=relay.const([1, 1], 'int64'), + slice_mode="size") y5 = relay.nn.max_pool2d(x) y = relay.Tuple((y1, y2, y3, y4, y5)) return relay.Function(args, y) def check(x_shape, channels1, channels2, channels3, channels4): - x = relay.var("x", shape=x_shape) + x = relay.var("x", shape=x_shape) in_c = x_shape[1] w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) @@ -99,8 +109,16 @@ def test_combine_parallel_conv2d_scale_relu(): y = relay.nn.conv2d(x, w, channels=channels1 + channels2) y = relay.multiply(y, scale) y = relay.nn.relu(y) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels1], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels1], "int64"), + end=relay.const([-1, channels2], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") y2 = relay.add(y2, bias) y = relay.Tuple((y1, y2)) return relay.Function(args, y) @@ -138,8 +156,16 @@ def test_combine_parallel_conv2d_scale(): args = [x, w1, w2, scale1, scale2] w = relay.concatenate((w1, w2), axis=0) y = relay.nn.conv2d(x, w, channels=channels1 + channels2) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels1], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels1], "int64"), + end=relay.const([-1, channels2], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") y1 = relay.multiply(y1, scale1) y2 = relay.multiply(y2, scale2) y = relay.Tuple((y1, y2)) @@ -178,8 +204,16 @@ def test_combine_parallel_conv2d_multiple_blocks(): for i in range(repeat): w_concat = relay.concatenate((w, w), axis=0) y = relay.nn.conv2d(y, w_concat, channels=channels*2) - y1 = relay.strided_slice(y, [0, 0], [None, channels]) - y2 = relay.strided_slice(y, [0, channels], [None, channels * 2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels], "int64"), + end=relay.const([-1, channels], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") y = relay.concatenate((y1, y2), axis=1) return relay.Function(args, y) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 813d7d7..4b7f7ca 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -521,26 +521,25 @@ inline Array split(const Tensor& x, Array split_indices, int ax * \param end Indicies indicating end of the slice * \param strides Specifies the stride values, it can be negative * in that case, the input tensor will be reversed in that particular axis + * \param slice_mode Specifies the slice mode * \param name The name of the operation * \param tag The tag to mark the operation * * \return A Tensor whose op member is the split operation */ inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, - const Array& strides, std::string name = "T_strided_slice", - std::string tag = kInjective) { + const Array& strides, std::string slice_mode = "end", + std::string name = "T_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); // Setup the ranges. // NOTE: this code duplicates the shape inference logic relay.op // Consider to refactor in the future. - std::vector stride_vec; - for (Integer i : strides) { - CHECK(i.defined()); - stride_vec.push_back(i->value); - } - for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) { - stride_vec.push_back(1); + std::vector stride_vec(src_tensor_dim, 1); + for (size_t i = 0; i < strides.size(); ++i) { + CHECK(strides[i].defined()); + stride_vec[i] = strides[i]->value; } + const int64_t max_range = std::numeric_limits::max(); std::vector begin_vec; @@ -559,8 +558,15 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const std::vector end_vec; for (size_t i = 0; i < end.size(); ++i) { // allow end to be None + if (!end[i].defined()) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (slice_mode == "size") { + if (end[i]->value < 0) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(begin_vec[i] + end[i]->value); + } } else { end_vec.push_back(end[i]->value); } diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py index c1e207c..c2a1905 100644 --- a/topi/python/topi/cuda/conv2d_alter_op.py +++ b/topi/python/topi/cuda/conv2d_alter_op.py @@ -246,7 +246,8 @@ def _conv2d_legalize(attrs, inputs, arg_types): new_attrs['channels'] = new_out_channel out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + out = relay.strided_slice(out, begin=relay.const([0, 0, 0, 0]), + end=relay.const(original_out_shape)) else: out = relay.nn.conv2d(data, kernel, **new_attrs) return out diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index d8be3bd..f2c1143 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -43,7 +43,8 @@ def atomic_add(x, y): return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y) -def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score_index): +def get_valid_counts_ir(data, valid_count, out, out_indices, + score_threshold, id_index, score_index): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -83,6 +84,7 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) + out_indices = ib.buffer_ptr(out_indices) atomic_add_return = ib.allocate( valid_count.dtype, (1,), name='atomic_add_return', scope='local') one_count = tvm.tir.const(1, dtype=valid_count.dtype) @@ -115,9 +117,11 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score valid_count[i]), one_count) with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = data[tid * elem_length + k] + out_indices[tid + k] = tid + k with ib.else_scope(): with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = -one + out_indices[tid + k] = -one_count return ib.get() @@ -149,24 +153,27 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): Rearranged data tensor. """ batch_size = data.shape[0] + num_anchors = data.shape[1] data_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8) out_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "out_buf", data_alignment=8) + out_indices_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "out_buf", data_alignment=8) - valid_count, out = \ - te.extern([(batch_size,), data.shape], [data], + valid_count, out, out_indices = \ + te.extern([(batch_size,), data.shape, (batch_size, num_anchors)], [data], lambda ins, outs: get_valid_counts_ir( - ins[0], outs[0], outs[1], score_threshold, id_index, score_index), + ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index), dtype=["int32", data.dtype], in_buffers=[data_buf], - out_buffers=[valid_count_buf, out_buf], + out_buffers=[valid_count_buf, out_buf, out_indices_buf], name="get_valid_counts", tag="get_valid_counts_gpu") - return [valid_count, out] + return [valid_count, out, out_indices] def nms_ir(data, sorted_index, valid_count, out, box_indices, @@ -335,7 +342,7 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices, return ib.get() -def non_max_suppression(data, valid_count, max_output_size=-1, +def non_max_suppression(data, valid_count, indices, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, return_indices=True, invalid_to_bottom=False): @@ -347,9 +354,18 @@ def non_max_suppression(data, valid_count, max_output_size=-1, 3-D tensor with shape [batch_size, num_anchors, elem_length]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. + It could be the second output out_tensor of get_valid_counts. valid_count : tvm.te.Tensor - 1-D tensor for valid number of boxes. + 1-D tensor for valid number of boxes. It could be the output + valid_count of get_valid_counts. + + indices : tvm.te.Tensor + 2-D tensor with shape [batch_size, num_anchors], represents + the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the + second dimension are like the output of arange(num_anchors) + if get_valid_counts is not used before non_max_suppression. max_output_size : optional, int Max number of output valid boxes for each instance. diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index 30784f4..22d7443 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -459,7 +459,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + out = non_max_suppression(inter_out[0], inter_out[1], inter_out[1], max_output_size=-1, iou_threshold=nms_threshold, force_suppress=force_suppress, top_k=nms_topk, return_indices=False) return out diff --git a/topi/python/topi/image/dilation2d.py b/topi/python/topi/image/dilation2d.py index a71866e..074ca6c 100644 --- a/topi/python/topi/image/dilation2d.py +++ b/topi/python/topi/image/dilation2d.py @@ -29,10 +29,10 @@ def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None): Parameters ---------- - input : tvm.Tensor + input : tvm.te.Tensor 4-D with shape [batch, in_channel, in_height, in_width] - filter : tvm.Tensor + filter : tvm.te.Tensor 3-D with shape [ in_channel, filter_height, filter_width] stride : int or a list/tuple of two ints @@ -49,7 +49,7 @@ def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None): Returns ------- - Output : tvm.Tensor + Output : tvm.te.Tensor 4-D with shape [batch, in_channel, out_height, out_width] """ if out_dtype is None: @@ -100,10 +100,10 @@ def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None): Parameters ---------- - input : tvm.Tensor + input : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] - filter : tvm.Tensor + filter : tvm.te.Tensor 3-D with shape [filter_height, filter_width, in_channel] stride : int or a list/tuple of two ints @@ -120,7 +120,7 @@ def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None): Returns ------- - Output : tvm.Tensor + Output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, in_channel] """ if out_dtype is None: diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index d715308..b4228a4 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -401,12 +401,12 @@ def isfinite(x): Parameters ---------- - x : tvm.Tensor + x : tvm.te.Tensor Input argument. Returns ------- - y : tvm.Tensor + y : tvm.te.Tensor The result. """ return te.compute(x.shape, lambda *i: te.isfinite(x(*i))) @@ -418,12 +418,12 @@ def isinf(x): Parameters ---------- - x : tvm.Tensor + x : tvm.te.Tensor Input argument. Returns ------- - y : tvm.Tensor + y : tvm.te.Tensor The result. """ return te.compute(x.shape, lambda *i: te.isinf(x(*i))) @@ -677,12 +677,12 @@ def fast_tanh(x): Parameters ---------- - x : tvm.Tensor + x : tvm.te.Tensor Input argument. Returns ------- - y : tvm.Tensor + y : tvm.te.Tensor The result. """ return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py index e492d68..f79eb52 100644 --- a/topi/python/topi/sort.py +++ b/topi/python/topi/sort.py @@ -31,10 +31,10 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): The input tensor. valid_count : tvm.te.Tensor, optional - 1-D tensor for valid number of boxes only for ssd. + 1-D tensor for valid number of boxes. axis : int, optional - Axis along which to sort the input tensor. + Axis along which to sort the input tensor. By default the flattened array is used. is_ascend : boolean, optional diff --git a/topi/python/topi/testing/strided_slice_python.py b/topi/python/topi/testing/strided_slice_python.py index c1c899a..970e1de 100644 --- a/topi/python/topi/testing/strided_slice_python.py +++ b/topi/python/topi/testing/strided_slice_python.py @@ -17,7 +17,7 @@ """strided_slice/set in python""" -def strided_slice_python(data, begin, end, strides): +def strided_slice_python(data, begin, end, strides, slice_mode="end"): """Python version of strided slice operator. Parameters @@ -34,6 +34,14 @@ def strided_slice_python(data, begin, end, strides): strides : list The stride of each slice. + slice_mode : str, optional + The slice mode [end, size]. + end: The default slice mode, ending indices for the slice. + size: The input strides will be ignored, input end in this mode indicates + the sizeof a slice starting at the location specified by begin. If end[i] is -1, + all remaining elements in that dimension are included in the slice. + + Returns ------- result : numpy.ndarray @@ -42,10 +50,24 @@ def strided_slice_python(data, begin, end, strides): strides = [] if strides is None else strides slices = [] for i in range(len(data.shape)): - slices.append(slice( - begin[i] if i < len(begin) else None, - end[i] if i < len(end) else None, - strides[i] if i < len(strides) else None)) + new_stride = None + if slice_mode == "end" and i < len(strides): + new_stride = strides[i] + + new_begin = begin[i] if i < len(begin) else None + if i >= len(end): + new_end = None + elif slice_mode == "size": + if end[i] < 0: + new_end = None + else: + new_end = new_begin + end[i] + else: + new_end = end[i] + + slices.append(slice(new_begin, + new_end, + new_stride)) return data[tuple(slices)] diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index e0f5c59..5a0bf11 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -131,7 +131,7 @@ def flip(a, axis=0): """ return cpp.flip(a, axis) -def strided_slice(a, begin, end, strides=None): +def strided_slice(a, begin, end, strides=None, slice_mode="end"): """Slice of an array. Parameters @@ -139,24 +139,31 @@ def strided_slice(a, begin, end, strides=None): a : tvm.te.Tensor The tensor to be sliced. - begin: list of int + begin : list of int The indices to begin with in the slicing. - end: list of int + end : list of int Indicies indicating end of the slice. - strides: list of int, optional + strides : list of int, optional Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. + slice_mode : str, optional + The slice mode [end, size]. + end - The ending indices for the slice [default]. + size - The input strides will be ignored, input end in this mode indicates + the sizeof a slice starting at the location specified by begin. If end[i] + is -1, all remaining elements in that dimension are included in the slice. + Returns ------- ret : tvm.te.Tensor """ if strides is None: strides = [] - return cpp.strided_slice(a, begin, end, strides) + return cpp.strided_slice(a, begin, end, strides, slice_mode) @tvm.te.tag_scope(tag=tag.INJECTIVE+",strided_set") def strided_set(a, v, begin, end, strides=None): diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index 28598de..269c876 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -23,7 +23,7 @@ from tvm.te import hybrid from ..sort import argsort @hybrid.script -def hybrid_rearrange_out(data, one): +def hybrid_rearrange_box_out(data, one, batch_size, num_anchors): """Hybrid routine to rearrange nms output to move all valid entries to top. @@ -36,14 +36,19 @@ def hybrid_rearrange_out(data, one): one: tvm.tir.const Constant one with the same dtype as data. + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. + + num_anchors: tvm.tir.IntImm or tvm.tir.Var + Number of anchors. + Returns ------- output : tvm.te.Tensor or numpy NDArray Transformed NMS output. 3-D tensor with shape [batch_size, num_anchors, 6]. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] elem_length = data.shape[2] output = output_tensor((batch_size, num_anchors, @@ -64,7 +69,59 @@ def hybrid_rearrange_out(data, one): @hybrid.script -def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one): +def hybrid_rearrange_indices_out(data, one, batch_size, num_anchors): + """Hybrid routine to rearrange nms output to + move all valid entries to top. + + Parameters + ---------- + data : tvm.te.Tensor or numpy NDArray + NMS output. 3-D tensor with shape + [batch_size, num_anchors, 6] or + [batch_size, num_anchors, 5], or 2-D + tensor with shape [batch_size, num_anchors]. + + one: tvm.tir.const + Constant one with the same dtype as data. + + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. + + num_anchors: tvm.tir.IntImm or tvm.tir.Var + Number of anchors. + + Returns + ------- + output : tvm.te.Tensor or numpy NDArray + 2-D tensor with shape [batch_size, num_anchors]. + + valid_box_count : tvm.te.Tensor or numpy NDArray + Tensor with shape [batch_size, 1], indicates + the valid number of boxes. + """ + valid_box_count = output_tensor((batch_size, 1), "int32") + output = output_tensor((batch_size, num_anchors), data.dtype) + + for i in parallel(batch_size): + valid_idx = 0 + for j in range(num_anchors): + if data[i, j] >= 0: + output[i, valid_idx] = data[i, j] + valid_idx += 1 + if data[i, j] > num_anchors or data[i, j] < -num_anchors: + output[i, valid_idx] = 0 + valid_idx += 1 + if j >= valid_idx: + output[i, j] = -one + valid_box_count[i, 0] = valid_idx + + return output, valid_box_count + + +@hybrid.script +def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, + one, batch_size, num_anchors): """Hybrid routine to get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -87,22 +144,31 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one): one: tvm.tir.const Constant one with the same dtype as data. + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. + + num_anchors: tvm.tir.IntImm or tvm.tir.Var + Number of anchors. + Returns ------- + valid_count : tvm.te.Tensor or numpy NDArray + 1-D tensor for valid number of boxes. + out_tensor : tvm.te.Tensor or numpy NDArray Rearranged data tensor. - valid_count : tvm.te.Tensor or numpy NDArray - 1-D tensor for valid number of boxes. + out_indices: tvm.te.Tensor or numpy NDArray + Related index in input data. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] box_data_length = data.shape[2] valid_count = output_tensor((batch_size,), "int32") out_tensor = output_tensor((batch_size, num_anchors, box_data_length), data.dtype) + out_indices = output_tensor((batch_size, num_anchors), "int32") for i in parallel(batch_size): valid_count[i] = 0 for j in range(num_anchors): @@ -111,11 +177,13 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one): (id_index < 0 or data[i, j, id_index] >= 0): for k in range(box_data_length): out_tensor[i, valid_count[i], k] = data[i, j, k] + out_indices[i, valid_count[i]] = j valid_count[i] += 1 if j >= valid_count[i]: for k in range(box_data_length): out_tensor[i, j, k] = -one - return valid_count, out_tensor + out_indices[i, j] = -1 + return valid_count, out_tensor, out_indices def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): @@ -139,38 +207,55 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): Returns ------- + valid_count : tvm.te.Tensor + 1-D tensor for valid number of boxes. + out_tensor : tvm.te.Tensor Rearranged data tensor. - valid_count : tvm.te.Tensor - 1-D tensor for valid number of boxes. + out_indices: tvm.te.Tensor or numpy NDArray + Related index in input data. """ score_threshold_const = tvm.tir.const(score_threshold, data.dtype) id_index_const = tvm.tir.const(id_index, "int32") score_index_const = tvm.tir.const(score_index, "int32") return hybrid_get_valid_counts(data, score_threshold_const, id_index_const, score_index_const, - tvm.tir.const(1, data.dtype)) + tvm.tir.const(1, data.dtype), + data.shape[0], data.shape[1]) @hybrid.script -def hybrid_nms(data, sorted_index, valid_count, - max_output_size, iou_threshold, force_suppress, - top_k, coord_start, id_index, score_index, zero, one): +def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors, + max_output_size, iou_threshold, force_suppress, top_k, coord_start, + score_index, id_index, return_indices, zero, one): """Hybrid routing for non-maximum suppression. Parameters ---------- data: tvm.te.Tensor or numpy NDArray Bounding boxes with class and score. 3-D tensor with shape - [batch_size, num_anchors, 6]. + [batch_size, num_anchors, 6]. It could be the second output + out_tensor of get_valid_counts. sorted_index : tvm.te.Tensor or numpy NDArray Bounding box indexes sorted by score, with shape [batch_size, num_anchors]. valid_count : tvm.te.Tensor or numpy NDArray - 1-D tensor for valid number of boxes. + 1-D tensor for valid number of boxes. It could be the output + valid_count of get_valid_counts. + + indices : tvm.te.Tensor or numpy.NDArray + indices in original tensor, with shape [batch_size, num_anchors], + represents the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the second + dimension are like the output of arange(num_anchors) if get_valid_counts + is not used before non_max_suppression. + + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. max_output_size : tvm.tir.const Max number of output valid boxes for each instance. @@ -188,11 +273,14 @@ def hybrid_nms(data, sorted_index, valid_count, coord_start : tvm.tir.const Start index of the consecutive 4 coordinates. + score_index: tvm.tir.const + Index of the scores/confidence of boxes. + id_index : tvm.tir.const index of the class categories, -1 to disable. - score_index: tvm.tir.const - Index of the scores/confidence of boxes. + return_indices : tvm.tir.const + Whether to return box indices in input data. zero: tvm.tir.const Constant zero with the same dtype as data. @@ -203,15 +291,17 @@ def hybrid_nms(data, sorted_index, valid_count, Returns ------- output : tvm.te.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. + 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. box_indices: tvm.te.Tensor 2-D tensor with shape [batch_size, num_anchors]. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] + box_data_length = data.shape[2] - box_indices = output_tensor((batch_size, num_anchors), "int32") + + # box_indices is the expected value, similar to TF & ONNX + box_indices = output_tensor((batch_size, num_anchors), sorted_index.dtype) output = output_tensor((batch_size, num_anchors, box_data_length,), data.dtype) @@ -232,9 +322,11 @@ def hybrid_nms(data, sorted_index, valid_count, for k in range(box_data_length): output[i, j + nkeep, k] = -one box_indices[i, j + nkeep] = -1 + # Apply nms box_start_idx = coord_start batch_idx = i + for j in range(valid_count[i]): if output[i, j, score_index] > 0 and (id_index < 0 or output[i, j, id_index] >= 0): box_a_idx = j @@ -246,36 +338,62 @@ def hybrid_nms(data, sorted_index, valid_count, check_iou = 1 elif id_index < 0 or output[i, j, id_index] == output[i, k, id_index]: check_iou = 1 + if check_iou > 0: - a_l = output[batch_idx, box_a_idx, box_start_idx] - a_t = output[batch_idx, box_a_idx, box_start_idx + 1] - a_r = output[batch_idx, box_a_idx, box_start_idx + 2] - a_b = output[batch_idx, box_a_idx, box_start_idx + 3] + # a_l: left, a_t: top, a_r: right, a_b: bottom + a_l = min(output[batch_idx, box_a_idx, box_start_idx], + output[batch_idx, box_a_idx, box_start_idx + 2]) + a_t = min(output[batch_idx, box_a_idx, box_start_idx + 1], + output[batch_idx, box_a_idx, box_start_idx + 3]) + a_r = max(output[batch_idx, box_a_idx, box_start_idx], + output[batch_idx, box_a_idx, box_start_idx + 2]) + a_b = max(output[batch_idx, box_a_idx, box_start_idx + 1], + output[batch_idx, box_a_idx, box_start_idx + 3]) + box_b_idx = k - b_t = output[batch_idx, box_b_idx, box_start_idx + 1] - b_b = output[batch_idx, box_b_idx, box_start_idx + 3] - b_l = output[batch_idx, box_b_idx, box_start_idx] - b_r = output[batch_idx, box_b_idx, box_start_idx + 2] + + # b_l: left, b_t: top, b_r: right, b_b: bottom + b_l = min(output[batch_idx, box_b_idx, box_start_idx], + output[batch_idx, box_b_idx, box_start_idx + 2]) + b_t = min(output[batch_idx, box_b_idx, box_start_idx + 1], + output[batch_idx, box_b_idx, box_start_idx + 3]) + b_r = max(output[batch_idx, box_b_idx, box_start_idx], + output[batch_idx, box_b_idx, box_start_idx + 2]) + b_b = max(output[batch_idx, box_b_idx, box_start_idx + 1], + output[batch_idx, box_b_idx, box_start_idx + 3]) + + # Overlapping width and height w = max(zero, min(a_r, b_r) - max(a_l, b_l)) h = max(zero, min(a_b, b_b) - max(a_t, b_t)) + + # Overlapping area area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + + # get the iou iou = zero if u <= zero else area / u + if iou >= iou_threshold: output[i, k, score_index] = -one if id_index >= 0: output[i, k, id_index] = -one box_indices[i, k] = -1 + else: for j in parallel(valid_count[i]): for k in range(box_data_length): output[i, j, k] = data[i, j, k] box_indices[i, j] = j + # Set invalid entry to be -1 for j in parallel(num_anchors - valid_count[i]): for k in range(box_data_length): output[i, j + valid_count[i], k] = -one box_indices[i, j + valid_count[i]] = -1 + # Only return max_output_size valid boxes num_valid_boxes = 0 if max_output_size > 0: @@ -287,10 +405,17 @@ def hybrid_nms(data, sorted_index, valid_count, box_indices[i, j] = -1 else: num_valid_boxes += 1 - return output, box_indices + if return_indices: + for j in range(valid_count[i]): + idx = box_indices[i, j] + if box_indices[i, j] >= 0: + box_indices[i, j] = indices[i, idx] + + return output, box_indices -def non_max_suppression(data, valid_count, max_output_size=-1, +@tvm.target.generic_func +def non_max_suppression(data, valid_count, indices, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, return_indices=True, invalid_to_bottom=False): @@ -304,6 +429,9 @@ def non_max_suppression(data, valid_count, max_output_size=-1, valid_count : tvm.te.Tensor 1-D tensor for valid number of boxes. + indices : tvm.te.Tensor + 2-D tensor with shape [batch_size, num_anchors]. + max_output_size : optional, int Max number of output valid boxes for each instance. By default all valid boxes are returned. @@ -334,8 +462,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1, Returns ------- - out : tvm.te.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. + out : tvm.te.Tensor or tuple of tvm.te.Tensor + 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. Out is a tuple of tvm.te.Tensor + if return_indices is True, the Tensor in the tuple is 2-D tensor + with shape [batch_size, num_anchors] and shape + [batch_size, num_valid_anchors] respectively. Example -------- @@ -348,7 +480,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, iou_threshold = 0.7 force_suppress = True top_k = -1 - out = non_max_suppression(data, valid_count, iou_threshold=iou_threshold, + out = non_max_suppression(data, valid_count, indices, iou_threshold=iou_threshold, force_suppress=force_suppress, top_k=top_k) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) @@ -366,17 +498,27 @@ def non_max_suppression(data, valid_count, max_output_size=-1, score_shape = (batch_size, num_anchors) score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis]) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) - out, box_indices = hybrid_nms(data, sort_tensor, valid_count, + out, box_indices = hybrid_nms(data, + sort_tensor, + valid_count, + indices, + batch_size, + num_anchors, tvm.tir.const(max_output_size, dtype="int32"), tvm.tir.const(iou_threshold, dtype=data.dtype), tvm.tir.const(force_suppress, dtype="bool"), tvm.tir.const(top_k, dtype="int32"), tvm.tir.const(coord_start, dtype="int32"), - tvm.tir.const(id_index, dtype="int32"), tvm.tir.const(score_index, dtype="int32"), + tvm.tir.const(id_index, dtype="int32"), + tvm.tir.const(return_indices, dtype="bool"), zero=tvm.tir.const(0, dtype=data.dtype), one=tvm.tir.const(1, dtype=data.dtype)) - if not return_indices and invalid_to_bottom: - out = hybrid_rearrange_out(out, one=tvm.tir.const(1, dtype=data.dtype)) - - return box_indices if return_indices else out + if return_indices: + return hybrid_rearrange_indices_out(box_indices, one=tvm.tir.const(1, dtype="int32"), + batch_size=batch_size, num_anchors=num_anchors) + + if invalid_to_bottom: + out = hybrid_rearrange_box_out(out, one=tvm.tir.const(1, dtype=data.dtype), + batch_size=batch_size, num_anchors=num_anchors) + return out diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index ba0cf54..e5b9215 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -304,7 +304,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + out = non_max_suppression(inter_out[0], inter_out[1], inter_out[1], max_output_size=-1, iou_threshold=nms_threshold, force_suppress=force_suppress, top_k=nms_topk, return_indices=False) return out diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index d1c607f..e9fc422 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -312,7 +312,9 @@ def _conv2d_legalize(attrs, inputs, arg_types): new_attrs['channels'] = new_out_channel out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + out = relay.strided_slice(out, + begin=relay.const([0, 0, 0, 0], "int32"), + end=relay.const(original_out_shape, "int32")) else: out = relay.nn.conv2d(data, kernel, **new_attrs) diff --git a/topi/python/topi/x86/conv3d.py b/topi/python/topi/x86/conv3d.py index 27f48f8..f0dee31 100644 --- a/topi/python/topi/x86/conv3d.py +++ b/topi/python/topi/x86/conv3d.py @@ -78,11 +78,11 @@ def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype): Parameters ---------- - input : tvm.Tensor + input : tvm.te.Tensor 5-D input data with shapes: [batch, in_channel, in_depth, in_height, in_width] for NCDHW layout - filter : tvm.Tensor + filter : tvm.te.Tensor 5-D filter with shape [out_channels, in_channels, kernel_depth, kernel_height, kernel_width] strides : int or a list/tuple of three ints @@ -96,7 +96,7 @@ def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype): Returns ------- - output : tvm.Tensor + output : tvm.te.Tensor 5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout """ layout = "NCDHW" diff --git a/topi/src/transform.cc b/topi/src/transform.cc index 4af491e..5300973 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -152,7 +152,7 @@ TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv) }); TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = strided_slice(args[0], args[1], args[2], args[3]); + *rv = strided_slice(args[0], args[1], args[2], args[3], args[4]); }); TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 3ccb44d..d2331ee 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -69,6 +69,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor)) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 @@ -78,10 +79,12 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 + np_out3[i, inter_idx] = j inter_idx += 1 if j >= np_out1[i]: for k in range(elem_length): np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 def check_device(device): ctx = tvm.context(device, 0) @@ -98,10 +101,18 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): tvm_input_data = tvm.nd.array(np_data, ctx) tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx) - f = tvm.build(s, [data, outs[0], outs[1]], device) - f(tvm_input_data, tvm_out1, tvm_out2) - tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + tvm_out3 = tvm.nd.array(np.zeros(np_out3.shape, dtype="int32"), ctx) + if device == "llvm": + f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) + f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) + else: + f = tvm.build(s, [data, outs[0], outs[1]], device) + f(tvm_input_data, tvm_out1, tvm_out2) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) """ Skip this test as it is intermittent see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094 @@ -114,19 +125,21 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): def test_get_valid_counts(): + verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0) verify_get_valid_counts((1, 2500, 6), 0, 0, 1) verify_get_valid_counts((1, 2500, 5), -1, -1, 0) verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0) verify_get_valid_counts((16, 500, 5), 0.95, -1, 1) -def verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, iou_threshold, - force_suppress, top_k, coord_start, score_index, id_index): +def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, + iou_threshold, force_suppress, top_k, coord_start, score_index, id_index): dshape = np_data.shape batch, num_anchors, _ = dshape indices_dshape = (batch, num_anchors) data = te.placeholder(dshape, name="data") valid_count = te.placeholder((batch,), dtype="int32", name="valid_count") + indices = te.placeholder((batch, num_anchors), dtype="int32", name="indices") def check_device(device): ctx = tvm.context(device, 0) @@ -136,25 +149,31 @@ def verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_re print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = topi.testing.dispatch(device, _nms_implement) - out = fcompute(data, valid_count, -1, iou_threshold, force_suppress, top_k, + out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, coord_start=coord_start, score_index=score_index, id_index=id_index, return_indices=False) - indices_out = fcompute(data, valid_count, -1, iou_threshold, force_suppress, top_k, - coord_start=coord_start, score_index=score_index, id_index=id_index) + indices_out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index, + return_indices=True) s = fschedule(out) indices_s = fschedule(indices_out) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_indices = tvm.nd.array(np_indices, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) - f = tvm.build(s, [data, valid_count, out], device) - f(tvm_data, tvm_valid_count, tvm_out) + f = tvm.build(s, [data, valid_count, indices, out], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx) - f = tvm.build(indices_s, [data, valid_count, indices_out], device) - f(tvm_data, tvm_valid_count, tvm_indices_out) + if device == 'llvm': + f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) + else: + f = tvm.build(indices_s, [data, valid_count, indices, indices_out], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) for device in ['llvm', 'cuda', 'opencl']: @@ -166,23 +185,24 @@ def test_non_max_suppression(): [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") + np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) np_indices_result = np.array([[3, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80], [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79], [0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") + np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]]]) np_indices_result = np.array([[3, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) - + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False): @@ -459,9 +479,9 @@ def test_proposal(): if __name__ == "__main__": test_get_valid_counts() - test_non_max_suppression() test_multibox_prior() test_multibox_detection() test_roi_align() test_roi_pool() test_proposal() + test_non_max_suppression() -- 2.7.4