[Relay, Topi] [Frontend][TFLite, MXNet] ReverseSequence operator (#5495)
authorMahesh Ambule <15611578+maheshambule@users.noreply.github.com>
Wed, 17 Jun 2020 03:53:08 +0000 (09:23 +0530)
committerGitHub <noreply@github.com>
Wed, 17 Jun 2020 03:53:08 +0000 (20:53 -0700)
* TFLite reverse_sequence op

* TFLite add_n implementation

* reverse_sequence implementation

* reverse_sequence implementation

* reverse sequence

* TOPI,Relay,TFLite - Reverse Sequence

Signed-off-by: maheshambule <mahesh_ambule@persistent.com>
* Reverse Sequence small fixes

Signed-off-by: maheshambule <mahesh_ambule@persistent.com>
* lint fixes

Signed-off-by: maheshambule <mdambule07@gmail.com>
* TFLite reverse_sequence op

Signed-off-by: maheshambule
* MXNet SequenceReverse implementation

* clang format

* clang format

* review comment fixes

16 files changed:
docs/api/python/topi.rst
docs/langref/relay_op.rst
include/tvm/relay/attrs/transform.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/tflite.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/op_attrs.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/tflite/test_forward.py
tests/python/relay/test_op_level3.py
topi/include/topi/transform.h
topi/python/topi/transform.py
topi/src/transform.cc
topi/tests/python/test_topi_transform.py

index 65f2375..53f2f3c 100644 (file)
@@ -46,6 +46,7 @@ List of operators
    topi.reinterpret
    topi.transpose
    topi.flip
+   topi.reverse_sequence
    topi.strided_slice
    topi.expand_dims
    topi.reshape
@@ -152,6 +153,7 @@ topi
 .. autofunction:: topi.reinterpret
 .. autofunction:: topi.transpose
 .. autofunction:: topi.flip
+.. autofunction:: topi.reverse_sequence
 .. autofunction:: topi.strided_slice
 .. autofunction:: topi.expand_dims
 .. autofunction:: topi.reshape
index cef96ef..86e0c0d 100644 (file)
@@ -132,6 +132,7 @@ This level enables additional math and transform operators.
    tvm.relay.repeat
    tvm.relay.tile
    tvm.relay.reverse
+   tvm.relay.reverse_sequence
    tvm.relay.unravel_index
    tvm.relay.sparse_to_dense
 
index cbc6034..750a8a4 100644 (file)
@@ -194,6 +194,20 @@ struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
   }
 };  // struct ReverseAttrs
 
+/*! \brief Attributes used in reverse_sequence operators */
+struct ReverseSequenceAttrs : public tvm::AttrsNode<ReverseSequenceAttrs> {
+  Integer seq_axis;
+  Integer batch_axis;
+
+  TVM_DECLARE_ATTRS(ReverseSequenceAttrs, "relay.attrs.ReverseSequenceAttrs") {
+    TVM_ATTR_FIELD(seq_axis).set_default(1).describe(
+        "The seq axis along which to reverse elements.");
+    TVM_ATTR_FIELD(batch_axis)
+        .set_default(0)
+        .describe("The batch axis along which to slice the tensor.");
+  }
+};  // struct ReverseSequenceAttrs
+
 /*! \brief Attributes used in squeeze operators */
 struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
   // use axis to make the name numpy compatible.
index 1d8842d..f77c3b5 100644 (file)
@@ -819,6 +819,21 @@ def _mx_reverse(inputs, attrs):
     return _op.reverse(inputs[0], **new_attrs)
 
 
+def _mx_sequence_reverse(inputs, attrs):
+    new_attrs = {}
+    use_seq_lengths = attrs.get_bool("use_sequence_length")
+    if not use_seq_lengths:
+        assert len(inputs) == 1
+        new_attrs["axis"] = attrs.get_int("axis")
+        return _op.reverse(inputs[0], **new_attrs)
+
+    assert len(inputs) == 2
+    new_attrs["seq_axis"] = attrs.get_int("axis")
+    # MXNet assumes batch_axis as 1.
+    new_attrs["batch_axis"] = 1
+    return _op.reverse_sequence(inputs[0], inputs[1], **new_attrs)
+
+
 def _mx_roi_align(inputs, attrs):
     new_attrs = {}
     new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
@@ -2078,6 +2093,7 @@ _convert_map = {
     "take"          : _mx_take,
     "gather_nd"     : _mx_gather_nd,
     "reverse"       : _mx_reverse,
+    "SequenceReverse"  : _mx_sequence_reverse,
     "squeeze"       : _mx_squeeze,
     "broadcast_axis": _mx_broadcast_axis,
     "broadcast_axes": _mx_broadcast_axis,
index b79d8e1..2fc82d7 100644 (file)
@@ -130,6 +130,7 @@ class OperatorConverter(object):
             'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
             'ROUND': self.convert_round,
             'RSQRT': self.convert_rsqrt,
+            'REVERSE_SEQUENCE': self.convert_reverse_sequence,
             'SELECT': self.convert_select,
             'SHAPE': self.convert_shape,
             'SIN': self.convert_sin,
@@ -2002,6 +2003,33 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_reverse_sequence(self, op):
+        """Convert TFLite REVERSE_SEQUENCE"""
+        try:
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ReverseSequenceOptions import ReverseSequenceOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                'TFLite does not support quantized REVERSE_SEQUENCE operator yet.')
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+
+        in_expr = self.get_tensor_expr(input_tensors[0])
+        length_expr = self.get_tensor_expr(input_tensors[1])
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.ReverseSequenceOptions
+        op_options = op.BuiltinOptions()
+        options = ReverseSequenceOptions()
+        options.Init(op_options.Bytes, op_options.Pos)
+        batch_axis = options.BatchDim()
+        seq_axis = options.SeqDim()
+
+        return _op.reverse_sequence(in_expr, length_expr, seq_axis, batch_axis)
+
     def convert_cast(self, op):
         """Convert TFLite CAST"""
         try:
@@ -2700,14 +2728,10 @@ class OperatorConverter(object):
         return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
     def get_tensor_expr(self, tensor):
-        """ Returns constant expr for constant else a tensor expr"""
+        """ Return the Relay expr for tensor. """
         if self.has_expr(tensor.tensor_idx):
-            # In most cases, we can assume that TOCO fuses elemwise operators
-            # with constants - it means both will be tensors.
             expr = self.get_expr(tensor.tensor_idx)
         else:
-            # However, in some corner cases, the elemwise operator is not fused,
-            # we can receive as constant.
             type_str = self.get_tensor_type_str(tensor.tensor.Type())
             expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)
 
index f134b82..d104c1b 100644 (file)
@@ -41,6 +41,7 @@ _reg.register_injective_schedule("full")
 _reg.register_injective_schedule("full_like")
 _reg.register_injective_schedule("arange")
 _reg.register_injective_schedule("reverse")
+_reg.register_injective_schedule("reverse_sequence")
 _reg.register_injective_schedule("cast")
 _reg.register_injective_schedule("cast_like")
 _reg.register_injective_schedule("reinterpret")
index 429c4f1..6c3dfaf 100644 (file)
@@ -227,6 +227,10 @@ class TileAttrs(Attrs):
 class ReverseAttrs(Attrs):
     """Attributes used in reverse operators"""
 
+@tvm._ffi.register_object("relay.attrs.ReverseSequenceAttrs")
+class ReverseSequenceAttrs(Attrs):
+    """Attributes used in reverse sequence operators"""
+
 
 @tvm._ffi.register_object("relay.attrs.SqueezeAttrs")
 class SqueezeAttrs(Attrs):
index 05958fc..a37226e 100644 (file)
@@ -515,6 +515,53 @@ def reverse(data, axis):
     return _make.reverse(data, axis)
 
 
+def reverse_sequence(data, seq_lengths, seq_axis=1, batch_axis=0):
+    """Reverse the tensor for variable length slices.
+    Input is first sliced along batch axis and then elements are reversed along seq axis.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The tensor to be reversed.
+
+    seq_lengths : relay.Expr
+        A 1D Tensor with length a.dims[batch_axis]
+        Must be one of the following types: int32, int64
+        if seq_lengths[i] > a.dims[seq_axis], it is rounded to a.dims[seq_axis]
+        if seq_lengths[i] < 1, it is rounded to 1
+
+    seq_axis : int, optional
+        The axis along which the elements will be reversed. Default is 1.
+
+    batch_axis : int, optional
+        The axis along which the tensor will be sliced. Default is 0.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result of same shape and type as of input.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x = [[0, 1, 2, 3],
+             [4, 5, 6, 7],
+             [8, 9, 10, 11],
+             [12, 13, 14, 15]]
+        relay.reverse(x, [1, 2, 3, 4], 0, 1) = [[0, 5, 10, 15],
+                                                [4, 1, 6, 11],
+                                                [8, 9, 2, 7],
+                                                [12, 13, 14, 3]]
+
+        relay.reverse(x, [1, 2, 3, 4], 1, 0) = [[0, 1, 2, 3],
+                                                [5, 4, 6, 7],
+                                                [10, 9, 8, 11],
+                                                [15, 14, 13, 12]]
+    """
+    return _make.reverse_sequence(data, seq_lengths, seq_axis, batch_axis)
+
+
 def where(condition, x, y):
     """Selecting elements from either x or y depending on the value of the
     condition.
index 2a7e4e2..ee5e291 100644 (file)
@@ -1397,7 +1397,8 @@ Array<te::Tensor> ReverseCompute(const Attrs& attrs, const Array<te::Tensor>& in
                                  const Type& out_type) {
   const ReverseAttrs* param = attrs.as<ReverseAttrs>();
   CHECK(param != nullptr);
-  return {topi::flip(inputs[0], param->axis)};
+  // pass empty seq_length tensor to reverse_sequence
+  return {topi::reverse_sequence(inputs[0], te::Tensor(), param->axis)};
 }
 
 Expr MakeReverse(Expr data, int axis) {
@@ -1423,6 +1424,96 @@ RELAY_REGISTER_OP("reverse")
     .set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// reverse sequence operator
+TVM_REGISTER_NODE_TYPE(ReverseSequenceAttrs);
+
+bool ReverseSequenceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                        const TypeReporter& reporter) {
+  // `types` contains: [data, seq_lengths, result]
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+
+  if (data == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "reverse_sequence: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+
+  const auto* seq_lengths = types[1].as<TensorTypeNode>();
+  if (seq_lengths == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "reverse_sequence: expect input type to be TensorType but get " << types[1];
+    return false;
+  }
+
+  const int seq_lengths_dim = static_cast<int>(seq_lengths->shape.size());
+  CHECK(seq_lengths_dim == 1) << "For reverse_sequnece, seq_lengths must be a 1D vector";
+  CHECK(seq_lengths->dtype.is_int())
+      << "For reverse_sequnece, seq_lengths must be tensor of integer";
+
+  const auto* param = attrs.as<ReverseSequenceAttrs>();
+  const int ndim = static_cast<int>(data->shape.size());
+  int batch_axis = param->batch_axis;
+  CHECK(-ndim <= batch_axis && batch_axis < ndim)
+      << "reverse_sequence only accepts `batch_axis` in [-data.ndim, data.ndim - 1]"
+      << ", but got batch_axis = " << batch_axis << ", and data.ndim = " << ndim;
+
+  if (batch_axis < 0) {
+    batch_axis = static_cast<int>(data->shape.size()) + batch_axis;
+  }
+  CHECK(reporter->Assert(seq_lengths->shape[0] == data->shape[batch_axis]))
+      << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
+      << ", but got dimension of batch_axis = " << data->shape[batch_axis]
+      << ", and seq_length size = " << seq_lengths->shape[0];
+
+  const int seq_axis = param->seq_axis;
+  CHECK(-ndim <= seq_axis && seq_axis < ndim)
+      << "reverse_sequnece only accepts `seq_axis` in [-data.ndim, data.ndim - 1]"
+      << ", but got seq_axis = " << seq_axis << ", and data.ndim = " << ndim;
+
+  reporter->Assign(types[2], types[0]);
+  return true;
+}
+
+Array<te::Tensor> ReverseSequenceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                         const Type& out_type) {
+  const ReverseSequenceAttrs* param = attrs.as<ReverseSequenceAttrs>();
+  CHECK(param != nullptr);
+  return {topi::reverse_sequence(inputs[0], inputs[1], param->seq_axis, param->batch_axis)};
+}
+
+Expr MakeReverseSequence(Expr data, Expr seq_lengths, int seq_axis, int batch_axis) {
+  auto attrs = make_object<ReverseSequenceAttrs>();
+  attrs->seq_axis = seq_axis;
+  attrs->batch_axis = batch_axis;
+  static const Op& op = Op::Get("reverse_sequence");
+  return Call(op, {data, seq_lengths}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.reverse_sequence").set_body_typed(MakeReverseSequence);
+
+RELAY_REGISTER_OP("reverse_sequence")
+    .describe(R"code(Reverses the tensor for variable length slices.
+Input is first sliced along batch axis and then elements are reversed along seq axis.
+
+- **data**: The input data to the operator.
+
+- **seq_lengths**: A 1D Tensor with length data.dims[batch_axis].
+
+- **seq_axis**: The axis along which the elements will be reversed. Default is 1.
+
+- **batch_axis**: The axis along which the tensor will be sliced. Default is 0.
+
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .set_attrs_type<ReverseSequenceAttrs>()
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("seq_lengths", "Tensor", "A 1D Tensor with length data.dims[batch_axis]")
+    .set_support_level(3)
+    .add_type_rel("ReverseSequence", ReverseSequenceRel)
+    .set_attr<FTVMCompute>("FTVMCompute", ReverseSequenceCompute)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
 // where operator
 bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
               const TypeReporter& reporter) {
index 00c077f..8b3e04b 100644 (file)
@@ -472,6 +472,39 @@ def test_forward_slice_like():
     verify((3, 4), (2, 3), (0))
     verify((3, 4), (2, 3), (-1))
 
+def test_forward_sequence_reverse():
+    def verify(shape, seq_lengths, use_seq_lengths, seq_axis):
+        data_np = np.random.uniform(size=shape).astype("float32")
+
+        ref_res_args = [mx.nd.array(data_np), None, use_seq_lengths, seq_axis]
+        mx_sym_args = [mx.sym.var("data"), None, use_seq_lengths, seq_axis]
+        from_mxnet_args = [{"data": shape}, {"data": "float32"}]
+        in_data= [data_np]
+
+        if use_seq_lengths and seq_lengths:
+            seq_lengths_np = np.array(seq_lengths).astype("int32")
+            ref_res_args[1] = mx.nd.array(seq_lengths_np)
+            mx_sym_args[1] = mx.sym.var("seq_lengths")
+            from_mxnet_args[0].update({"seq_lengths": seq_lengths_np.shape})
+            from_mxnet_args[1].update({"seq_lengths": "int32"})
+            in_data.append(seq_lengths_np)
+
+        ref_res = mx.nd.SequenceReverse(*ref_res_args)
+        mx_sym = mx.sym.SequenceReverse(*mx_sym_args)
+        mod, _ = relay.frontend.from_mxnet(mx_sym, *from_mxnet_args)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(*in_data)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+
+    verify((3, 4), [1, 2, 3, 1], True, 0)
+    verify((3, 4), None, False, 0)
+    verify((3, 5, 5, 6), [1, 2, 3, 1, 3], True, 0)
+    # MXNet accepts axis value as 0 only
+    # verify((3, 4, 5, 6), None, False, 2)
+
 def test_forward_l2_normalize():
     data = mx.sym.var('data')
     mx_sym = mx.sym.L2Normalization(data, mode="channel")
@@ -1232,6 +1265,7 @@ if __name__ == '__main__':
     test_forward_scalar_ops()
     test_forward_slice_like()
     test_forward_slice_axis()
+    test_forward_sequence_reverse()
     test_forward_l2_normalize()
     test_forward_shape_array()
     test_forward_squeeze()
index 4adc3e0..166eb27 100644 (file)
@@ -2082,6 +2082,32 @@ def test_forward_spacetodepth():
     _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
     _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)
 
+
+#######################################################################
+# ReverseSequence
+# ---------------
+
+def _test_reverse_sequence(shape, dtype, seq_lengths, batch_axis, seq_axis):
+    """ One iteration of reverse_sequence operation with given data and attributes """
+
+    data = np.random.uniform(0, 100, size=shape).astype(dtype)
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(dtype=dtype, name="input", shape=shape)
+        out = tf.reverse_sequence(in_data, seq_lengths=seq_lengths, batch_axis=batch_axis,
+                                   seq_axis=seq_axis)
+
+        compare_tflite_with_tvm(data, 'input', [in_data], [out])
+
+
+def test_forward_reverse_sequence():
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        _test_reverse_sequence([4, 3], "float32", [3, 2, 1], 1, 0)
+        _test_reverse_sequence([4, 3], "float32", [3, 2, 1, 3], 0, 1)
+        _test_reverse_sequence([2, 3, 3, 3], "float32", [2, 3, 2], 2, 1)
+        _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3], 0, 2)
+        _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3, 1, 4], 3, 2)
+
+
 #######################################################################
 # Sparse To Dense
 # ---------------
@@ -2602,6 +2628,7 @@ if __name__ == '__main__':
     test_forward_stridedslice()
     test_forward_depthtospace()
     test_forward_spacetodepth()
+    test_forward_reverse_sequence()
     test_forward_sparse_to_dense()
     test_forward_select()
     test_forward_quantize_dequantize()
index f50a692..f3e28db 100644 (file)
@@ -663,6 +663,73 @@ def test_reverse():
     verify_reverse((2, 3, 4), -1)
 
 
+def test_reverse_sequence():
+    def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res):
+        seq_lengths_data = np.array(seq_lengths).astype("int32")
+        x = relay.var("x", relay.TensorType(x_data.shape, str(x_data.dtype)))
+        z = relay.reverse_sequence(x, relay.const(seq_lengths_data), seq_axis, batch_axis)
+        zz = run_infer_type(z)
+        assert zz.checked_type == x.type_annotation
+        func = relay.Function([x], z)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(x_data)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = [[0, 5, 10, 15],
+              [4, 1, 6, 11],
+              [8, 9, 2, 7],
+              [12, 13, 14, 3]]
+    verify_reverse_sequence(indata, [1, 2, 3, 4], 1, 0, np.array(result))
+    verify_reverse_sequence(indata, [1, 2, 3, 4], -1, 0, np.array(result))
+    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32"))
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = [[0, 1, 2, 3],
+              [5, 4, 6, 7],
+              [10, 9, 8, 11],
+              [15, 14, 13, 12]]
+    verify_reverse_sequence(indata, [1, 2, 3, 4], 0, 1, np.array(result))
+    verify_reverse_sequence(indata, [1, 2, 3, 4], 0, -1, np.array(result))
+    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32"))
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = [[0, 1, 2, 3],
+              [4, 5, 6, 7],
+              [8, 9, 10, 11],
+              [15, 14, 13, 12]]
+    verify_reverse_sequence(indata, [-1, 0, 1, 5], 0, 1, np.array(result))
+
+    indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
+    result = [[[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
+               [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
+               [[0,  1,  2], [3,  4,  5], [6,  7,  8]]],
+              [[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
+               [[36, 37, 38], [39, 40, 41], [42, 43, 44]],
+               [[27, 28, 29], [30, 31, 32], [33, 34, 35]]]]
+    verify_reverse_sequence(indata, [3, 3], 0, 1, np.array(result))
+
+    indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
+    result = [[[[9, 10, 11], [21, 22, 23], [15, 16, 17]],
+               [[0, 1, 2], [12, 13, 14], [6, 7, 8]],
+               [[18, 19, 20], [3, 4, 5], [24, 25, 26]]],
+              [[[36, 37, 38], [48, 49, 50], [42, 43, 44]],
+               [[27, 28, 29], [39, 40, 41], [33, 34, 35]],
+               [[45, 46, 47], [30, 31, 32], [51, 52, 53]]]]
+    verify_reverse_sequence(indata, [2, 3, 2], 2, 1, np.array(result))
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = []
+    with pytest.raises(Exception) as execinfo:
+        verify_reverse_sequence(indata, [2, 3, 2, 4, 5], 1, 0, np.array(result))
+
+    assert "For reverse_sequnece seq_lengths size should match with dimension of batch axis," \
+           " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0]
+
+
 def test_scatter():
 
     def ref_scatter(data, indices, updates, axis=0):
index 7947967..e0e4556 100644 (file)
@@ -150,44 +150,72 @@ inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name =
 }
 
 /*!
- * \brief flip/reverse elements of an array in a particular axis
+ * \brief Reverse the tensor for variable length slices.
+ * Input is first sliced along batch axis and then elements are reversed along seq axis.
  *
  * \param x The input tensor
- * \param axis The axis along which the tensors will be reveresed
- * (allows negative indices)
+ * \param seq_lengths A 1D Tensor with length x.dims[batch_axis]. Optional Tensor() can be passed.
+ * If not defined batch axis is ignored and tensor is reversed along seq_axis.
+ * \param seq_axis The axis along which the elements will be reveresed
+ * \param batch_axis The axis along which the tensor will be sliced
  * \param name The name of the operation
  * \param tag The tag to mark the operation
  *
- * \return A Tensor whose op member is the reverse operation
+ * \return A Tensor whose op member is the reverse_sequence operation
  */
-inline Tensor flip(const Tensor& x, int axis = 0, std::string name = "T_flip",
-                   std::string tag = kInjective) {
+inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
+                               int batch_axis = 0, std::string name = "T_reverse_sequence",
+                               std::string tag = kInjective) {
   size_t src_tensor_dim = x->shape.size();
-  int axis_inp = axis;
+  int seq_axis_inp = seq_axis;
 
-  if (axis < 0) {
-    axis = static_cast<int>(x->shape.size()) + axis;
+  if (seq_lengths.defined()) {
+    size_t seq_lengths_dim = seq_lengths->shape.size();
+    int batch_axis_inp = batch_axis;
+    if (batch_axis < 0) {
+      batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
+    }
+
+    CHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
+
+    CHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
+        << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
+        << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
+        << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
+
+    CHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
+        << "batch_axis=" << batch_axis_inp << " is invalid for the "
+        << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
   }
 
-  CHECK((0 <= axis) && (axis < static_cast<int>(x->shape.size())))
-      << "axis=" << axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
+  if (seq_axis < 0) {
+    seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
+  }
+  CHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
+      << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
       << "-dimensional input tensor";
 
-  // Reverse the Input Tensor in the axis specified
-  return compute(
-      x->shape,
-      [&](const Array<Var>& indices) {
-        Array<PrimExpr> real_indices;
-        for (size_t i = 0; i < src_tensor_dim; ++i) {
-          if (i == static_cast<size_t>(axis)) {
-            real_indices.push_back(x->shape[i] - indices[i] - 1);
-          } else {
-            real_indices.push_back(indices[i]);
-          }
+  auto func = [&](const Array<Var>& indices) {
+    Array<PrimExpr> real_indices;
+    for (size_t i = 0; i < src_tensor_dim; ++i) {
+      if (i == static_cast<size_t>(seq_axis)) {
+        if (seq_lengths.defined()) {
+          auto len = seq_lengths(indices[batch_axis]);
+          auto idx = if_then_else(
+              len <= 1 || len <= indices[i], indices[i],
+              if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
+          real_indices.push_back(idx);
+        } else {
+          real_indices.push_back(x->shape[i] - 1 - indices[i]);
         }
-        return x(real_indices);
-      },
-      name, tag);
+      } else {
+        real_indices.push_back(indices[i]);
+      }
+    }
+    return x(real_indices);
+  };
+
+  return compute(x->shape, func, name, tag);
 }
 
 /*!
index f1bcccd..a8c8b14 100644 (file)
@@ -131,6 +131,37 @@ def flip(a, axis=0):
     """
     return cpp.flip(a, axis)
 
+
+def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0):
+    """Reverse the tensor for variable length slices.
+    Input is first sliced along batch axis and then elements are reversed along seq axis.
+
+    Parameters
+    ----------
+    a : tvm.te.Tensor
+       The tensor to be reversed.
+
+    seq_lengths : tvm.te.Tensor
+       A 1D Tensor with length a.dims[batch_axis]
+       Must be one of the following types: int32, int64
+       if seq_lengths[i] > a.dims[seq_axis], it is rounded to a.dims[seq_axis]
+       if seq_lengths[i] < 1, it is rounded to 1
+
+    seq_axis : int, optional
+       The axis along which the elements will be reversed. Default is 1.
+
+    batch_axis : int, optional
+       The axis along which the tensor will be sliced. Default is 0.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+       The computed result of same shape and type as of input.
+
+    """
+    return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis)
+
+
 def strided_slice(a, begin, end, strides=None, slice_mode="end"):
     """Slice of an array.
 
index 2791ff7..4308784 100644 (file)
@@ -40,7 +40,12 @@ TVM_REGISTER_GLOBAL("topi.transpose").set_body([](TVMArgs args, TVMRetValue* rv)
 });
 
 TVM_REGISTER_GLOBAL("topi.flip").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = flip(args[0], args[1]);
+  // pass empty seq_lengths tensor to reverse_sequence
+  *rv = reverse_sequence(args[0], Tensor(), args[1]);
+});
+
+TVM_REGISTER_GLOBAL("topi.reverse_sequence").set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = reverse_sequence(args[0], args[1], args[2], args[3]);
 });
 
 TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) {
index 96df101..b0aee6a 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 """Test code for broadcasting operators."""
 import numpy as np
+import pytest
 import tvm
 from tvm import te
 import topi
@@ -289,6 +290,85 @@ def verify_flip(in_shape, axis):
     for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(device)
 
+
+def test_reverse_sequence():
+    def verify_reverse_sequence(in_data, seq_lengths, batch_axis, seq_axis, ref_res):
+        seq_lengths = np.array(seq_lengths).astype("int32")
+        A = te.placeholder(shape=in_data.shape, name="A", dtype=str(in_data.dtype))
+        B = te.placeholder(shape=seq_lengths.shape, name="B", dtype=str(seq_lengths.dtype))
+        C = topi.reverse_sequence(A, B, seq_axis, batch_axis)
+
+        def check_device(device):
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                return
+            print("Running on target: %s" % device)
+            with tvm.target.create(device):
+                s = topi.testing.get_injective_schedule(device)(C)
+
+            foo = tvm.build(s, [A, B, C], device, name="reverse_sequence")
+
+            data_nd = tvm.nd.array(in_data, ctx)
+            seq_lengths_nd = tvm.nd.array(seq_lengths, ctx)
+            out_nd = tvm.nd.empty(in_data.shape, ctx=ctx, dtype=A.dtype)
+            foo(data_nd, seq_lengths_nd, out_nd)
+            tvm.testing.assert_allclose(out_nd.asnumpy(), ref_res)
+
+        for device in get_all_backend():
+            check_device(device)
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = [[0, 5, 10, 15],
+              [4, 1, 6, 11],
+              [8, 9, 2, 7],
+              [12, 13, 14, 3]]
+    verify_reverse_sequence(indata, [1, 2, 3, 4], 1, 0, np.array(result))
+    verify_reverse_sequence(indata, [1, 2, 3, 4], -1, 0, np.array(result))
+    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32"))
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = [[0, 1, 2, 3],
+              [5, 4, 6, 7],
+              [10, 9, 8, 11],
+              [15, 14, 13, 12]]
+    verify_reverse_sequence(indata, [1, 2, 3, 4], 0, 1, np.array(result))
+    verify_reverse_sequence(indata, [1, 2, 3, 4], 0, -1, np.array(result))
+    verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32"))
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = [[0, 1, 2, 3],
+              [4, 5, 6, 7],
+              [8, 9, 10, 11],
+              [15, 14, 13, 12]]
+    verify_reverse_sequence(indata, [-1, 0, 1, 5], 0, 1, np.array(result))
+
+    indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
+    result = [[[[18, 19, 20], [21, 22, 23], [24, 25, 26]],
+               [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
+               [[0,  1,  2], [3,  4,  5], [6,  7,  8]]],
+              [[[45, 46, 47], [48, 49, 50], [51, 52, 53]],
+               [[36, 37, 38], [39, 40, 41], [42, 43, 44]],
+               [[27, 28, 29], [30, 31, 32], [33, 34, 35]]]]
+    verify_reverse_sequence(indata, [3, 3], 0, 1, np.array(result))
+
+    indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
+    result = [[[[9, 10, 11], [21, 22, 23], [15, 16, 17]],
+               [[0, 1, 2], [12, 13, 14], [6, 7, 8]],
+               [[18, 19, 20], [3, 4, 5], [24, 25, 26]]],
+              [[[36, 37, 38], [48, 49, 50], [42, 43, 44]],
+               [[27, 28, 29], [39, 40, 41], [33, 34, 35]],
+               [[45, 46, 47], [30, 31, 32], [51, 52, 53]]]]
+    verify_reverse_sequence(indata, [2, 3, 2], 2, 1, np.array(result))
+
+    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
+    result = []
+    with pytest.raises(Exception) as execinfo:
+        verify_reverse_sequence(indata, [2, 3, 2, 4, 5], 1, 0, np.array(result))
+
+    assert "For reverse_sequnece seq_lengths size should match with dimension of batch axis," \
+           " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0]
+
 def verify_take(src_shape, indices_src, axis=None, mode="clip"):
     src_dtype = "float32"
     indices_dtype = "int32"