[Relay][Frontend] Add reverse op to relay (#2800)
authorLeyuan Wang <laurawly@gmail.com>
Wed, 13 Mar 2019 21:24:46 +0000 (14:24 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Wed, 13 Mar 2019 21:24:46 +0000 (14:24 -0700)
* start adding reverse

* reverse updated

* reverse uses topi::flip

* typo fixed

* comment addressed

* exp simplified

docs/langref/relay_op.rst
include/tvm/relay/attrs/transform.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/relay/test_op_level3.py

index f20c443..e16da29 100644 (file)
@@ -99,6 +99,7 @@ This level enables additional math and transform operators.
    tvm.relay.stack
    tvm.relay.repeat
    tvm.relay.tile
+   tvm.relay.reverse
 
 
 **Level 4: Broadcast and Reductions**
@@ -229,6 +230,7 @@ Level 3 Definitions
 .. autofunction:: tvm.relay.stack
 .. autofunction:: tvm.relay.repeat
 .. autofunction:: tvm.relay.tile
+.. autofunction:: tvm.relay.reverse
 
 
 Level 4 Definitions
index 5382017..326c9f0 100644 (file)
@@ -146,6 +146,15 @@ struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
   }
 };  // struct TileAttrs
 
+/*! \brief Attributes used in reverse operators */
+struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
+  Integer axis;
+  TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
+    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
+        .describe("The axis along which to reverse elements.");
+  }
+};  // struct ReverseAttrs
+
 /*! \brief Attributes used in squeeze operators */
 struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
   // use axis to make the name numpy compatible.
index 93bd8ef..42f3e8a 100644 (file)
@@ -422,6 +422,13 @@ def _mx_tile(inputs, attrs):
     return _op.tile(inputs[0], **new_attrs)
 
 
+def _mx_reverse(inputs, attrs):
+    assert len(inputs) == 1
+    new_attrs = {}
+    new_attrs["axis"] = attrs.get_int("axis")
+    return _op.reverse(inputs[0], **new_attrs)
+
+
 def _mx_roi_align(inputs, attrs):
     new_attrs = {}
     new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
@@ -612,6 +619,7 @@ _convert_map = {
     "_arange"       : _mx_arange,
     "repeat"        : _mx_repeat,
     "tile"          : _mx_tile,
+    "reverse"       : _mx_reverse,
     "BlockGrad"     : _mx_BlockGrad,
     "SoftmaxOutput" : _mx_softmax_output,
     "SoftmaxActivation" : _mx_softmax_activation,
index 2b43c21..72fbca9 100644 (file)
@@ -19,6 +19,7 @@ _reg.register_schedule("reshape_like", schedule_injective)
 _reg.register_schedule("full", schedule_injective)
 _reg.register_schedule("full_like", schedule_injective)
 _reg.register_schedule("arange", schedule_injective)
+_reg.register_schedule("reverse", schedule_injective)
 _reg.register_schedule("repeat", schedule_broadcast)
 _reg.register_schedule("tile", schedule_broadcast)
 _reg.register_schedule("cast", schedule_injective)
index b772698..37aace5 100644 (file)
@@ -385,6 +385,35 @@ def tile(data, reps):
     return _make.tile(data, reps)
 
 
+def reverse(data, axis):
+    """Reverses the order of elements along given axis while preserving array shape.
+    By default, repeat flattens the input array into 1-D and then repeats the elements.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    axis: int
+        The axis along which to reverse elements.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x = [[1., 2.], [3., 4.]]
+        relay.reverse(x, axis=0) = [[3., 4.], [1., 2.]]
+
+        relay.reverse(x, axis=1) = [[2., 1.], [4., 3.]]
+    """
+    return _make.reverse(data, axis)
+
+
 def where(condition, x, y):
     """Selecting elements from either x or y depending on the value of the
     condition.
index 7aa98e3..36b93ee 100644 (file)
@@ -1086,8 +1086,8 @@ Array<Tensor> RepeatCompute(const Attrs& attrs,
 }
 
 Expr MakeRepeat(Expr data,
-                    int repeats,
-                    int axis) {
+                int repeats,
+                int axis) {
   auto attrs = make_node<RepeatAttrs>();
   attrs->repeats = repeats;
   attrs->axis = axis;
@@ -1204,6 +1204,69 @@ RELAY_REGISTER_OP("tile")
 .set_attr<FTVMCompute>("FTVMCompute", TileCompute)
 .set_attr<TOpPattern>("TOpPattern", kBroadcast);
 
+// reverse operator
+TVM_REGISTER_NODE_TYPE(ReverseAttrs);
+
+bool ReverseRel(const Array<Type>& types,
+               int num_inputs,
+               const Attrs& attrs,
+               const TypeReporter& reporter) {
+  // `types` contains: [data, result]
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "reverse: expect input type to be TensorType but get "
+        << types[0];
+    return false;
+  }
+  const auto* param = attrs.as<ReverseAttrs>();
+  const int ndim = static_cast<int>(data->shape.size());
+  const int axis = param->axis;
+  CHECK(-ndim <= axis && axis < ndim)
+    << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]"
+    << ", but got axis = " << axis
+    << ", and data.ndim = " << ndim;
+  reporter->Assign(types[1], types[0]);
+  return true;
+}
+
+Array<Tensor> ReverseCompute(const Attrs& attrs,
+                             const Array<Tensor>& inputs,
+                             const Type& out_type,
+                             const Target& target) {
+  const ReverseAttrs *param = attrs.as<ReverseAttrs>();
+  CHECK(param != nullptr);
+  return { topi::flip(inputs[0], param->axis) };
+}
+
+Expr MakeReverse(Expr data,
+                 int axis) {
+  auto attrs = make_node<ReverseAttrs>();
+  attrs->axis = axis;
+  static const Op& op = Op::Get("reverse");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.reverse")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 2>(MakeReverse, args, rv);
+});
+
+RELAY_REGISTER_OP("reverse")
+.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
+
+- **data**: The input data to the operator.
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.Reverse")
+.add_argument("data", "Tensor", "The input tensor.")
+.set_support_level(3)
+.add_type_rel("Reverse", ReverseRel)
+.set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
+.set_attr<TOpPattern>("TOpPattern", kInjective);
+
 // where operator
 bool WhereRel(const Array<Type>& types,
               int num_inputs,
index e762c7d..eee0bcf 100644 (file)
@@ -491,6 +491,25 @@ def test_arange():
     verify_arange(20, 1, -1.5)
 
 
+def test_reverse():
+    def verify_reverse(dshape, axis):
+        x = relay.var("x", relay.TensorType(dshape, "float32"))
+        z = relay.reverse(x, axis=axis)
+        zz = relay.ir_pass.infer_type(z)
+
+        func = relay.Function([x], z)
+        x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
+        ref_res = np.flip(x_data, axis)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(x_data)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+    verify_reverse((2, 3, 4), 1)
+    verify_reverse((4, 7), 0)
+    verify_reverse((2, 3, 4), -1)
+
+
 if __name__ == "__main__":
     test_cast()
     test_zeros_ones()
@@ -515,3 +534,4 @@ if __name__ == "__main__":
     test_squeeze_bad_axes_infer_type()
     test_split_infer_type()
     test_arange()
+    test_reverse()