From fdc8b0dd1763aece4ce457a7baf522c2989ac6c4 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <15611578+maheshambule@users.noreply.github.com> Date: Mon, 23 Mar 2020 07:10:54 +0530 Subject: [PATCH] [Relay, Topi] [TF, MXNet] Unravel Index operator (#5082) * first cut unravel_index * merge fixes * change rates to dilations * unravel_index op relay, topi, mxnet, tf * doc changes * small changes * remove empty unravel and argwhere attrs * remove empty unravel and argwhere attrs --- docs/api/python/topi.rst | 2 + docs/frontend/tensorflow.rst | 1 + docs/langref/relay_op.rst | 3 +- include/tvm/relay/attrs/transform.h | 6 -- python/tvm/relay/frontend/mxnet.py | 8 +++ python/tvm/relay/frontend/tensorflow.py | 10 +++- python/tvm/relay/op/_transform.py | 1 + python/tvm/relay/op/transform.py | 23 ++++++++ src/relay/op/tensor/transform.cc | 72 +++++++++++++++++++++++- tests/python/frontend/mxnet/test_forward.py | 29 +++++++++- tests/python/frontend/tensorflow/test_forward.py | 52 +++++++++++++++++ tests/python/relay/test_op_level3.py | 39 +++++++++++++ topi/include/topi/transform.h | 48 ++++++++++++++++ topi/python/topi/transform.py | 25 +++++++- topi/src/topi.cc | 5 ++ topi/tests/python/test_topi_transform.py | 44 +++++++++++++++ 16 files changed, 353 insertions(+), 15 deletions(-) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 49fd94d..676dde9 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -47,6 +47,7 @@ List of operators topi.strided_slice topi.expand_dims topi.reshape + topi.unravel_index topi.squeeze topi.concatenate topi.split @@ -147,6 +148,7 @@ topi .. autofunction:: topi.strided_slice .. autofunction:: topi.expand_dims .. autofunction:: topi.reshape +.. autofunction:: topi.unravel_index .. autofunction:: topi.squeeze .. autofunction:: topi.concatenate .. autofunction:: topi.split diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst index e06794d..80230c6 100644 --- a/docs/frontend/tensorflow.rst +++ b/docs/frontend/tensorflow.rst @@ -242,5 +242,6 @@ Supported Ops - Transpose - TruncateMod - Unpack +- UnravelIndex - Where - ZerosLike diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 35f9eeb..ac636f8 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -124,6 +124,7 @@ This level enables additional math and transform operators. tvm.relay.repeat tvm.relay.tile tvm.relay.reverse + tvm.relay.unravel_index **Level 4: Broadcast and Reductions** @@ -217,4 +218,4 @@ This level supports dialect operators. :nosignatures: tvm.relay.qnn.op.requantize - tvm.relay.qnn.op.conv2d + tvm.relay.qnn.op.conv2d \ No newline at end of file diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 11c7886..ae2ac11 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -315,12 +315,6 @@ struct OneHotAttrs : public tvm::AttrsNode { } }; // struct OneHotAttrs -/*! \brief Attributes for ArgWhere operator */ -struct ArgWhereAttrs : public tvm::AttrsNode { - TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") { - } -}; // struct ArgWhereAttrs - } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 17be368..b3feded 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -120,6 +120,13 @@ def _mx_compare(new_op, wrapper): return impl +def _mx_unravel_index(inputs, attrs): + assert len(inputs) == 1 + shape = attrs.get_int_tuple("shape") + shape_expr = _expr.const(list(shape)) + return _op.unravel_index(inputs[0], shape_expr) + + def _mx_zeros(inputs, attrs): assert len(inputs) == 0 shape = attrs.get_int_tuple("shape") @@ -1826,6 +1833,7 @@ _convert_map = { "Embedding" : _mx_embedding, "argsort" : _mx_argsort, "topk" : _mx_topk, + "_unravel_index": _mx_unravel_index, "SequenceMask" : _mx_sequence_mask, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ff69ccc..4221cac 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -627,6 +627,11 @@ def _decode_image(): return inputs[0] return _impl +def _unravel_index(): + def _impl(inputs, attr, params): + return _op.unravel_index(inputs[0], inputs[1]) + return _impl + def _crop_and_resize(): def _impl(inputs, attr, params): # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] @@ -1744,6 +1749,7 @@ _convert_map = { 'Transpose' : _transpose(), 'TruncateMod' : _elemwise('mod'), 'Unpack' : _unpack(), + 'UnravelIndex' : _unravel_index(), 'Where' : _where(), 'ZerosLike' : AttrCvt('zeros_like'), @@ -2517,9 +2523,7 @@ class GraphProto(object): array_ndim = len(np_array.shape) if array_ndim == 0: - new_array = np.empty([1], dtype=np_array.dtype) - new_array[0] = np_array - self._nodes[name] = [tvm.relay.const(new_array)] + self._nodes[name] = [tvm.relay.const(np_array)] else: self._params[name] = tvm.nd.array(np_array) self._nodes[name] = [_expr.var(name, diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 4b35009..1f85e31 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -54,6 +54,7 @@ _reg.register_injective_schedule("gather_nd") _reg.register_injective_schedule("sequence_mask") _reg.register_injective_schedule("one_hot") _reg.register_reduce_schedule("collapse_sum_like") +_reg.register_injective_schedule("unravel_index") # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 6a30eb2..d7a7b4f 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -861,3 +861,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): [0, 0, 1]] """ return _make.one_hot(indices, on_value, off_value, depth, axis, dtype) + + +def unravel_index(indices, shape): + """Convert a flat index or array of flat indices into a tuple of coordinate arrays. + + Example:: + - unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6],[4, 5, 1]] + + Parameters + ---------- + indices : relay.Expr + An integer array containing indices. + + shape : relay.Expr + The shape of the array. + + Returns + ------- + result : relay.Expr + The tuple of coordinate arrays. + """ + + return _make.unravel_index(indices, shape) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 32df221..942ba7e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -806,15 +806,13 @@ bool ArgWhereRel(const Array& types, TVM_REGISTER_GLOBAL("relay.op._make.argwhere") .set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); - auto attrs = make_object(); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return CallNode::make(op, {data}, Attrs(), {}); }); RELAY_REGISTER_OP("argwhere") .describe(R"doc(Find the indices of elements of a tensor that are non-zero)doc" TVM_ADD_FILELINE) .set_num_inputs(1) -.set_attrs_type() .add_argument("condition", "Tensor", "The input condition tensor.") .add_type_rel("ArgWhere", ArgWhereRel) .set_attr("TOpIsStateful", false) @@ -2662,5 +2660,73 @@ RELAY_REGISTER_OP("one_hot") .set_attr("FTVMCompute", OneHotCompute) .set_attr("TOpPattern", kOutEWiseFusable); +/* relay.unravel_index */ +bool UnRavelIndexRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + + const auto* indices = types[0].as(); + if (indices == nullptr) { + CHECK(types[0].as()) + << "unravel_index: expect input type to be TensorType but get " + << types[0]; + return false; + } + CHECK(indices->dtype.is_int()) + << "indices of unravel_index must be tensor of integer"; + + const auto* shape = types[1].as(); + if (shape == nullptr) { + CHECK(types[1].as()) + << "unravel_index: expect input type to be TensorType but get " + << types[1]; + return false; + } + CHECK(indices->dtype.is_int()) + << "shape of unravel_index must be tensor of integer"; + + Array indices_shape; + Array shape_shape; + indices_shape = indices->shape; + shape_shape = shape->shape; + + Array oshape; + oshape.push_back(shape_shape[0]); + if (indices_shape.size() != 0) { + oshape.push_back(indices_shape[0]); + } + reporter->Assign(types[2], TensorType(oshape, indices->dtype)); + return true; +} + +Array UnRavelIndexCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type) { + return Array{topi::unravel_index(inputs[0], inputs[1])}; +} + +Expr MakeUnRavelIndex(Expr data, + Expr shape) { + static const Op& op = Op::Get("unravel_index"); + return CallNode::make(op, {data, shape}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.unravel_index") +.set_body_typed(MakeUnRavelIndex); + +RELAY_REGISTER_OP("unravel_index") +.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. + +Example:: + - unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]] +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.set_support_level(3) +.add_type_rel("UnRavelIndexRel", UnRavelIndexRel) +.set_attr("FTVMCompute", UnRavelIndexCompute) +.set_attr("TOpPattern", kInjective); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index b81fbab..102905a 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -949,6 +949,32 @@ def test_forward_cond(): verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32')) +def test_forward_unravel_index(): + def verify(x, shape, dtype): + a_np = np.array(x).astype(dtype) + mx_sym = _mx_symbol(mx.sym, 'unravel_index', [mx.sym.var('a'), shape]) + ref_res = _mx_symbol(mx.nd, 'unravel_index', [mx.nd.array(a_np), shape]) + shapes = {'a': a_np.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + + for target, ctx in ctx_list(): + for kind in ["graph", "vm", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + + for dtype in ["int32", "int64"]: + verify([0, 1, 2, 3], [2, 2], dtype) + verify([144, 13, 45], [6, 7, 10, 2], dtype) + verify([456], [6, 7, 10, 2], dtype) + + # In below example, 5 is out of bound for array of size 4. + # MXNet implementation provides different result than TVM + # TVM implementation is inline with Tensorflow + # Ideally error should be thrown just like Numpy + # verify([0, 1, 2, 5], [2, 2], dtype) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1003,4 +1029,5 @@ if __name__ == '__main__': test_forward_convolution() test_forward_deconvolution() test_forward_cond() - test_forward_make_loss() \ No newline at end of file + test_forward_make_loss() + test_forward_unravel_index() \ No newline at end of file diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 2342606..3c51977 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3057,6 +3057,57 @@ def test_forward_add_n(): _test_forward_add_n(in5) +####################################################################### +# Unravel Index +# ---------------------- +def _test_forward_unravel_index(inputs): + tf.reset_default_graph() + with tf.Graph().as_default(): + temp = [] + for each in inputs: + temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype)) + output = tf.unravel_index(temp[0], temp[1]) + compare_tf_with_tvm([each for each in inputs], [ + each.name for each in temp], output.name) + + +def _test_forward_unravel_index_scalar(x, y, dtype="int32"): + tf.reset_default_graph() + with tf.Graph().as_default(): + indices_1 = constant_op.constant(x, dtype=dtype) + dims_1 = constant_op.constant(y, dtype=dtype) + out_1 = array_ops.unravel_index(indices_1, dims_1) + compare_tf_with_tvm([], [], out_1.name) + + +def test_forward_unravel_index(): + x = np.array([0, 1, 2, 3]) + y = np.array([2, 2]) + _test_forward_unravel_index([x, y]) + + x = np.array([0, 1, 2, 5]) + y = np.array([2, 2]) + _test_forward_unravel_index([x, y]) + + x = np.array([0, 1, 2, 5]) + y = np.array([2]) + _test_forward_unravel_index([x, y]) + + x = np.array([102, 300, 16]) + y = np.array([10, 10, 9, 6]) + _test_forward_unravel_index([x, y]) + + x = np.array([100]) + y = np.array([10, 10, 9, 6]) + _test_forward_unravel_index([x, y]) + + # Test scalar input + _test_forward_unravel_index_scalar(13, [1, 4, 5, 2]) + + +####################################################################### +# Dilation2d +# ---------------------- def _test_dilation2d(tensor_in_sizes, filter_in_sizes, strides, dilations, padding): """ One iteration of dilation2d with given shapes and attributes """ @@ -3173,6 +3224,7 @@ if __name__ == '__main__': test_forward_squared_difference() test_forward_add_n() test_forward_floormod() + test_forward_unravel_index() # Reductions test_forward_argminmax() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 7e5314d..fffb1de 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -683,6 +683,44 @@ def test_gather_nd(): verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) + +def test_unravel_index(): + def verify_unravel_index(indices, shape, dtype): + x_data = np.array(indices).astype(dtype) + y_data = np.array(shape).astype(dtype) + x = relay.var("x", relay.TensorType(x_data.shape, dtype)) + y = relay.var("y", relay.TensorType(y_data.shape, dtype)) + + z = relay.unravel_index(x, y) + zz = run_infer_type(z) + + if len(x_data.shape) == 1: + out_shape = [y_data.shape[0], x_data.shape[0]] + else: + out_shape = [y_data.shape[0]] + assert zz.checked_type == relay.ty.TensorType(out_shape, dtype) + + func = relay.Function([x, y], z) + ref_res = np.unravel_index(x_data, y_data) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data, y_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + for dtype in ["int64", "int32"]: + verify_unravel_index([0, 1, 2, 3], [2, 2], dtype) + verify_unravel_index([144], [5, 5, 5, 2], dtype) + verify_unravel_index(144, [5, 5, 5, 2], dtype) + verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype) + + # In below example, 5 is out of bound for array of size 4. + # Numpy implementation throws error for it + # TVM implementation does not throw error instead it produces + # output which is inline with Tensorflow + # verify_unravel_index([0, 1, 2, 5], [2, 2], dtype) + + if __name__ == "__main__": test_arange() test_cast() @@ -713,3 +751,4 @@ if __name__ == "__main__": test_tile() test_repeat() test_gather_nd() + test_unravel_index() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index efbffad..40bdcc6 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -233,6 +233,54 @@ inline Tensor reshape(const Tensor& x, } /*! + * \brief Converts a flat index or array of flat indices into a tuple of coordinate arrays + * + * \param x The input tensor having indices. + * \param shape The shape tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor of coordinate arrays. + */ + +inline Tensor unravel_index(const Tensor& x, + const Tensor& shape, + std::string name = "T_unravel", + std::string tag = kInjective) { + auto x_shape = x->shape; + auto shape_shape = shape->shape; + + Array oshape; + oshape.push_back(shape_shape[0]); + if (x_shape.size() != 0) { + oshape.push_back(x_shape[0]); + } + + auto func = [&](const Array& indices) { + auto i = indices[0]; + std::vector indices_divs; + PrimExpr ret = 0; + PrimExpr cur_val = 0; + PrimExpr index_val = 0; + + if (x_shape.size() != 0) { + index_val = x[indices[1]]; + } else { + index_val = x(); + } + indices_divs.push_back(index_val); + for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) { + ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret); + cur_val = indexdiv(indices_divs.back(), shape[v]); + indices_divs.push_back(cur_val); + } + return ret; + }; + + return compute(oshape, func, name, tag); +} + +/*! * \brief Remove size 1 dimensions from the shape of a tensor. * The removed dimensions must have a constant size of 1. * diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 036191b..ef54560 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,consider-using-enumerate +# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name """Injective transformation operators""" from __future__ import absolute_import as _abs import tvm @@ -653,3 +653,26 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): [0, 0, 1]] """ return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype) + + +def unravel_index(indices, shape): + """Convert a flat index or array of flat indices into a tuple of coordinate arrays. + + Example:: + - unravel_index([22, 41, 37], [7, 6]) = [[3, 6, 6], [4, 5, 1]] + + Parameters + ---------- + indices : relay.Expr + An integer array containing indices. + + shape : relay.Expr + The shape of the array. + + Returns + ------- + result : relay.Expr + The tuple of coordinate arrays. + """ + + return cpp.unravel_index(indices, shape) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 5581f2b..3a3175c 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -435,6 +435,11 @@ TVM_REGISTER_GLOBAL("topi.gather_nd") *rv = gather_nd(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.unravel_index") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = unravel_index(args[0], args[1]); + }); + TVM_REGISTER_GLOBAL("topi.matmul") .set_body([](TVMArgs args, TVMRetValue *rv) { switch ( args.size() ) { diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 097c87d..b98ce09 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -562,6 +562,40 @@ def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype): for device in get_all_backend(): check_device(device) + +def verify_unravel_index(indices, shape, dtype): + x_data = np.array(indices).astype(dtype) + y_data = np.array(shape).astype(dtype) + if len(x_data.shape) == 1: + dst_shape = [y_data.shape[0], x_data.shape[0]] + else: + dst_shape = [y_data.shape[0]] + + X = te.placeholder(shape=x_data.shape, dtype=dtype, name="X") + Y = te.placeholder(shape=y_data.shape, dtype=dtype, name="Y") + Z = topi.unravel_index(X, Y) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(Z) + foo = tvm.build(s, [X, Y, Z], device, name="unravel_index") + + out_npy = np.unravel_index(x_data, y_data) + datax_nd = tvm.nd.array(x_data, ctx) + datay_nd = tvm.nd.array(y_data, ctx) + out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=Z.dtype) + foo(datax_nd, datay_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) + + for device in get_all_backend(): + check_device(device) + + def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) @@ -882,6 +916,15 @@ def test_one_hot(): verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32") verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + +def test_unravel_index(): + for dtype in ["int32", "int64"]: + verify_unravel_index([0, 1, 2, 3], [2, 2], dtype) + verify_unravel_index([144], [5, 5, 5, 2], dtype) + verify_unravel_index(144, [5, 5, 5, 2], dtype) + verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype) + + if __name__ == "__main__": test_strided_slice() test_concatenate() @@ -905,3 +948,4 @@ if __name__ == "__main__": test_ndarray_size() test_where_fusion() test_one_hot() + test_unravel_index() -- 2.7.4