[relay][topi] Add operation relay.nn.dilate() which calls topi.nn.dilate() (#5331)
authornotoraptor <notoraptor@users.noreply.github.com>
Mon, 27 Apr 2020 07:36:35 +0000 (03:36 -0400)
committerGitHub <noreply@github.com>
Mon, 27 Apr 2020 07:36:35 +0000 (16:36 +0900)
* Add operation relay.nn.dilate() which calls topi.nn.dilate().

* Fix typo

* Set op pattern to injective

include/tvm/relay/attrs/nn.h
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/nn/nn.py
python/tvm/relay/op/op_attrs.py
src/relay/op/nn/nn.cc
tests/python/relay/test_any.py

index f985a90..fdf56a7 100644 (file)
@@ -442,6 +442,16 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
   }
 };
 
+/*! \brief Attributes used in dilate operator */
+struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
+  Array<IndexExpr> strides;
+
+  TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") {
+    TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+      .describe("Dilation stride on each dimension, 1 means no dilation.");
+  }
+};
+
 /*! \brief Attributes used in 1D transposed convolution operator */
 struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
   IndexExpr channels;
index 5f6aa89..ad8c654 100644 (file)
@@ -502,6 +502,15 @@ reg.register_reduce_schedule("nn.cross_entropy")
 reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
 
 
+# dilate
+@reg.register_compute("nn.dilate")
+def compute_dilate(attrs, inputs, out_dtype):
+    return [topi.nn.dilate(inputs[0], attrs.strides)]
+
+reg.register_broadcast_schedule("nn.dilate")
+reg.register_pattern("nn.dilate", OpPattern.INJECTIVE)
+
+
 # cross_entropy_with_logits
 @reg.register_compute("nn.cross_entropy_with_logits")
 def compute_cross_entropy_with_logits(attrs, inputs, out_dtype):
@@ -697,6 +706,21 @@ def pad_shape_func(attrs, inputs, _):
         pad_width.append(get_const_tuple(pair))
     return [_pad_shape_func(inputs[0], convert(pad_width))]
 
+@script
+def _dilate_shape_func(data_shape, strides):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(out.shape[0]):
+        out[i] = (data_shape[i] - 1) * strides[i] + 1
+
+    return out
+
+@reg.register_shape_func("nn.dilate", False)
+def dilate_shape_func(attrs, inputs, _):
+    """
+    Shape function for dilate op.
+    """
+    return [_dilate_shape_func(inputs[0], convert(attrs.strides))]
+
 reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
 reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
 reg.register_shape_func("nn.relu", False, elemwise_shape_func)
index 622b0fa..c879eb6 100644 (file)
@@ -1347,6 +1347,25 @@ def pad(data,
     return _make.pad(data, pad_width, pad_value, pad_mode)
 
 
+def dilate(data, strides):
+    """Dilate data with zeros.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        n-D, can be any layout.
+
+    strides : <tuple of <int>
+        Dilation stride on each dimension, 1 means no dilation.
+
+    Returns
+    -------
+    Output : tvm.relay.Expr
+        The computed result
+    """
+    return _make.dilate(data, strides)
+
+
 def mirror_pad(data,
                pad_width,
                mode="SYMMETRIC"):
index a47be76..a1c73ef 100644 (file)
@@ -350,6 +350,11 @@ class Conv2DTransposeAttrs(Attrs):
     """Attributes used in Transposed Conv2D operators"""
 
 
+@tvm._ffi.register_object("relay.attrs.DilateAttrs")
+class DilateAttrs(Attrs):
+    """Attributes used in dilate operators"""
+
+
 @tvm._ffi.register_object("relay.attrs.SubPixelAttrs")
 class SubPixelAttrs(Attrs):
     """Attributes used in depth to space and space to depth operators"""
index 5cdca80..3c9b077 100644 (file)
@@ -1035,6 +1035,54 @@ Do log on the data - do not accept logits.
 .add_type_rel("CrossEntropy", CrossEntropyRel);
 
 
+// relay.nn.dilate
+TVM_REGISTER_NODE_TYPE(DilateAttrs);
+
+bool DilateRel(const Array<Type>& types,
+               int num_inputs,
+               const Attrs& attrs,
+               const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* x = types[0].as<TensorTypeNode>();
+  const DilateAttrs* param = attrs.as<DilateAttrs>();
+  if (x == nullptr) return false;
+  CHECK_EQ(x->shape.size(), param->strides.size());
+
+  std::vector<IndexExpr> oshape;
+  for (size_t i = 0; i < param->strides.size(); ++i) {
+    if (!x->shape[i].as<tir::AnyNode>()) {
+      oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1);
+    } else {
+      oshape.push_back(x->shape[i]);
+    }
+  }
+
+  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), x->dtype));
+  return true;
+}
+
+// Positional relay function to create dilate operator used by frontend FFI.
+Expr MakeDilate(Expr data, Array<IndexExpr> strides) {
+  auto attrs = make_object<DilateAttrs>();
+  attrs->strides = std::move(strides);
+  static const Op& op = Op::Get("nn.dilate");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate")
+.set_body_typed(MakeDilate);
+
+
+RELAY_REGISTER_OP("nn.dilate")
+.describe(R"code(
+Dilate data with zeros.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.add_argument("x", "1D Tensor", "Data to dilate.")
+.set_support_level(10)
+.add_type_rel("Dilate", DilateRel);
+
 // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
 Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
   static const Op& op = Op::Get("nn.cross_entropy_with_logits");
index aa81e31..6ce59bb 100644 (file)
@@ -508,6 +508,34 @@ def test_any_pad():
     verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3))
     verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))
 
+def verify_any_dilate(data_shape, strides, static_data_shape):
+    assert len(data_shape) == len(strides)
+    mod = tvm.IRModule()
+    dtype = "float32"
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.nn.dilate(data, strides)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    ref_shape = tuple((static_data_shape[i] - 1) * strides[i] + 1
+                      for i in range(len(static_data_shape)))
+    ref_out = np.zeros(shape=ref_shape, dtype=dtype)
+    ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np
+
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        tvm.testing.assert_allclose(result.asnumpy(), ref_out)
+
+def test_any_dilate():
+    verify_any_dilate(any_dims(1), (1,), (1,))
+    verify_any_dilate(any_dims(1), (1,), (5,))
+    verify_any_dilate(any_dims(1), (5,), (5,))
+    verify_any_dilate(any_dims(3), (1, 1, 1), (1, 2, 3))
+    verify_any_dilate(any_dims(3), (1, 1, 2), (1, 2, 3))
+    verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3))
+    verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3))
+    verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4))
+
 def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
     mod = tvm.IRModule()
     dtype = "float32"