.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(data_layout)
.set_default("NCW")
.describe(
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(data_layout)
.set_default("NCHW")
.describe(
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(data_layout)
.set_default("NCHW")
.describe(
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(data_layout)
.set_default("NCDHW")
.describe(
"i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
.describe("The dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1, 1}))
.describe("The strides of the convolution.");
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(data_layout)
.set_default("NCDHW")
.describe(
"i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
.describe("The dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1, 1}))
.describe("The strides of the convolution.");
"i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
.describe("The dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(strides)
.set_default(Array<IndexExpr>({1}))
.describe("The strides of the convolution.");
/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
double pad_value;
- Array<Array<IndexExpr> > pad_width;
+ Array<Array<Integer>> pad_width;
std::string pad_mode;
TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
/*! \brief Attributes used for the MirrorPadding operator */
struct MirrorPadAttrs : public tvm::AttrsNode<MirrorPadAttrs> {
std::string mode;
- Array<Array<IndexExpr> > pad_width;
+ Array<Array<IndexExpr>> pad_width;
TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") {
TVM_ATTR_FIELD(mode)
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
- .set_default(NullValue<Array<IndexExpr> >());
+ .set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(data_layout)
.set_default("NCHW")
.describe(
* "constant" pads with constant_value;
* "edge" pads using the edge values of the input array;
* "reflect" pads by reflecting values with respect to the edges.
+ * \param dyn_output_shape Output shape of the pad op, default nullptr.
+ * You only need to pass this in if the shape was evaluated dynamically.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array<tvm::PrimExpr>& pad_before,
tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
PrimExpr pad_value = PrimExpr(), std::string name = "T_pad",
- std::string tag = kElementWise, std::string pad_mode = "constant") {
+ std::string tag = kElementWise, std::string pad_mode = "constant",
+ const Array<PrimExpr>* dyn_output_shape = nullptr) {
if (pad_after.size() < pad_before.size()) {
for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
pad_after.push_back(pad_before[i]);
}
}
+
arith::Analyzer analyzer;
CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size());
- tvm::Array<tvm::PrimExpr> output_shape;
tvm::Array<tvm::PrimExpr> pad_before_int32;
tvm::Array<tvm::PrimExpr> pad_after_int32;
+
for (const auto& ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}
for (const auto& ele : pad_after) {
pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}
- for (size_t i = 0; i < t->shape.size(); ++i) {
- if (i >= pad_before.size()) {
- output_shape.push_back(t->shape[i]);
- } else {
- output_shape.push_back(
- analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
+
+ tvm::Array<tvm::PrimExpr> output_shape;
+ if (dyn_output_shape == nullptr) {
+ for (size_t i = 0; i < t->shape.size(); ++i) {
+ if (i >= pad_before.size()) {
+ output_shape.push_back(t->shape[i]);
+ } else {
+ output_shape.push_back(
+ analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
+ }
+ }
+ } else {
+ for (size_t i = 0; i < dyn_output_shape->size(); i++) {
+ output_shape.push_back((*dyn_output_shape)[i]);
}
}
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import, redefined-builtin, invalid-name
+"""The Relay namespace containing dynamic ops."""
+
+from . import _nn
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constructor APIs"""
+import tvm._ffi
+
+tvm._ffi._init_api("relay.op.dyn.nn._make", __name__)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
+"""Backend compiler related feature registration"""
+
+from __future__ import absolute_import
+
+from tvm.te.hybrid import script
+from ...op import register_shape_func
+from ...op import register_broadcast_schedule
+
+# pad
+register_broadcast_schedule("dyn.nn.pad")
+
+#####################
+# Shape functions #
+#####################
+
+@script
+def _dyn_pad_shape_func(data, pad_width):
+ ndim = len(data.shape)
+ out = output_tensor((ndim,), "int64")
+ for i in const_range(ndim):
+ out[i] = int64(pad_width[i, 0] + pad_width[i, 1] + data.shape[i])
+ return out
+
+@register_shape_func("dyn.nn.pad", True)
+def pad_shape_func(attrs, inputs, data):
+ """
+ Shape function for dynamic pad op.
+ """
+ return [_dyn_pad_shape_func(inputs[0], inputs[1])]
from tvm.relay import expr
from . import _make
+from ..dyn.nn import _make as _dyn_make
from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
+from ...expr import const, Expr
def conv1d(data,
def pad(data,
pad_width,
- pad_value=0.0,
+ pad_value=0,
pad_mode='constant'):
r"""Padding
----------
data: tvm.relay.Expr
The input data to the operator
- pad_width: tuple of <tuple of <int>>, required
+ pad_width: tuple of <tuple of <int>>, or tvm.relay.Expr, required
Number of values padded to the edges of each axis, in the format
of ((before_1, after_1), ..., (before_N, after_N))
- pad_value: float, optional, default=0.0
+ pad_value: float, or tvm.relay.Expr, optional, default=0
The value used for padding
pad_mode: 'constant', 'edge', 'reflect'
'constant' pads with constant_value pad_value
result : tvm.relay.Expr
The computed result.
"""
+ if (isinstance(pad_width, Expr) or (isinstance(pad_value, Expr))):
+ if not isinstance(pad_width, Expr):
+ pad_width = const(list(pad_width))
+ if not isinstance(pad_value, Expr):
+ pad_value = const(pad_value)
+ return _dyn_make.pad(data, pad_width, pad_value, pad_mode)
return _make.pad(data, pad_width, pad_value, pad_mode)
from tvm.tir import op
return getattr(op, func_id)(*args)
-sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
+sqrt = log = exp = tanh = sigmoid = power = popcount = round = _math_intrin #pylint: disable=invalid-name
def _min_max(func_id, args):
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
+ 'round' : round,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
elif layout == "NHWC":
out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)),
simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype)))
-
else:
raise ValueError("not support this layout {} yet".format(layout))
coord_trans = "align_corners" if align_corners else "asymmetric"
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file pad.cc
+ * \brief Implementation of dynamic pad
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+#include <tvm/tir/op.h>
+#include <tvm/topi/nn.h>
+
+#include <vector>
+
+#include "../../make_op.h"
+#include "../../op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace dyn {
+
+// relay.dyn.nn.pad
+
+bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // types = [data_type, pad_width_type, pad_value_type, ret_type]
+ CHECK_EQ(types.size(), 4);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) return false;
+
+ const auto* pad_width = types[1].as<TensorTypeNode>();
+ if (pad_width == nullptr) return false;
+
+ const auto* pad_value = types[2].as<TensorTypeNode>();
+ if (pad_value == nullptr) return false;
+
+ int data_rank = data->shape.size();
+ CHECK(data_rank) << "Data shape must have static rank";
+
+ int pad_width_rank = pad_width->shape.size();
+ CHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D";
+
+ auto pad_width_dim1 = pad_width->shape[0].as<IntImmNode>();
+ auto pad_width_dim2 = pad_width->shape[1].as<IntImmNode>();
+
+ CHECK(pad_width_dim1->value == data_rank && pad_width_dim2->value == 2)
+ << "Pad width must have shape (N, 2), where N is the rank of input data";
+
+ const PadAttrs* param = attrs.as<PadAttrs>();
+ CHECK(param != nullptr);
+
+ std::vector<IndexExpr> oshape;
+ for (int i = 0; i < data_rank; i++) {
+ oshape.push_back(Any());
+ }
+
+ reporter->Assign(types[3], TensorType(oshape, data->dtype));
+ return true;
+}
+
+Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ const auto* param = attrs.as<PadAttrs>();
+ CHECK(param);
+
+ auto data = inputs[0];
+ auto pad_width = inputs[1];
+
+ const PrimExpr& pad_value = inputs[2](Array<PrimExpr>());
+
+ Array<IndexExpr> pad_before;
+ Array<IndexExpr> pad_after;
+
+ for (int i = 0; i < pad_width->shape[0].as<IntImmNode>()->value; ++i) {
+ pad_before.push_back(pad_width[i][0]);
+ pad_after.push_back(pad_width[i][1]);
+ }
+
+ const auto* out_ttype = out_type.as<TensorTypeNode>();
+ CHECK(out_ttype != nullptr);
+
+ return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad",
+ topi::kElementWise, param->pad_mode,
+ &out_type.as<TensorTypeNode>()->shape)};
+}
+
+// Handler to create a call to the padding op used by front-end FFI
+Expr MakePad(Expr data, Expr pad_width, Expr pad_value, String pad_mode) {
+ auto attrs = make_object<PadAttrs>();
+ attrs->pad_mode = std::move(pad_mode);
+ static const Op& op = Op::Get("dyn.nn.pad");
+ return Call(op, {data, pad_width, pad_value}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.pad").set_body_typed(MakePad);
+
+RELAY_REGISTER_OP("dyn.nn.pad")
+ .describe(R"code(Pad for n-D tensor.
+
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<PadAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "Tensor that will be padded")
+ .add_argument("pad_width", "Tensor", "Tensor of how much to pad by")
+ .add_argument("pad_val", "double", "The value to fill the padded area with")
+ .set_support_level(2)
+ .add_type_rel("DynamicPad", PadRel)
+ .set_attr<TOpPattern>("TOpPattern", kInjective)
+ .set_attr<FTVMCompute>("FTVMCompute", PadCompute);
+
+} // namespace dyn
+} // namespace relay
+} // namespace tvm
Expr MakeOnes(Array<Integer> shape, DataType dtype);
-Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, String pad_mode);
+Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value, String pad_mode);
Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name);
// split.
// 1) Create a map from axis to param_width using old layout.
- std::map<std::string, tvm::Array<tvm::PrimExpr>> axis_pad_width;
+ std::map<std::string, tvm::Array<Integer>> axis_pad_width;
int index_counter = 0;
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_EQ(old_in_layouts.size(), 1);
}
// 2) Create new pad width by walking over the new layout and using the map.
- tvm::Array<tvm::Array<tvm::PrimExpr>> new_pad_width;
+ tvm::Array<tvm::Array<Integer>> new_pad_width;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& new_layout_axis = LayoutAxis::Get(iter_var);
auto axis_name = new_layout_axis.name();
}
// Handler to create a call to the padding op used by front-end FFI
-Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, String pad_mode) {
+Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value, String pad_mode) {
auto attrs = make_object<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
}
return Expr(nullptr);
}},
+ {Op::Get("dyn.nn.pad"),
+ [](const CallNode* call_node) {
+ const ConstantNode* pad_width = call_node->args[1].as<ConstantNode>();
+ const ConstantNode* pad_fill = call_node->args[2].as<ConstantNode>();
+ if (pad_width && pad_fill) {
+ CHECK_EQ(pad_fill->data->ndim, 0); // pad_val is 1d
+ CHECK_EQ(pad_width->data->ndim, 2); // pad_width is 2d
+
+ const PadAttrs* param = call_node->attrs.as<PadAttrs>();
+ CHECK(param);
+ return MakePad(call_node->args[0], ToMatrix(pad_width->data), ToScalar(pad_fill->data),
+ param->pad_mode);
+ }
+ return Expr(nullptr);
+ }},
};
}
*/
static inline Array<Integer> ToVector(const runtime::NDArray& array) {
size_t ndim = array.Shape().size();
- CHECK_EQ(ndim, 1) << "This function should only used for shape tensor.";
+ CHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays";
size_t len = array.Shape().front();
Array<Integer> out;
for (size_t i = 0; i < len; ++i) {
return out;
}
+/*!
+ * \brief Convert a NDArray with type int or float to Array<Array<Integer>>.
+ * \param array Input NDArray
+ * \return Converted Array.
+ */
+static inline Array<Array<Integer>> ToMatrix(const runtime::NDArray& array) {
+ size_t ndim = array.Shape().size();
+ CHECK_EQ(ndim, 2) << "This function should only used for 2D NDArrays";
+ size_t dim1 = array.Shape().at(0);
+ size_t dim2 = array.Shape().at(1);
+
+ Array<Array<Integer>> out;
+
+ for (size_t i = 0; i < dim1; ++i) {
+ Array<Integer> inner_out;
+ for (size_t j = 0; j < dim2; ++j) {
+ double elem_val = ToScalar(array, i * dim2 + j);
+ inner_out.push_back(Integer(static_cast<int>(elem_val)));
+ }
+ out.push_back(inner_out);
+ }
+ return out;
+}
+
inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); }
inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); }
static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value,
std::string pad_mode) {
- return MakePad(data, pad_width, pad_value, pad_mode);
+ Array<Array<Integer>> pad_width_int;
+ for (size_t i = 0; i < pad_width.size(); ++i) {
+ pad_width_int.push_back(CheckConstantShapeArrayInteger(pad_width[i]));
+ }
+ return MakePad(data, pad_width_int, pad_value, pad_mode);
}
static inline Expr Tile(Expr data, Array<Integer> reps) { return MakeTile(data, reps); }
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Support level2 dynamic operator test cases.
+"""
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm import te
+from tvm.relay.testing import ctx_list
+import random
+from test_dynamic_op_level3 import verify_func
+import tvm.topi.testing
+from tvm.relay.testing import run_infer_type
+
+def test_dyn_pad():
+ def verify_pad(dshape, pad_width, pad_val, dtype):
+ x = relay.var("x", relay.TensorType(dshape, dtype))
+ ndim = len(dshape)
+ pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64'))
+ pad_val_var = relay.var("pad_val_var", relay.TensorType((), dtype))
+ y = relay.nn.pad(x, pad_width_var, pad_val_var)
+ yy = run_infer_type(y)
+
+ assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * ndim, dtype)
+ func = relay.Function([x, pad_width_var, pad_val_var], y)
+ data = np.random.uniform(size=dshape).astype(dtype)
+ ref_res = np.pad(data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * ndim)
+ pad_width = np.array(pad_width).astype('int64')
+
+ verify_func(func, [data, pad_width, np.array(pad_val).astype(dtype)], ref_res)
+
+ def verify_pad_default_fill(dshape, pad_width, dtype):
+ x = relay.var("x", relay.TensorType(dshape, dtype))
+ ndim = len(dshape)
+ pad_width_var = relay.var("pad_width_var", relay.TensorType((ndim, 2), 'int64'))
+ y = relay.nn.pad(x, pad_width_var)
+ yy = run_infer_type(y)
+
+ assert yy.checked_type == relay.ty.TensorType((relay.Any(),) * ndim, dtype)
+ func = relay.Function([x, pad_width_var], y)
+ data = np.random.uniform(size=dshape).astype(dtype)
+ ref_res = np.pad(data, pad_width)
+ pad_width = np.array(pad_width).astype('int64')
+
+ verify_func(func, [data, pad_width], ref_res)
+
+ verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
+ verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64")
+ verify_pad_default_fill((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), "float64")
+ verify_pad_default_fill((2, 7), ((1, 4), (2, 2)), "int32")
+
+if __name__ == "__main__":
+ test_dyn_pad()
from tvm.relay.testing import run_infer_type, create_workload, ctx_list
import tvm.topi.testing
-
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, tvm.transform.Pass)
zz = func2.body
assert isinstance(zz, relay.Call)
- assert zz.checked_type == relay.TensorType(fill_shape, dtype)
+ assert zz.op == relay.op.get("full")
ref_res = np.full(fill_shape, fill_value).astype(dtype)
y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64')
verify_full(4, (1, 2, 3, 4), 'int32')
verify_full(4.0, (1, 2, 8, 10), 'float32')
+def test_dynamic_to_static_pad():
+ def verify_pad(data_shape, pad_width, pad_val, dtype):
+ x = relay.var("x", relay.TensorType(data_shape, dtype))
+ z = relay.nn.pad(x, relay.const(np.array(pad_width)), pad_val)
+ func = run_infer_type(relay.Function([x], z))
+ func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+ zz = func2.body
+ assert isinstance(zz, relay.Call)
+ assert zz.op == relay.op.get("nn.pad")
+
+ x_data = np.random.uniform(size=data_shape).astype(dtype)
+ ref_res = np.pad(x_data, pad_width, 'constant', constant_values=(((pad_val,)*2),) * len(data_shape))
+ verify_func(func2, [x_data], ref_res)
+
+ verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
+ verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64")
+
+
if __name__ == "__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_resize()
test_dynamic_to_static_one_hot()
test_dynamic_to_static_full()
+ test_dynamic_to_static_pad()