// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
+Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<Array<IndexExpr>>& old_in_shapes) {
+ BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>());
+
+ size_t axis =
+ param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
+
+ Layout ret = Layout::Undef();
+
+ // If new_in_layouts are defined, this code tries to modify the layout.
+ if (new_in_layouts.defined() && old_in_layouts.defined()) {
+ // Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout.
+ const auto& bn_dim = old_in_layouts[0][axis];
+ auto new_index = new_in_layouts[0].IndexOf(bn_dim);
+ param->axis = new_index;
+ ret = new_in_layouts[0];
+ } else if (old_in_layouts.defined()) {
+ ret = old_in_layouts[0];
+ }
+ // BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout.
+ Layout c_layout = Layout("C");
+
+ return Array<Array<Layout>>{{ret, c_layout, c_layout, c_layout, c_layout},
+ {ret, c_layout, c_layout}};
+}
+
bool BatchNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
.add_argument("beta", "Tensor", "The beta offset factor.")
.add_argument("moving_mean", "Tensor", "Running mean of input.")
.add_argument("moving_var", "Tensor", "Running variance of input.")
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout)
.set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel);
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+def test_conv_bn_convert_layout():
+ """ Check that layout transforms are propagated through bn. """
+ def before():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight = relay.var("weight", shape=(3, 3, 64, 64))
+ y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
+ data_layout='NHWC', kernel_layout='HWIO')
+
+ dtype = "float32"
+ beta = relay.var("beta", relay.TensorType((64,), dtype))
+ gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+ moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
+ moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))
+
+ y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=3)
+ y = relay.nn.relu(y[0])
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ w = relay.var("weight", shape=(3, 3, 64, 64))
+ x = relay.layout_transform(x, 'NHWC', 'NCHW')
+ w = relay.layout_transform(w, 'HWIO', 'OIHW')
+ y = relay.nn.conv2d(x, w,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+
+ dtype = "float32"
+ beta = relay.var("beta", relay.TensorType((64,), dtype))
+ gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+ moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
+ moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))
+
+ y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=1)
+ y = relay.nn.relu(y[0])
+ y = relay.layout_transform(y, "NCHW", "NHWC")
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
if __name__ == "__main__":
test_no_convert_layout()
test_conv_convert_layout()
test_bn_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
+ test_conv_bn_convert_layout()