From d69c6fd8e144b4cec9c52cb521b1d10e22a9e52f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 10 Oct 2019 22:57:09 -0700 Subject: [PATCH] [Relay][AlterOp] NHWC to NCHWc support for Pool, pad, concatenate, sum. (#4059) --- python/tvm/relay/frontend/tflite.py | 10 +- src/relay/op/nn/pad.cc | 4 +- src/relay/op/nn/pooling.cc | 10 +- src/relay/op/tensor/reduce.cc | 58 +++++- src/relay/op/tensor/transform.cc | 26 ++- tests/python/relay/test_pass_alter_op_layout.py | 248 ++++++++++++++++++++++-- 6 files changed, 321 insertions(+), 35 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 8b91315..a519c6f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -748,10 +748,12 @@ class OperatorConverter(object): elif padding == Padding.SAME: pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) - in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), - (pad_top, pad_bottom), - (pad_left, pad_right), - (0, 0))) + do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0) + if do_pad: + in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), + (pad_top, pad_bottom), + (pad_left, pad_right), + (0, 0))) else: raise tvm.error.OpAttributeUnImplemented( 'Padding format {} is not supported for operator Conv.'.format(padding)) diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 331e50f..5127ee4 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 76dec99..503db41 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -47,15 +47,9 @@ Array > Pool2DInferCorrectLayout( T *params = const_cast(attrs.as()); if (new_in_layouts.defined()) { + // Set the pool with the new layout. CHECK_EQ(new_in_layouts.size(), 1); - - Layout raw_layout(params->layout); - Layout input = new_in_layouts[0]; - if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && - input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && - !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) { - params->layout = input.name(); // modify self to follow the input layout - } + params->layout = new_in_layouts[0].name(); } Layout inferred_layout(params->layout); diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index a7be3ff..e41cfda 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -119,6 +119,59 @@ Array GetExcludeAxes(size_t indim, return r_axes; } +// Return the modified layout for AlterOpLayout pass. +Array> ReduceInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array>& old_in_shapes) { + // NOTE: Discard "const" qualifier here. + ReduceAttrs* params = const_cast(attrs.as()); + + // Get the reduce axes. + uint32_t indim = old_in_shapes[0].size(); + auto r_axes = GetReduceAxes(indim, params->axis, params->exclude); + + Layout ret = Layout::Undef(); + if (new_in_layouts.defined() && r_axes.size()) { + // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the + // modified layout axes. + CHECK_EQ(new_in_layouts.size(), 1); + CHECK_EQ(old_in_layouts.size(), 1); + + // 1) Collect the original axes + std::unordered_set old_r_dims; + for (auto r_axis : r_axes) { + old_r_dims.emplace(old_in_layouts[0][r_axis].name()); + } + + // 2) Collect the new axes by walking new_layout. + tvm::Array new_r_axes; + std::string new_layout_string = ""; + int axis_index = 0; + for (auto iter_var : new_in_layouts[0]->axes) { + const auto& layout_axis = LayoutAxis::Get(iter_var); + const std::string& layout_dim = layout_axis.name(); + if (old_r_dims.count(layout_dim)) { + new_r_axes.push_back(tvm::Integer(axis_index)); + } + // Collect only the primal axis. + if (layout_axis.IsPrimal()) { + new_layout_string += layout_dim; + axis_index++; + } + } + + // 3) Set the new axis and layout. + ret = Layout(new_layout_string); + params->axis = new_r_axes; + } else if (old_in_layouts.defined()) { + // If the new layout is undefined, set the old layout as the inferred layout. + CHECK_EQ(old_in_layouts.size(), 1); + ret = old_in_layouts[0]; + } + + return Array>{{ret}, {ret}}; +} template Array ReduceCompute(const Attrs& attrs, @@ -325,6 +378,7 @@ Example:: .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) .add_type_rel("Reduce", ReduceRel) +.set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("FTVMCompute", SumCompute) .set_attr("TOpPattern", kCommReduce); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3f371f2..1d0c9ec 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -283,22 +283,34 @@ Array> ConcatenateLayout( const Array& new_in_layouts, const Array& old_in_layouts, const Array> &old_in_shapes) { - const ConcatenateAttrs* param = attrs.as(); + ConcatenateAttrs* param = const_cast(attrs.as()); size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast(param->axis); Layout ret; + bool is_new_layout_selected = false; if (new_in_layouts.defined()) { // this function is called after some operators are alternated. + // If all the new input layouts are same, the new in layout gets selected. For axis, the new + // axis in the new layout is identified. The param->axis is then modified on the fly to conform + // to the new input layout. const auto& concate_dim = old_in_layouts[0][axis]; - for (size_t i = 0; i < new_in_layouts.size(); ++i) { - if (new_in_layouts[i].ndim() > axis && - new_in_layouts[i][axis] == concate_dim) { - ret = new_in_layouts[i]; - break; + bool all_input_layouts_same = true; + for (auto new_layout : new_in_layouts) { + if (!new_layout.Equals(new_in_layouts[0])) { + all_input_layouts_same = false; } } - } else { // this function is called on the original correct relay ir + if (all_input_layouts_same) { + auto new_index = new_in_layouts[0].IndexOf(concate_dim); + ret = new_in_layouts[0]; + param->axis = new_index; + is_new_layout_selected = true; + } + } + + if (!is_new_layout_selected) { + // this function is called on the original correct relay ir for (size_t i = 0; i < old_in_layouts.size(); ++i) { if (old_in_layouts[i].defined()) { ret = old_in_layouts[i]; diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index cc668d7..adb8676 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -45,6 +45,7 @@ def test_alter_op(): y = relay.Function([x, weight], y) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=100) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -79,6 +80,7 @@ def test_alter_return_none(): called = [False] + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.global_max_pool2d", level=101) def alter_conv2d(attrs, inputs, tinfos): called[0] = True @@ -112,6 +114,7 @@ def test_alter_layout(): y = relay.Function(analysis.free_vars(y), y) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=102) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -180,6 +183,7 @@ def test_alter_layout_dual_path(): y = relay.Function(analysis.free_vars(ret), ret) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=103) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -241,6 +245,7 @@ def test_alter_layout_resnet(): y = relay.nn.global_max_pool2d(y) return relay.Function(analysis.free_vars(y), y) + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=104) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -291,6 +296,7 @@ def test_alter_layout_broadcast_op(): y = relay.Function(analysis.free_vars(y), y) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=105) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -338,6 +344,7 @@ def test_alter_layout_scalar(): y = relay.Function(analysis.free_vars(y), y) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=106) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -370,9 +377,19 @@ def test_alter_layout_scalar(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + def test_alter_layout_concatenate(): - """ """ - def before(): + """ NCHW, NHWC and corner case concatenate layout transform.""" + # Register alter op layout. "level" is used to override the previously registered functions. + @register_alter_op_layout("nn.conv2d", level=107) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + # NCHW layout transformation. + def before_nchw(): x = relay.var("x", shape=(1, 64, 56, 56)) weight1 = relay.var('weight1') weight2 = relay.var('weight2') @@ -388,14 +405,7 @@ def test_alter_layout_concatenate(): y = relay.Function(analysis.free_vars(ret), ret) return y - @register_alter_op_layout("nn.conv2d", level=107) - def alter_conv2d(attrs, inputs, tinfos): - data, weight = inputs - new_attrs = dict(attrs) - new_attrs['data_layout'] = 'NCHW16c' - return relay.nn.conv2d(data, weight, **new_attrs) - - def expected(): + def expected_nchw(): x = relay.var("x", shape=(1, 64, 56, 56)) weight1 = relay.var('weight1') weight2 = relay.var('weight2') @@ -415,10 +425,57 @@ def test_alter_layout_concatenate(): y = relay.Function(analysis.free_vars(ret), ret) return y - a = before() + a = before_nchw() a = run_opt_pass(a, transform.AlterOpLayout()) - b = expected() + b = expected_nchw() + b = run_opt_pass(b, transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + # NHWC layout transformation. + def before_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC') + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC') + ret = relay.concatenate([y, y1], axis=3) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1') + weight2 = relay.var('weight2') + y = relay.layout_transform(x, "NHWC", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW16c') + ret = relay.concatenate([y, y1], axis=1) + ret = relay.layout_transform(ret, "NCHW16c", "NHWC") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before_nhwc() + a = run_opt_pass(a, transform.AlterOpLayout()) + + b = expected_nhwc() b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) @@ -435,6 +492,7 @@ def test_alter_layout_nchw_upsamping_op(): y = relay.Function(analysis.free_vars(y), y) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=108) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -474,6 +532,7 @@ def test_alter_layout_strided_slice(): y = relay.Function(analysis.free_vars(y), y) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=109) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -511,6 +570,7 @@ def test_alter_layout_depthwise_conv2d(): return y import topi + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=110) def alter_conv2d(attrs, inputs, tinfos): with tvm.target.create("llvm"): @@ -548,6 +608,7 @@ def test_alter_layout_prelu(): y = relay.Function(analysis.free_vars(y), y) return y + # Register alter op layout. "level" is used to override the previously registered functions. @register_alter_op_layout("nn.conv2d", level=111) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs @@ -580,6 +641,167 @@ def test_alter_layout_prelu(): assert(analysis.alpha_equal(a, b)) +def test_alter_layout_pool(): + """ Check NCHW, NHWC pool layout conversion""" + # Register alter op layout. "level" is used to override the previously registered functions. + @register_alter_op_layout("nn.conv2d", level=113) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + # Check NCHW conversion. + def before_nchw(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + ret = relay.nn.avg_pool2d(y, pool_size=(1, 1)) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected_nchw(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + y = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c') + ret = relay.layout_transform(ret, "NCHW16c", "NCHW") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before_nchw() + a = run_opt_pass(a, transform.AlterOpLayout()) + + b = expected_nchw() + b = run_opt_pass(b, transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + # Check NHWC conversion. + def before_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC') + ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NHWC') + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1') + y = relay.layout_transform(x, "NHWC", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c') + ret = relay.layout_transform(ret, "NCHW16c", "NHWC") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before_nhwc() + a = run_opt_pass(a, transform.AlterOpLayout()) + + b = expected_nhwc() + b = run_opt_pass(b, transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_alter_layout_sum(): + """ Check NCHW, NHWC sum layout conversion""" + # Register alter op layout. "level" is used to override the previously registered functions. + @register_alter_op_layout("nn.conv2d", level=114) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + # Check NCHW conversion. + def before_nchw(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + ret = relay.sum(y, axis=1, keepdims=True) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected_nchw(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var('weight1') + y = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + ret = relay.layout_transform(y, "NCHW16c", "NCHW") + ret = relay.sum(ret, axis=[1], keepdims=True) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before_nchw() + a = run_opt_pass(a, transform.AlterOpLayout()) + + b = expected_nchw() + b = run_opt_pass(b, transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + # Check NHWC conversion. + def before_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1') + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC') + ret = relay.sum(y, axis=3, keepdims=True) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1') + y = relay.layout_transform(x, "NHWC", "NCHW16c") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + ret = relay.layout_transform(y, "NCHW16c", "NCHW") + ret = relay.sum(ret, axis=[1], keepdims=True) + ret = relay.layout_transform(ret, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before_nhwc() + a = run_opt_pass(a, transform.AlterOpLayout()) + + b = expected_nhwc() + b = run_opt_pass(b, transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -593,3 +815,5 @@ if __name__ == "__main__": test_alter_layout_strided_slice() test_alter_layout_depthwise_conv2d() test_alter_layout_prelu() + test_alter_layout_pool() + test_alter_layout_sum() -- 2.7.4