[Relay][Convert Layout] Handling batch norm layout change. (#4600)
authorAnimesh Jain <anijain@umich.edu>
Mon, 30 Dec 2019 23:35:25 +0000 (15:35 -0800)
committerYizhi Liu <liuyizhi@apache.org>
Mon, 30 Dec 2019 23:35:25 +0000 (15:35 -0800)
src/relay/op/nn/nn.cc
src/relay/pass/convert_layout.cc
tests/python/relay/test_pass_convert_op_layout.py

index 2cb0c28..dfb360a 100644 (file)
@@ -617,6 +617,34 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
 // batch_norm
 TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
 
+Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
+                                                 const Array<Layout>& new_in_layouts,
+                                                 const Array<Layout>& old_in_layouts,
+                                                 const Array<Array<IndexExpr>>& old_in_shapes) {
+  BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>());
+
+  size_t axis =
+      param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
+
+  Layout ret = Layout::Undef();
+
+  // If new_in_layouts are defined, this code tries to modify the layout.
+  if (new_in_layouts.defined() && old_in_layouts.defined()) {
+    // Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout.
+    const auto& bn_dim = old_in_layouts[0][axis];
+    auto new_index = new_in_layouts[0].IndexOf(bn_dim);
+    param->axis = new_index;
+    ret = new_in_layouts[0];
+  } else if (old_in_layouts.defined()) {
+    ret = old_in_layouts[0];
+  }
+  // BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout.
+  Layout c_layout = Layout("C");
+
+  return Array<Array<Layout>>{{ret, c_layout, c_layout, c_layout, c_layout},
+                              {ret, c_layout, c_layout}};
+}
+
 bool BatchNormRel(const Array<Type>& types,
                   int num_inputs,
                   const Attrs& attrs,
@@ -708,6 +736,7 @@ axis to be the last item in the input shape.
 .add_argument("beta", "Tensor", "The beta offset factor.")
 .add_argument("moving_mean", "Tensor", "Running mean of input.")
 .add_argument("moving_var", "Tensor", "Running variance of input.")
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout)
 .set_support_level(1)
 .add_type_rel("BatchNorm", BatchNormRel);
 
index 1db4422..fa8b872 100644 (file)
@@ -134,7 +134,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
       };
   return CreateFunctionPass(
       pass_func, 3, "ConvertLayout",
-      {ir::StringImm::make("InferType"), ir::StringImm::make("SimplifyInference"),
+      {ir::StringImm::make("InferType"),
        ir::StringImm::make("CanonicalizeOps")});
 }
 
index 9544525..dfd7451 100644 (file)
@@ -349,6 +349,54 @@ def test_scalar_convert_layout():
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_bn_convert_layout():
+    """ Check that layout transforms are propagated through bn. """
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
+                            data_layout='NHWC', kernel_layout='HWIO')
+
+        dtype = "float32"
+        beta = relay.var("beta", relay.TensorType((64,), dtype))
+        gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+        moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
+        moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))
+
+        y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=3)
+        y = relay.nn.relu(y[0])
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        w = relay.var("weight", shape=(3, 3, 64, 64))
+        x = relay.layout_transform(x, 'NHWC', 'NCHW')
+        w = relay.layout_transform(w, 'HWIO', 'OIHW')
+        y = relay.nn.conv2d(x, w,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+
+        dtype = "float32"
+        beta = relay.var("beta", relay.TensorType((64,), dtype))
+        gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+        moving_mean = relay.var("moving_mean", relay.TensorType((64,), dtype))
+        moving_var = relay.var("moving_var", relay.TensorType((64,), dtype))
+
+        y = relay.nn.batch_norm(y, gamma, beta, moving_mean, moving_var, axis=1)
+        y = relay.nn.relu(y[0])
+        y = relay.layout_transform(y, "NCHW", "NHWC")
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
 if __name__ == "__main__":
     test_no_convert_layout()
     test_conv_convert_layout()
@@ -358,3 +406,4 @@ if __name__ == "__main__":
     test_bn_convert_layout()
     test_resnet_convert_layout()
     test_scalar_convert_layout()
+    test_conv_bn_convert_layout()