[RELAY][DYN] Implementation of the dynamic pad operator (#6284)
authorLily Orth-Smith <lilyorthsmith@gmail.com>
Thu, 20 Aug 2020 01:57:05 +0000 (18:57 -0700)
committerGitHub <noreply@github.com>
Thu, 20 Aug 2020 01:57:05 +0000 (18:57 -0700)
16 files changed:
include/tvm/relay/attrs/nn.h
include/tvm/topi/nn.h
python/tvm/relay/op/dyn/nn/__init__.py [new file with mode: 0644]
python/tvm/relay/op/dyn/nn/_make.py [new file with mode: 0644]
python/tvm/relay/op/dyn/nn/_nn.py [new file with mode: 0644]
python/tvm/relay/op/nn/nn.py
python/tvm/te/hybrid/calls.py
python/tvm/te/hybrid/runtime.py
python/tvm/topi/nn/upsampling.py
src/relay/op/dyn/nn/pad.cc [new file with mode: 0644]
src/relay/op/make_op.h
src/relay/op/nn/pad.cc
src/relay/transforms/dynamic_to_static.cc
src/relay/transforms/pattern_util.h
tests/python/relay/dyn/test_dynamic_op_level2.py [new file with mode: 0644]
tests/python/relay/test_pass_dynamic_to_static.py

index 5f1ee2f..6bfdb49 100644 (file)
@@ -85,7 +85,7 @@ struct Conv1DAttrs : public tvm::AttrsNode<Conv1DAttrs> {
         .set_default(NullValue<IndexExpr>());
     TVM_ATTR_FIELD(kernel_size)
         .describe("Specifies the dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(data_layout)
         .set_default("NCW")
         .describe(
@@ -148,7 +148,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
         .set_default(NullValue<IndexExpr>());
     TVM_ATTR_FIELD(kernel_size)
         .describe("Specifies the dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(data_layout)
         .set_default("NCHW")
         .describe(
@@ -242,7 +242,7 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
         .set_default(NullValue<IndexExpr>());
     TVM_ATTR_FIELD(kernel_size)
         .describe("Specifies the dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(data_layout)
         .set_default("NCHW")
         .describe(
@@ -331,7 +331,7 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
         .set_default(NullValue<IndexExpr>());
     TVM_ATTR_FIELD(kernel_size)
         .describe("Specifies the dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(data_layout)
         .set_default("NCDHW")
         .describe(
@@ -381,7 +381,7 @@ struct Conv3DTransposeAttrs : public tvm::AttrsNode<Conv3DTransposeAttrs> {
             "i.e. the number of output channels in the convolution.");
     TVM_ATTR_FIELD(kernel_size)
         .describe("The dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(strides)
         .set_default(Array<IndexExpr>({1, 1, 1}))
         .describe("The strides of the convolution.");
@@ -480,7 +480,7 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
         .set_default(NullValue<IndexExpr>());
     TVM_ATTR_FIELD(kernel_size)
         .describe("Specifies the dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(data_layout)
         .set_default("NCDHW")
         .describe(
@@ -539,7 +539,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
             "i.e. the number of output channels in the convolution.");
     TVM_ATTR_FIELD(kernel_size)
         .describe("The dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(strides)
         .set_default(Array<IndexExpr>({1, 1}))
         .describe("The strides of the convolution.");
@@ -626,7 +626,7 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
             "i.e. the number of output channels in the convolution.");
     TVM_ATTR_FIELD(kernel_size)
         .describe("The dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(strides)
         .set_default(Array<IndexExpr>({1}))
         .describe("The strides of the convolution.");
@@ -1016,7 +1016,7 @@ struct UpSampling3DAttrs : public tvm::AttrsNode<UpSampling3DAttrs> {
 /*! \brief Attributes used for the padding operator */
 struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
   double pad_value;
-  Array<Array<IndexExpr> > pad_width;
+  Array<Array<Integer>> pad_width;
   std::string pad_mode;
 
   TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
@@ -1037,7 +1037,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
 /*! \brief Attributes used for the MirrorPadding operator */
 struct MirrorPadAttrs : public tvm::AttrsNode<MirrorPadAttrs> {
   std::string mode;
-  Array<Array<IndexExpr> > pad_width;
+  Array<Array<IndexExpr>> pad_width;
 
   TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") {
     TVM_ATTR_FIELD(mode)
@@ -1242,7 +1242,7 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
         .set_default(NullValue<IndexExpr>());
     TVM_ATTR_FIELD(kernel_size)
         .describe("Specifies the dimensions of the convolution window.")
-        .set_default(NullValue<Array<IndexExpr> >());
+        .set_default(NullValue<Array<IndexExpr>>());
     TVM_ATTR_FIELD(data_layout)
         .set_default("NCHW")
         .describe(
index 17eb0d0..d257d3c 100644 (file)
@@ -124,6 +124,8 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl
  * "constant" pads with constant_value;
  * "edge" pads using the edge values of the input array;
  * "reflect" pads by reflecting values with respect to the edges.
+ * \param dyn_output_shape Output shape of the pad op, default nullptr.
+ * You only need to pass this in if the shape was evaluated dynamically.
  * \param name The name of the operation
  * \param tag The tag to mark the operation
  *
@@ -151,30 +153,40 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl
 inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array<tvm::PrimExpr>& pad_before,
                            tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
                            PrimExpr pad_value = PrimExpr(), std::string name = "T_pad",
-                           std::string tag = kElementWise, std::string pad_mode = "constant") {
+                           std::string tag = kElementWise, std::string pad_mode = "constant",
+                           const Array<PrimExpr>* dyn_output_shape = nullptr) {
   if (pad_after.size() < pad_before.size()) {
     for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
       pad_after.push_back(pad_before[i]);
     }
   }
+
   arith::Analyzer analyzer;
   CHECK_GE(pad_before.size(), 1);
   CHECK_EQ(pad_before.size(), pad_after.size());
-  tvm::Array<tvm::PrimExpr> output_shape;
   tvm::Array<tvm::PrimExpr> pad_before_int32;
   tvm::Array<tvm::PrimExpr> pad_after_int32;
+
   for (const auto& ele : pad_before) {
     pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
   }
   for (const auto& ele : pad_after) {
     pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
   }
-  for (size_t i = 0; i < t->shape.size(); ++i) {
-    if (i >= pad_before.size()) {
-      output_shape.push_back(t->shape[i]);
-    } else {
-      output_shape.push_back(
-          analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
+
+  tvm::Array<tvm::PrimExpr> output_shape;
+  if (dyn_output_shape == nullptr) {
+    for (size_t i = 0; i < t->shape.size(); ++i) {
+      if (i >= pad_before.size()) {
+        output_shape.push_back(t->shape[i]);
+      } else {
+        output_shape.push_back(
+            analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
+      }
+    }
+  } else {
+    for (size_t i = 0; i < dyn_output_shape->size(); i++) {
+      output_shape.push_back((*dyn_output_shape)[i]);
     }
   }
 
diff --git a/python/tvm/relay/op/dyn/nn/__init__.py b/python/tvm/relay/op/dyn/nn/__init__.py
new file mode 100644 (file)
index 0000000..01a3a1b
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import, redefined-builtin, invalid-name
+"""The Relay namespace containing dynamic ops."""
+
+from . import _nn
diff --git a/python/tvm/relay/op/dyn/nn/_make.py b/python/tvm/relay/op/dyn/nn/_make.py
new file mode 100644 (file)
index 0000000..280fe72
--- /dev/null
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constructor APIs"""
+import tvm._ffi
+
+tvm._ffi._init_api("relay.op.dyn.nn._make", __name__)
diff --git a/python/tvm/relay/op/dyn/nn/_nn.py b/python/tvm/relay/op/dyn/nn/_nn.py
new file mode 100644 (file)
index 0000000..141fc22
--- /dev/null
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
+"""Backend compiler related feature registration"""
+
+from __future__ import absolute_import
+
+from tvm.te.hybrid import script
+from ...op import register_shape_func
+from ...op import register_broadcast_schedule
+
+# pad
+register_broadcast_schedule("dyn.nn.pad")
+
+#####################
+#  Shape functions  #
+#####################
+
+@script
+def _dyn_pad_shape_func(data, pad_width):
+    ndim = len(data.shape)
+    out = output_tensor((ndim,), "int64")
+    for i in const_range(ndim):
+        out[i] = int64(pad_width[i, 0] + pad_width[i, 1] + data.shape[i])
+    return out
+
+@register_shape_func("dyn.nn.pad", True)
+def pad_shape_func(attrs, inputs, data):
+    """
+    Shape function for dynamic pad op.
+    """
+    return [_dyn_pad_shape_func(inputs[0], inputs[1])]
index b2df850..c04db30 100644 (file)
@@ -19,7 +19,9 @@
 from tvm.relay import expr
 
 from . import _make
+from ..dyn.nn import _make as _dyn_make
 from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
+from ...expr import const, Expr
 
 
 def conv1d(data,
@@ -1410,7 +1412,7 @@ def prelu(data, alpha, axis=1):
 
 def pad(data,
         pad_width,
-        pad_value=0.0,
+        pad_value=0,
         pad_mode='constant'):
     r"""Padding
 
@@ -1421,10 +1423,10 @@ def pad(data,
     ----------
     data: tvm.relay.Expr
         The input data to the operator
-    pad_width: tuple of <tuple of <int>>, required
+    pad_width: tuple of <tuple of <int>>, or tvm.relay.Expr, required
         Number of values padded to the edges of each axis, in the format
         of ((before_1, after_1), ..., (before_N, after_N))
-    pad_value: float, optional, default=0.0
+    pad_value: float, or tvm.relay.Expr, optional, default=0
         The value used for padding
     pad_mode: 'constant', 'edge', 'reflect'
         'constant' pads with constant_value pad_value
@@ -1435,6 +1437,12 @@ def pad(data,
     result : tvm.relay.Expr
         The computed result.
     """
+    if (isinstance(pad_width, Expr) or (isinstance(pad_value, Expr))):
+        if not isinstance(pad_width, Expr):
+            pad_width = const(list(pad_width))
+        if not isinstance(pad_value, Expr):
+            pad_value = const(pad_value)
+        return _dyn_make.pad(data, pad_width, pad_value, pad_mode)
     return _make.pad(data, pad_width, pad_value, pad_mode)
 
 
index 78ed1dc..88ade6e 100644 (file)
@@ -73,7 +73,7 @@ def _math_intrin(func_id, args):
     from tvm.tir import op
     return getattr(op, func_id)(*args)
 
-sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
+sqrt = log = exp = tanh = sigmoid = power = popcount = round = _math_intrin #pylint: disable=invalid-name
 
 
 def _min_max(func_id, args):
index 7dcfc7c..7987e46 100644 (file)
@@ -126,6 +126,7 @@ HYBRID_GLOBALS = {
     'exp'            : numpy.exp,
     'sigmoid'        : sigmoid,
     'popcount'       : popcount,
+    'round'          : round,
     'likely'         : lambda cond: cond,
     'uint8'          : numpy.uint8,
     'uint16'         : numpy.uint16,
index 96a13ef..d8da41f 100644 (file)
@@ -57,7 +57,6 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
     elif layout == "NHWC":
         out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)),
                      simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype)))
-
     else:
         raise ValueError("not support this layout {} yet".format(layout))
     coord_trans = "align_corners" if align_corners else "asymmetric"
diff --git a/src/relay/op/dyn/nn/pad.cc b/src/relay/op/dyn/nn/pad.cc
new file mode 100644 (file)
index 0000000..8a17f50
--- /dev/null
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file pad.cc
+ * \brief Implementation of dynamic pad
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+#include <tvm/tir/op.h>
+#include <tvm/topi/nn.h>
+
+#include <vector>
+
+#include "../../make_op.h"
+#include "../../op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace dyn {
+
+// relay.dyn.nn.pad
+
+bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+            const TypeReporter& reporter) {
+  // types = [data_type, pad_width_type, pad_value_type, ret_type]
+  CHECK_EQ(types.size(), 4);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const auto* pad_width = types[1].as<TensorTypeNode>();
+  if (pad_width == nullptr) return false;
+
+  const auto* pad_value = types[2].as<TensorTypeNode>();
+  if (pad_value == nullptr) return false;
+
+  int data_rank = data->shape.size();
+  CHECK(data_rank) << "Data shape must have static rank";
+
+  int pad_width_rank = pad_width->shape.size();
+  CHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D";
+
+  auto pad_width_dim1 = pad_width->shape[0].as<IntImmNode>();
+  auto pad_width_dim2 = pad_width->shape[1].as<IntImmNode>();
+
+  CHECK(pad_width_dim1->value == data_rank && pad_width_dim2->value == 2)
+      << "Pad width must have shape (N, 2), where N is the rank of input data";
+
+  const PadAttrs* param = attrs.as<PadAttrs>();
+  CHECK(param != nullptr);
+
+  std::vector<IndexExpr> oshape;
+  for (int i = 0; i < data_rank; i++) {
+    oshape.push_back(Any());
+  }
+
+  reporter->Assign(types[3], TensorType(oshape, data->dtype));
+  return true;
+}
+
+Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                             const Type& out_type) {
+  const auto* param = attrs.as<PadAttrs>();
+  CHECK(param);
+
+  auto data = inputs[0];
+  auto pad_width = inputs[1];
+
+  const PrimExpr& pad_value = inputs[2](Array<PrimExpr>());
+
+  Array<IndexExpr> pad_before;
+  Array<IndexExpr> pad_after;
+
+  for (int i = 0; i < pad_width->shape[0].as<IntImmNode>()->value; ++i) {
+    pad_before.push_back(pad_width[i][0]);
+    pad_after.push_back(pad_width[i][1]);
+  }
+
+  const auto* out_ttype = out_type.as<TensorTypeNode>();
+  CHECK(out_ttype != nullptr);
+
+  return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad",
+                                     topi::kElementWise, param->pad_mode,
+                                     &out_type.as<TensorTypeNode>()->shape)};
+}
+
+// Handler to create a call to the padding op used by front-end FFI
+Expr MakePad(Expr data, Expr pad_width, Expr pad_value, String pad_mode) {
+  auto attrs = make_object<PadAttrs>();
+  attrs->pad_mode = std::move(pad_mode);
+  static const Op& op = Op::Get("dyn.nn.pad");
+  return Call(op, {data, pad_width, pad_value}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.pad").set_body_typed(MakePad);
+
+RELAY_REGISTER_OP("dyn.nn.pad")
+    .describe(R"code(Pad for n-D tensor.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<PadAttrs>()
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor", "Tensor that will be padded")
+    .add_argument("pad_width", "Tensor", "Tensor of how much to pad by")
+    .add_argument("pad_val", "double", "The value to fill the padded area with")
+    .set_support_level(2)
+    .add_type_rel("DynamicPad", PadRel)
+    .set_attr<TOpPattern>("TOpPattern", kInjective)
+    .set_attr<FTVMCompute>("FTVMCompute", PadCompute);
+
+}  // namespace dyn
+}  // namespace relay
+}  // namespace tvm
index 1e17bbe..c759be3 100644 (file)
@@ -54,7 +54,7 @@ Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);
 
 Expr MakeOnes(Array<Integer> shape, DataType dtype);
 
-Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, String pad_mode);
+Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value, String pad_mode);
 
 Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name);
 
index d710360..45447e1 100644 (file)
@@ -53,7 +53,7 @@ Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layou
     // split.
 
     // 1) Create a map from axis to param_width using old layout.
-    std::map<std::string, tvm::Array<tvm::PrimExpr>> axis_pad_width;
+    std::map<std::string, tvm::Array<Integer>> axis_pad_width;
     int index_counter = 0;
     CHECK_EQ(new_in_layouts.size(), 1);
     CHECK_EQ(old_in_layouts.size(), 1);
@@ -64,7 +64,7 @@ Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layou
     }
 
     // 2) Create new pad width by walking over the new layout and using the map.
-    tvm::Array<tvm::Array<tvm::PrimExpr>> new_pad_width;
+    tvm::Array<tvm::Array<Integer>> new_pad_width;
     for (auto iter_var : new_in_layouts[0]->axes) {
       const auto& new_layout_axis = LayoutAxis::Get(iter_var);
       auto axis_name = new_layout_axis.name();
@@ -178,7 +178,7 @@ Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs
 }
 
 // Handler to create a call to the padding op used by front-end FFI
-Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, String pad_mode) {
+Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value, String pad_mode) {
   auto attrs = make_object<PadAttrs>();
   attrs->pad_value = pad_value;
   attrs->pad_width = std::move(pad_width);
index 0ccc4c3..3de773e 100644 (file)
@@ -124,6 +124,21 @@ class DynamicToStaticMutator : public MixedModeMutator {
            }
            return Expr(nullptr);
          }},
+        {Op::Get("dyn.nn.pad"),
+         [](const CallNode* call_node) {
+           const ConstantNode* pad_width = call_node->args[1].as<ConstantNode>();
+           const ConstantNode* pad_fill = call_node->args[2].as<ConstantNode>();
+           if (pad_width && pad_fill) {
+             CHECK_EQ(pad_fill->data->ndim, 0);   // pad_val is 1d
+             CHECK_EQ(pad_width->data->ndim, 2);  // pad_width is 2d
+
+             const PadAttrs* param = call_node->attrs.as<PadAttrs>();
+             CHECK(param);
+             return MakePad(call_node->args[0], ToMatrix(pad_width->data), ToScalar(pad_fill->data),
+                            param->pad_mode);
+           }
+           return Expr(nullptr);
+         }},
     };
   }
 
index 0b64846..f493720 100644 (file)
@@ -419,7 +419,7 @@ static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0)
  */
 static inline Array<Integer> ToVector(const runtime::NDArray& array) {
   size_t ndim = array.Shape().size();
-  CHECK_EQ(ndim, 1) << "This function should only used for shape tensor.";
+  CHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays";
   size_t len = array.Shape().front();
   Array<Integer> out;
   for (size_t i = 0; i < len; ++i) {
@@ -429,6 +429,30 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) {
   return out;
 }
 
+/*!
+ * \brief Convert a NDArray with type int or float to Array<Array<Integer>>.
+ * \param array Input NDArray
+ * \return Converted Array.
+ */
+static inline Array<Array<Integer>> ToMatrix(const runtime::NDArray& array) {
+  size_t ndim = array.Shape().size();
+  CHECK_EQ(ndim, 2) << "This function should only used for 2D NDArrays";
+  size_t dim1 = array.Shape().at(0);
+  size_t dim2 = array.Shape().at(1);
+
+  Array<Array<Integer>> out;
+
+  for (size_t i = 0; i < dim1; ++i) {
+    Array<Integer> inner_out;
+    for (size_t j = 0; j < dim2; ++j) {
+      double elem_val = ToScalar(array, i * dim2 + j);
+      inner_out.push_back(Integer(static_cast<int>(elem_val)));
+    }
+    out.push_back(inner_out);
+  }
+  return out;
+}
+
 inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); }
 
 inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); }
@@ -629,7 +653,11 @@ static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexE
 
 static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value,
                        std::string pad_mode) {
-  return MakePad(data, pad_width, pad_value, pad_mode);
+  Array<Array<Integer>> pad_width_int;
+  for (size_t i = 0; i < pad_width.size(); ++i) {
+    pad_width_int.push_back(CheckConstantShapeArrayInteger(pad_width[i]));
+  }
+  return MakePad(data, pad_width_int, pad_value, pad_mode);
 }
 
 static inline Expr Tile(Expr data, Array<Integer> reps) { return MakeTile(data, reps); }
diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py
new file mode 100644 (file)
index 0000000..137febd
--- /dev/null
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Support level2 dynamic operator test cases.
+"""
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm import te
+from tvm.relay.testing import ctx_list
+import random
+from test_dynamic_op_level3 import verify_func
+import tvm.topi.testing
+from tvm.relay.testing import run_infer_type
+
+def test_dyn_pad():
+    def verify_pad(dshape, pad_width, pad_val, dtype):
+        x = relay.var("x", relay.TensorType(dshape, dtype))
+        ndim = len(dshape)
+        pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64'))
+        pad_val_var = relay.var("pad_val_var", relay.TensorType((), dtype))
+        y = relay.nn.pad(x, pad_width_var, pad_val_var)
+        yy = run_infer_type(y)
+
+        assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * ndim, dtype)
+        func = relay.Function([x, pad_width_var, pad_val_var], y)
+        data = np.random.uniform(size=dshape).astype(dtype)
+        ref_res = np.pad(data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * ndim)
+        pad_width = np.array(pad_width).astype('int64')
+
+        verify_func(func, [data, pad_width, np.array(pad_val).astype(dtype)], ref_res)
+
+    def verify_pad_default_fill(dshape, pad_width, dtype):
+        x = relay.var("x", relay.TensorType(dshape, dtype))
+        ndim = len(dshape)
+        pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64'))
+        y = relay.nn.pad(x, pad_width_var)
+        yy = run_infer_type(y)
+
+        assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * ndim, dtype)
+        func = relay.Function([x, pad_width_var], y)
+        data = np.random.uniform(size=dshape).astype(dtype)
+        ref_res = np.pad(data, pad_width)
+        pad_width = np.array(pad_width).astype('int64')
+
+        verify_func(func, [data, pad_width], ref_res)
+
+    verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
+    verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64")
+    verify_pad_default_fill((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), "float64")
+    verify_pad_default_fill((2, 7), ((1, 4), (2, 2)), "int32")
+
+if __name__ == "__main__":
+    test_dyn_pad()
index c61f169..ed9b94c 100644 (file)
@@ -23,7 +23,6 @@ from tvm.relay.build_module import bind_params_by_name
 from tvm.relay.testing import run_infer_type, create_workload, ctx_list
 import tvm.topi.testing
 
-
 def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, tvm.transform.Pass)
 
@@ -312,7 +311,7 @@ def test_dynamic_to_static_full():
         
         zz = func2.body
         assert isinstance(zz, relay.Call)
-        assert zz.checked_type == relay.TensorType(fill_shape, dtype)
+        assert zz.op == relay.op.get("full")
 
         ref_res = np.full(fill_shape, fill_value).astype(dtype)
         y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64')
@@ -321,6 +320,24 @@ def test_dynamic_to_static_full():
     verify_full(4, (1, 2, 3, 4), 'int32')
     verify_full(4.0, (1, 2, 8, 10), 'float32')
 
+def test_dynamic_to_static_pad():
+    def verify_pad(data_shape, pad_width, pad_val, dtype):
+        x = relay.var("x", relay.TensorType(data_shape, dtype))
+        z = relay.nn.pad(x, relay.const(np.array(pad_width)), pad_val)
+        func = run_infer_type(relay.Function([x], z))
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.op == relay.op.get("nn.pad")
+
+        x_data = np.random.uniform(size=data_shape).astype(dtype)
+        ref_res = np.pad(x_data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * len(data_shape))
+        verify_func(func2, [x_data], ref_res)
+
+    verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
+    verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64")
+
+
 if __name__ == "__main__":
     test_dynamic_to_static_reshape()
     test_dynamic_to_static_double_reshape()
@@ -332,3 +349,4 @@ if __name__ == "__main__":
     test_dynamic_to_static_resize()
     test_dynamic_to_static_one_hot()
     test_dynamic_to_static_full()
+    test_dynamic_to_static_pad()