From 35af4c8b42588c5736f49cf32c15f88e9a582e2e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 30 Dec 2019 15:35:25 -0800 Subject: [PATCH] [Relay][Convert Layout] Handling batch norm layout change. (#4600) --- src/relay/op/nn/nn.cc | 29 ++++++++++++++ src/relay/pass/convert_layout.cc | 2 +- tests/python/relay/test_pass_convert_op_layout.py | 49 +++++++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 2cb0c28..dfb360a 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -617,6 +617,34 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input // batch_norm TVM_REGISTER_NODE_TYPE(BatchNormAttrs); +Array> BatchNormInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array>& old_in_shapes) { + BatchNormAttrs* param = const_cast(attrs.as()); + + size_t axis = + param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast(param->axis); + + Layout ret = Layout::Undef(); + + // If new_in_layouts are defined, this code tries to modify the layout. + if (new_in_layouts.defined() && old_in_layouts.defined()) { + // Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout. + const auto& bn_dim = old_in_layouts[0][axis]; + auto new_index = new_in_layouts[0].IndexOf(bn_dim); + param->axis = new_index; + ret = new_in_layouts[0]; + } else if (old_in_layouts.defined()) { + ret = old_in_layouts[0]; + } + // BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout. + Layout c_layout = Layout("C"); + + return Array>{{ret, c_layout, c_layout, c_layout, c_layout}, + {ret, c_layout, c_layout}}; +} + bool BatchNormRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -708,6 +736,7 @@ axis to be the last item in the input shape. .add_argument("beta", "Tensor", "The beta offset factor.") .add_argument("moving_mean", "Tensor", "Running mean of input.") .add_argument("moving_var", "Tensor", "Running variance of input.") +.set_attr("FInferCorrectLayout", BatchNormInferCorrectLayout) .set_support_level(1) .add_type_rel("BatchNorm", BatchNormRel); diff --git a/src/relay/pass/convert_layout.cc b/src/relay/pass/convert_layout.cc index 1db4422..fa8b872 100644 --- a/src/relay/pass/convert_layout.cc +++ b/src/relay/pass/convert_layout.cc @@ -134,7 +134,7 @@ Pass ConvertLayout(const std::string& desired_layout) { }; return CreateFunctionPass( pass_func, 3, "ConvertLayout", - {ir::StringImm::make("InferType"), ir::StringImm::make("SimplifyInference"), + {ir::StringImm::make("InferType"), ir::StringImm::make("CanonicalizeOps")}); } diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 9544525..dfd7451 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -349,6 +349,54 @@ def test_scalar_convert_layout(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) +def test_conv_bn_convert_layout(): + """ Check that layout transforms are propagated through bn. """ + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout='NHWC', kernel_layout='HWIO') + + dtype = "float32" + beta = relay.var("beta", relay.TensorType((64,), dtype)) + gamma = relay.var("gamma", relay.TensorType((64,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((64,), dtype)) + + y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=3) + y = relay.nn.relu(y[0]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + w = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, 'NHWC', 'NCHW') + w = relay.layout_transform(w, 'HWIO', 'OIHW') + y = relay.nn.conv2d(x, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + + dtype = "float32" + beta = relay.var("beta", relay.TensorType((64,), dtype)) + gamma = relay.var("gamma", relay.TensorType((64,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((64,), dtype)) + + y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=1) + y = relay.nn.relu(y[0]) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_no_convert_layout() test_conv_convert_layout() @@ -358,3 +406,4 @@ if __name__ == "__main__": test_bn_convert_layout() test_resnet_convert_layout() test_scalar_convert_layout() + test_conv_bn_convert_layout() -- 2.7.4