[TOPI-ARM] Do not alter layout if layout is NHWC (#5350)
authorAnimesh Jain <anijain@umich.edu>
Fri, 17 Apr 2020 07:10:02 +0000 (00:10 -0700)
committerGitHub <noreply@github.com>
Fri, 17 Apr 2020 07:10:02 +0000 (15:10 +0800)
* [TOPI-ARM] Do not alter layout if layout is NHWC

* Add test.

tests/python/relay/test_pass_alter_op_layout.py
topi/python/topi/arm_cpu/conv2d_alter_op.py

index 2a2e265..9b18f72 100644 (file)
@@ -940,11 +940,8 @@ def test_alter_layout_sum():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
-# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the
-#   right behavior of alter_layout
-@pytest.mark.skip
-def test_alter_layout_nhwc_nchw_arm():
-    """ Check NHWC to NHCW conversion for a small sequence of ops."""
+def test_alter_layout_nhwc_arm():
+    """ Check that AlterOplayout does not alter NHWC data layout. """
     def alter_conv2d(attrs, inputs, tinfos, out_type):
         import topi
         with tvm.target.create("llvm -device=arm_cpu"):
@@ -974,25 +971,7 @@ def test_alter_layout_nhwc_nchw_arm():
         return y
 
     def expected_nhwc():
-        x = relay.var("x", shape=(1, 56, 56, 64))
-        weight1 = relay.var('weight1', shape=(3, 3, 64, 64))
-        weight2 = relay.var('weight2', shape=(3, 3, 64, 64))
-        y = relay.layout_transform(x, "NHWC", "NCHW")
-        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
-        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
-        y = relay.nn.conv2d(y, weight1,
-                            channels=64,
-                            kernel_size=(3, 3))
-        y = relay.nn.relu(y)
-        y = relay.nn.avg_pool2d(y,
-                                pool_size=(1,1))
-        y = relay.nn.conv2d(y, weight2,
-                            channels=64,
-                            kernel_size=(3, 3))
-        y = relay.nn.relu(y)
-        y = relay.layout_transform(y, "NCHW", "NHWC")
-        y = relay.Function(analysis.free_vars(y), y)
-        return y
+        return before_nhwc()
 
     with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
         a = before_nhwc()
@@ -1060,5 +1039,5 @@ if __name__ == "__main__":
     test_alter_layout_pad()
     test_alter_layout_pool()
     test_alter_layout_sum()
-    # test_alter_layout_nhwc_nchw_arm()
+    test_alter_layout_nhwc_arm()
     test_alter_op_with_global_var()
index 02bb4c3..aa4c878 100644 (file)
@@ -59,6 +59,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     data, kernel = tinfos
     out_dtype = out_type.dtype
 
+    # We only perform layout alteration for NCHW data layout.
+    if data_layout == "NHWC":
+        return None
+
     # Extract data types
     data_tensor, kernel_tensor = tinfos
     data_dtype = data_tensor.dtype