[RELAY] [OP] [MXNet Frontend] Add sequence_mask (#3437)
authorXingjian Shi <sxjscience001@gmail.com>
Fri, 28 Jun 2019 04:51:04 +0000 (21:51 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Fri, 28 Jun 2019 04:51:04 +0000 (21:51 -0700)
* Add sequence_mask

use exactly the same arguments as mxnet

fix

* fix lint

* fix lint

* add mxnet conversion + relay

* update

* update doc

* fix pylint

* fix doc

* address comment

* try to address comments

* try to enable shape check for valid_length

* fix

* try to fix

* fix bug

* try to fix

* address comment

* address comment

15 files changed:
docs/api/python/topi.rst
docs/langref/relay_op.rst
include/tvm/relay/attrs/transform.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/relay/test_op_level10.py
topi/include/topi/transform.h
topi/python/topi/testing/__init__.py
topi/python/topi/testing/sequence_mask_python.py [new file with mode: 0644]
topi/python/topi/transform.py
topi/src/topi.cc
topi/tests/python/test_topi_transform.py

index ade0f1a..367ad1a 100644 (file)
@@ -101,6 +101,7 @@ List of operators
    topi.image.resize
    topi.argsort
    topi.topk
+   topi.sequence_mask
 
 
 List of schedules
@@ -167,6 +168,7 @@ topi
 .. autofunction:: topi.layout_transform
 .. autofunction:: topi.argsort
 .. autofunction:: topi.topk
+.. autofunction:: topi.sequence_mask
 
 topi.nn
 ~~~~~~~
index 28ee99e..ccdb3e8 100644 (file)
@@ -190,6 +190,7 @@ This level support backpropagation of broadcast operators. It is temporary.
    tvm.relay.device_copy
    tvm.relay.annotation.on_device
    tvm.relay.reverse_reshape
+   tvm.relay.sequence_mask
    tvm.relay.nn.batch_matmul
    tvm.relay.contrib.adaptive_max_pool2d
    tvm.relay.contrib.adaptive_avg_pool2d
@@ -323,6 +324,7 @@ Level 10 Definitions
 .. autofunction:: tvm.relay.device_copy
 .. autofunction:: tvm.relay.annotation.on_device
 .. autofunction:: tvm.relay.reverse_reshape
+.. autofunction:: tvm.relay.sequence_mask
 .. autofunction:: tvm.relay.nn.batch_matmul
 .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
 .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
index 5e31518..1247884 100644 (file)
@@ -275,6 +275,18 @@ struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
   }
 };
 
+struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
+  double mask_value;
+  int axis;
+
+  TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") {
+    TVM_ATTR_FIELD(mask_value).set_default(0)
+      .describe("The masking value.");
+    TVM_ATTR_FIELD(axis).set_default(0)
+      .describe("The axis of the length dimension. Can only be 0 or 1.");
+  }
+};  // struct SequenceMaskAttrs.
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
index 2f36355..0bcee63 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, import-self, len-as-condition
+# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return
 """MXNet symbol frontend."""
 from __future__ import absolute_import as _abs
 
@@ -709,6 +709,18 @@ def _mx_topk(inputs, attrs):
     return _op.topk(inputs[0], **new_attrs)
 
 
+def _mx_SequenceMask(inputs, attrs):
+    assert len(inputs) == 1 or len(inputs) == 2
+    new_attrs = {}
+    use_sequence_length = attrs.get_bool('use_sequence_length', False)
+    new_attrs['mask_value'] = attrs.get_float('value', 0.0)
+    new_attrs['axis'] = attrs.get_int('axis', 0)
+    if use_sequence_length:
+        return _op.sequence_mask(*inputs, **new_attrs)
+    else:
+        return inputs[0]
+
+
 def _mx_rnn_param_concat(inputs, _):
     # We don't need to concatenate RNN params because we will unravel the RNN op
     return [inputs]
@@ -994,6 +1006,7 @@ _convert_map = {
     "Embedding"     : _mx_embedding,
     "argsort"       : _mx_argsort,
     "topk"          : _mx_topk,
+    "SequenceMask"  : _mx_SequenceMask,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
     "LinearRegressionOutput" : _mx_linear_regression_output,
index 95fb2ad..0749bbd 100644 (file)
@@ -19,7 +19,7 @@
 from __future__ import absolute_import
 from . import op as _reg
 from ._reduce import _schedule_reduce
-from .op import schedule_injective, OpPattern
+from .op import OpPattern
 
 schedule_injective = _reg.schedule_injective
 schedule_broadcast = _reg.schedule_injective
@@ -50,6 +50,8 @@ _reg.register_schedule("stack", schedule_injective)
 _reg.register_schedule("concatenate", schedule_concatenate)
 _reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
 _reg.register_schedule("gather_nd", schedule_injective)
+_reg.register_schedule("sequence_mask", schedule_injective)
+
 
 # layout_transform
 _reg.register_schedule("layout_transform", schedule_injective)
index dce2258..bac60a0 100644 (file)
@@ -678,3 +678,49 @@ def gather_nd(data, indices):
         relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
     """
     return _make.gather_nd(data, indices)
+
+
+def sequence_mask(data, valid_length, mask_value=0, axis=0):
+    """Sets all elements outside the expected length of the sequence to a constant value.
+
+    This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
+    [batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data.
+
+    valid_length : relay.Expr
+        The expected (valid) length of each sequence in the tensor.
+
+    mask_value : float
+        The masking value.
+
+    axis : int
+        The axis of the length dimension.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x = [[[  1.,   2.,   3.], [  4.,   5.,   6.]],
+             [[  7.,   8.,   9.], [ 10.,  11.,  12.]],
+             [[ 13.,  14.,   15.], [ 16.,  17.,   18.]]]
+
+       relay.sequence_mask(x, valid_length=[1, 1]) =
+            [[[  1.,   2.,   3.], [  4.,   5.,   6.]],
+             [[  0.,   0.,   0.], [  0.,   0.,   0.]],
+             [[  0.,   0.,   0.], [  0.,   0.,   0.]]]
+
+       relay.sequence_mask(x, valid_length=[2, 3], mask_value=0.1) =
+            [[[  1.,   2.,   3.], [  4.,   5.,   6.]],
+             [[  7.,   8.,   9.], [  10.,  11.,  12.]],
+             [[  0.1,  0.1,  0.1], [  16.,  17.,  18.]]]
+    """
+    return _make.sequence_mask(data, valid_length, mask_value, axis)
index 873e75d..da93860 100644 (file)
@@ -805,7 +805,7 @@ Examples::
 .set_num_inputs(2)
 .add_argument("data", "Tensor", "The input tensor.")
 .add_argument("indices", "Tensor", "The indices tensor.")
-.set_support_level(2)
+.set_support_level(3)
 .add_type_rel("Take", TakeRel)
 .set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
 .set_attr<TOpPattern>("TOpPattern", kInjective);
@@ -2218,5 +2218,108 @@ output shape will simply be (Y_0, ..., Y_{K-1}).
 .set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
 .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// relay.sequence_mask
+TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs);
+
+bool SequenceMaskRel(const Array<Type>& types,
+                     int num_inputs,
+                     const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  // `types` contains: [data, valid_length, result]
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* valid_length = types[1].as<TensorTypeNode>();
+  CHECK(data);
+  CHECK(valid_length);
+  const auto param = attrs.as<SequenceMaskAttrs>();
+  Array<IndexExpr> valid_length_shape;
+  CHECK(param->axis == 0 || param->axis == 1);
+  valid_length_shape.push_back(data->shape[1 - param->axis]);
+  reporter->Assign(types[1], TensorTypeNode::make(valid_length_shape, valid_length->dtype));
+  reporter->Assign(types[2], types[0]);
+  return true;
+}
+
+Array<Tensor> SequenceMaskCompute(const Attrs& attrs,
+                                  const Array<Tensor>& inputs,
+                                  const Type& out_type,
+                                  const Target& target) {
+  const auto* param = attrs.as<SequenceMaskAttrs>();
+  CHECK(param != nullptr);
+  return Array<Tensor>{ topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) };
+}
+
+Expr MakeSequenceMask(Expr data,
+                      Expr valid_length,
+                      double mask_value,
+                      int axis) {
+  auto attrs = make_node<SequenceMaskAttrs>();
+  attrs->mask_value = std::move(mask_value);
+  attrs->axis = std::move(axis);
+  static const Op& op = Op::Get("sequence_mask");
+  return CallNode::make(op, {data, valid_length}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.sequence_mask")
+.set_body_typed(MakeSequenceMask);
+
+RELAY_REGISTER_OP("sequence_mask")
+.describe(R"code(Sets all elements outside the expected length of the sequence to a constant value.
+
+This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
+[batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
+
+`axis` means the axis of the length dimension and can only be 0 or 1. If axis is 0,
+the data must have shape [MAX_LENGTH, batch_size, ...]. Otherwise (axis=1), the data must have
+shape [batch_size, MAX_LENGTH, ...].
+
+`valid_length` gives the length of each sequence. `valid_length` should be
+a 1D int array with positive ints and has dimension [batch_size,].
+
+Examples::
+
+  x = [[[  1.,   2.,   3.],
+        [  4.,   5.,   6.]],
+
+       [[  7.,   8.,   9.],
+        [ 10.,  11.,  12.]],
+
+       [[ 13.,  14.,   15.],
+        [ 16.,  17.,   18.]]]
+
+  // valid_length [1, 1] means only the first block of each batch will be kept
+  // and other blocks are masked with default mask value = 0
+  sequence_mask(x, valid_length=[1, 1]) =
+       [[[  1.,   2.,   3.],
+         [  4.,   5.,   6.]],
+
+        [[  0.,   0.,   0.],
+         [  0.,   0.,   0.]],
+
+        [[  0.,   0.,   0.],
+         [  0.,   0.,   0.]]]
+
+  // valid_length [2, 3] means the first 2 blocks of the 1st batch will be kept
+  // and the first 3 blocks of the 2nd batch will be kept
+  // the masked values are set to be the specified mask value = 0.1
+  sequence_mask(x, valid_length=[2, 3], mask_value=0.1) =
+       [[[  1.,   2.,   3.],
+         [  4.,   5.,   6.]],
+
+        [[  7.,   8.,   9.],
+         [  10.,  11.,  12.]],
+
+        [[  0.1,  0.1,  0.1],
+         [  16.,  17.,  18.]]]
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.SequenceMaskAttrs")
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.")
+.set_support_level(10)
+.add_type_rel("SequenceMask", SequenceMaskRel)
+.set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
+.set_attr<TOpPattern>("TOpPattern", kInjective);
+
 }  // namespace relay
 }  // namespace tvm
index ffef538..aec1980 100644 (file)
@@ -666,6 +666,51 @@ def test_forward_topk():
     verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
     verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")
 
+def test_forward_sequence_mask():
+    def verify(shape, use_sequence_length, value, axis, dtype, itype):
+        data_np = np.random.uniform(size=shape).astype(dtype)
+        valid_length_np = np.random.randint(0, shape[axis], size=shape[1-axis]).astype(itype)
+        if use_sequence_length:
+            ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
+                                         sequence_length=mx.nd.array(valid_length_np, dtype=itype),
+                                         use_sequence_length=use_sequence_length,
+                                         value=value,
+                                         axis=axis)
+            mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
+                                         sequence_length=mx.sym.var('valid_length'),
+                                         use_sequence_length=use_sequence_length,
+                                         value=value,
+                                         axis=axis)
+            mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape,
+                                                        'valid_length': valid_length_np.shape},
+                                               dtype={"data": dtype,
+                                                      "valid_length": itype})
+        else:
+            ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
+                                         use_sequence_length=use_sequence_length,
+                                         value=value,
+                                         axis=axis)
+            mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
+                                         use_sequence_length=use_sequence_length,
+                                         value=value,
+                                         axis=axis)
+            mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}, dtype={"data": dtype})
+        for target, ctx in ctx_list():
+            for kind in ['graph', 'debug']:
+                if use_sequence_length is False and kind == 'graph':
+                    # Disable the test for 'graph' when it's identity.
+                    continue
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                if use_sequence_length:
+                    op_res = intrp.evaluate()(data_np, valid_length_np)
+                else:
+                    op_res = intrp.evaluate()(data_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
+    verify((5, 10), True, 0.0, 0, 'float32', 'float32')
+    verify((5, 4, 3), True, 1.0, 1, 'float32', 'float32')
+    verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
+    verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')
+
 
 if __name__ == '__main__':
     test_forward_mlp()
@@ -710,3 +755,4 @@ if __name__ == '__main__':
     test_forward_Crop()
     test_forward_argsort()
     test_forward_topk()
+    test_forward_sequence_mask()
index 244744c..f904fb0 100644 (file)
@@ -249,6 +249,27 @@ def test_adaptive_pool2d():
     verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max")
     verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg")
 
+def test_sequence_mask():
+    def _verify(data_shape, mask_value, axis, dtype, itype):
+        max_length = data_shape[axis]
+        nbatch = data_shape[1 - axis]
+        data = relay.var("data", relay.TensorType(data_shape, dtype))
+        valid_length = relay.var("valid_length", relay.TensorType((nbatch,), itype))
+        out = relay.sequence_mask(data, valid_length, mask_value, axis)
+        assert relay.ir_pass.infer_type(out).checked_type == relay.ty.TensorType(data_shape, dtype)
+        func = relay.Function([data, valid_length], out)
+        data_np = np.random.uniform(size=data_shape).astype(dtype)
+        valid_length_np = np.random.randint(0, max_length, size=nbatch).astype(itype)
+        gt_out_np = topi.testing.sequence_mask(data_np, valid_length_np, mask_value, axis)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                out_relay = intrp.evaluate(func)(data_np, valid_length_np)
+                tvm.testing.assert_allclose(out_relay.asnumpy(), gt_out_np)
+    _verify((5, 10), 0.0, 1, 'float32', 'int32')
+    _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
+    _verify((5, 8, 3), 0.1, 1, 'float64', 'float32')
 
 if __name__ == "__main__":
     test_adaptive_pool2d()
@@ -258,3 +279,4 @@ if __name__ == "__main__":
     test_reverse_reshape()
     test_batch_matmul()
     test_shape_of()
+    test_sequence_mask()
index c992be6..a7314a7 100644 (file)
@@ -657,6 +657,43 @@ inline Tensor take(const Tensor& a,
   }
 }
 
+
+/*!
+* \brief Mask the out-of-boundary elements of each sequence.
+*
+* \param data The source array.
+* \param valid_length The real length of each sequence.
+* \param mask_value The masking value.
+* \param axis The axis of the temporal dimension of the sequence
+* \param name The name of the operation.
+* \param tag The tag to mark the operation.
+*
+* \return A Tensor whose op member is the sequence_mask operation
+*/
+inline Tensor sequence_mask(const Tensor& data,
+                            const Tensor& valid_length,
+                            double mask_value,
+                            int axis,
+                            std::string name = "T_sequence_mask",
+                            std::string tag = kInjective) {
+  CHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
+  CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
+  auto length_dim = data->shape[axis];
+  auto batch_dim = data->shape[1 - axis];
+  Array<Expr> out_shape = data->shape;
+  Tensor out = compute(
+      out_shape, [&](const Array<Var>& out_index) {
+        Array<Expr> len_index;
+        auto tid = out_index[axis];
+        auto bid = out_index[1 - axis];
+        len_index.push_back(bid);
+        Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
+                                     tvm::cast(data->dtype, Expr(mask_value)), data(out_index));
+        return ret;
+      }, name, tag);
+  return out;
+}
+
 /*!
 * \brief Take elements from an array along an axis.
 *
index 40c1bdc..2d76ba9 100644 (file)
@@ -23,3 +23,4 @@ from .gather_nd_python import gather_nd_python
 from .strided_slice_python import strided_slice_python
 from .batch_matmul import batch_matmul
 from .slice_axis_python import slice_axis_python
+from .sequence_mask_python import sequence_mask
diff --git a/topi/python/topi/testing/sequence_mask_python.py b/topi/python/topi/testing/sequence_mask_python.py
new file mode 100644 (file)
index 0000000..d77eb6f
--- /dev/null
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Sequence mask in python"""
+import numpy as np
+
+def sequence_mask(data, valid_length, mask_value, axis):
+    """batch_matmul operator implemented in numpy.
+
+    Parameters
+    ----------
+    data : numpy.ndarray
+        N-D with shape [batch_size, MAX_LENGTH, ...] or [MAX_LENGTH, batch_size, ...]
+
+    valid_length : numpy.ndarray
+        1-D with shape [batch_size,]
+
+    mask_value : float
+        Masking value
+
+    axis : int
+        The axis of the length dimension
+
+    Returns
+    -------
+    out : numpy.ndarray
+        N-D with shape same as data
+    """
+    in_shape = data.shape
+    max_length = data.shape[axis]
+    val_len_expand_shape = [1 for _ in range(len(in_shape))]
+    val_len_expand_shape[1 - axis] = in_shape[1 - axis]
+    seq_len_expand_shape = [1 for _ in range(len(in_shape))]
+    seq_len_expand_shape[axis] = in_shape[axis]
+    mask = np.broadcast_to(np.arange(max_length).reshape(seq_len_expand_shape),
+                           in_shape) >= valid_length.reshape(val_len_expand_shape)
+    out = data * (1 - mask) + mask_value * mask
+    return out
index 3d7293e..738754e 100644 (file)
@@ -436,3 +436,44 @@ def shape(array, dtype="int32"):
         The resulting tensor.
     """
     return cpp.shape(array, dtype)
+
+
+def sequence_mask(data, valid_length, mask_value=0, axis=0):
+    """Sets all elements outside the expected length of the sequence to a constant value.
+
+    This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
+    [batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
+
+    `axis` means the axis of the length dimension and can only be 0 or 1. If `axis` is 0,
+    the data must have shape [MAX_LENGTH, batch_size, ...]. Otherwise (axis=1), the data must have
+    shape [batch_size, MAX_LENGTH, ...].
+
+    `valid_length` gives the length of each sequence. `valid_length` should be
+    a 1D int array with positive ints and has dimension [batch_size,].
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        N-D with shape [MAX_LENGTH, batch_size, ...] or [batch_size, MAX_LENGTH, ...]
+        depending on the value of `axis`.
+
+    valid_length : tvm.Tensor
+        1-D with shape [batch_size,]
+
+    mask_value : float, optional
+        The masking value, default 0
+
+    axis : int, optional
+        axis of the length dimension, must be 0 or 1, default 0
+
+    Returns
+    -------
+    output : tvm.Tensor
+        N-D with shape [MAX_LENGTH, batch_size, ...] or [batch_size, MAX_LENGTH, ...]
+        depending on the value of `axis`.
+    """
+
+    assert len(data.shape) >= 2,\
+        "only support data.ndim >= 2, received data.shape = {}".format(data.shape)
+    assert axis == 0 or axis == 1, "only support axis = 0, 1, received axis = {}".format(axis)
+    return cpp.sequence_mask(data, valid_length, mask_value, axis)
index 57a2743..688cc9f 100644 (file)
@@ -337,6 +337,14 @@ TVM_REGISTER_GLOBAL("topi.take")
   }
   });
 
+TVM_REGISTER_GLOBAL("topi.sequence_mask")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  double pad_val = args[2];
+  int axis = args[3];
+  *rv = sequence_mask(args[0], args[1], pad_val, axis);
+});
+
+
 TVM_REGISTER_GLOBAL("topi.where")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = where(args[0], args[1], args[2]);
index 5682fde..9d69734 100644 (file)
@@ -619,6 +619,36 @@ def test_shape():
         check_device(backend)
 
 
+def test_sequence_mask():
+    for in_shape in (5, 10), (3, 4, 5, 4):
+        for axis in [0, 1]:
+            for mask_value in [0.0, 1.0]:
+                max_length = in_shape[axis]
+                batch_size = in_shape[1 - axis]
+                A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
+                B = tvm.placeholder(shape=(batch_size,), dtype="int32", name="B")
+                C = topi.sequence_mask(A, B, axis=axis, mask_value=mask_value)
+                A_data = np.random.normal(0, 1, in_shape).astype(np.float32)
+                B_data = np.random.randint(1, max_length, (batch_size,)).astype(np.int32)
+                C_gt_data = topi.testing.sequence_mask(A_data, B_data, mask_value, axis)
+
+                def check_device(device):
+                    ctx = tvm.context(device, 0)
+                    if not ctx.exist:
+                        print("Skip because %s is not enabled" % device)
+                        return
+                    tvm_A = tvm.nd.array(A_data, ctx)
+                    tvm_B = tvm.nd.array(B_data, ctx)
+                    tvm_C = tvm.nd.empty(in_shape, ctx=ctx, dtype="float32")
+                    print("Running on target: %s" % device)
+                    with tvm.target.create(device):
+                        s = topi.generic.schedule_injective(C)
+                    f = tvm.build(s, [A, B, C], device, name="SequenceMask")
+                    f(tvm_A, tvm_B, tvm_C)
+                    tvm.testing.assert_allclose(tvm_C.asnumpy(), C_gt_data)
+                for backend in get_all_backend():
+                    check_device(backend)
+
 if __name__ == "__main__":
     test_strided_slice()
     test_concatenate()
@@ -637,3 +667,4 @@ if __name__ == "__main__":
     test_repeat()
     test_tile()
     test_shape()
+    test_sequence_mask()