// relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);
+Array<Array<Layout> > PadInferCorrectLayout(
+ const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<Array<IndexExpr>> &old_in_shapes) {
+ // NOTE: Discard "const" qualifier here.
+ PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
+
+ Layout ret;
+ // If new_in_layouts are defined, this code tries to modify the layout.
+ bool is_layout_modified = new_in_layouts.defined();
+ if (new_in_layouts.defined()) {
+ // Create a map of axis to param_width. For the new layout, a new param_width is generated using
+ // the map. The new layout is rejected, if the padding is happening along the axis which was
+ // split.
+
+ // 1) Create a map from axis to param_width using old layout.
+ std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width;
+ int index_counter = 0;
+ CHECK_EQ(new_in_layouts.size(), 1);
+ CHECK_EQ(old_in_layouts.size(), 1);
+ for (auto iter_var : old_in_layouts[0]->axes) {
+ const auto& old_layout_axis = LayoutAxis::Get(iter_var);
+ axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
+ index_counter++;
+ }
+
+ // 2) Create new pad width by walking over the new layout and using the map.
+ tvm::Array<tvm::Array<tvm::Expr>> new_pad_width;
+ for (auto iter_var : new_in_layouts[0]->axes) {
+ const auto& new_layout_axis = LayoutAxis::Get(iter_var);
+ auto axis_name = new_layout_axis.name();
+ if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) {
+ // This is primal axis. So, directly use the original pad_width.
+ new_pad_width.push_back(axis_pad_width.at(axis_name));
+ } else {
+ // This is the axis that got split. So, check that pad_width was [0, 0] originally.
+ const auto& dual_axis = new_layout_axis.ToPrimal();
+ auto dual_axis_name = dual_axis.name();
+ CHECK(axis_pad_width.count(dual_axis_name))
+ << "Missing axis " << dual_axis << " in " << old_in_layouts[0].name();
+ new_pad_width.push_back(axis_pad_width.at(dual_axis_name));
+
+ // If any pad_width element is not zero, do not change the layout.
+ for (auto width : axis_pad_width.at(dual_axis_name)) {
+ if (auto* width_imm = width.as<IntImm>()) {
+ if (width_imm->value != 0) {
+ is_layout_modified = false;
+ }
+ } else {
+ is_layout_modified = false;
+ }
+ }
+ }
+ }
+
+ // If the above conditions satisfied, we can set the newly created pad_width and use the new
+ // layout.
+ if (is_layout_modified) {
+ ret = new_in_layouts[0];
+ params->pad_width = new_pad_width;
+ }
+ }
+
+ if (!is_layout_modified) {
+ if (old_in_layouts.defined()) {
+ CHECK_EQ(old_in_layouts.size(), 1);
+ ret = old_in_layouts[0];
+ } else {
+ ret = Layout::Undef();
+ }
+ }
+
+ return Array<Array<Layout> >{{ret}, {ret}};
+}
+
bool PadRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Pad", PadRel)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);
assert(analysis.alpha_equal(a, b))
+def test_alter_layout_pad():
+ """ Check NCHW, NHWC and corner case for pad layout conversion"""
+ # Register alter op layout. "level" is used to override the previously registered functions.
+ @register_alter_op_layout("nn.conv2d", level=112)
+ def alter_conv2d(attrs, inputs, tinfos):
+ data, weight = inputs
+ new_attrs = dict(attrs)
+ new_attrs['data_layout'] = 'NCHW16c'
+ return relay.nn.conv2d(data, weight, **new_attrs)
+
+ # Check NCHW conversion.
+ def before_nchw():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var('weight1')
+ y = relay.nn.conv2d(x, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)))
+ y = relay.Function(analysis.free_vars(ret), ret)
+ return y
+
+ def expected_nchw():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var('weight1')
+ y = relay.layout_transform(x, "NCHW", "NCHW16c")
+ y = relay.nn.conv2d(y, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW16c")
+ ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
+ ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
+ y = relay.Function(analysis.free_vars(ret), ret)
+ return y
+
+ a = before_nchw()
+ a = run_opt_pass(a, transform.AlterOpLayout())
+
+ b = expected_nchw()
+ b = run_opt_pass(b, transform.InferType())
+
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+ # Check NHWC conversion.
+ def before_nhwc():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight1 = relay.var('weight1')
+ y = relay.nn.conv2d(x, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout='NHWC')
+ ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (0, 0)))
+ y = relay.Function(analysis.free_vars(ret), ret)
+ return y
+
+ def expected_nhwc():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight1 = relay.var('weight1')
+ y = relay.layout_transform(x, "NHWC", "NCHW16c")
+ y = relay.nn.conv2d(y, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW16c")
+ ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
+ ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
+ y = relay.Function(analysis.free_vars(ret), ret)
+ return y
+
+ a = before_nhwc()
+ a = run_opt_pass(a, transform.AlterOpLayout())
+
+ b = expected_nhwc()
+ b = run_opt_pass(b, transform.InferType())
+
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+ # Check that conversion does not happen when padding along split axis..
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var('weight1')
+ y = relay.nn.conv2d(x, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
+ y = relay.Function(analysis.free_vars(ret), ret)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var('weight1')
+ y = relay.layout_transform(x, "NCHW", "NCHW16c")
+ y = relay.nn.conv2d(y, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW16c")
+ ret = relay.layout_transform(y, "NCHW16c", "NCHW")
+ ret = relay.nn.pad(ret, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
+ y = relay.Function(analysis.free_vars(ret), ret)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.AlterOpLayout())
+
+ b = expected()
+ b = run_opt_pass(b, transform.InferType())
+
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
def test_alter_layout_pool():
""" Check NCHW, NHWC pool layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
test_alter_layout_strided_slice()
test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
+ test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()