[PASS][ConvertLayout] Fixes AttributeError during ConvertLayout to NHWC (#6419)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Wed, 9 Sep 2020 15:43:39 +0000 (16:43 +0100)
committerGitHub <noreply@github.com>
Wed, 9 Sep 2020 15:43:39 +0000 (08:43 -0700)
Fixes an issue described in #6410. In order to retrieve the shape a tensor `checked_type` should be used.

Change-Id: I991d194d9cc15ee20464ff2e239fd05c035000c8

python/tvm/relay/op/nn/_nn.py
python/tvm/relay/qnn/op/layout_conversions.py
tests/python/relay/test_pass_convert_op_layout.py

index 43fca6d..02cf78d 100644 (file)
@@ -157,8 +157,10 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
         return relay.nn.conv2d(data, weight, **new_attrs)
     elif desired_data_layout == 'NHWC':
         # Check for depthwise convolution.
-        if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape,
-                               attrs['kernel_layout'], attrs['groups']):
+        data_info, weight_info = tinfos
+        if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
+                               weight_info.shape, attrs['kernel_layout'],
+                               attrs['groups']):
             new_attrs['kernel_layout'] = 'HWOI'
         else:
             new_attrs['kernel_layout'] = 'HWIO'
index 391714a..3d71438 100644 (file)
@@ -62,8 +62,10 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts):
         return relay.qnn.op.conv2d(*inputs, **new_attrs)
     if desired_data_layout == 'NHWC':
         # Check for depthwise convolution.
-        if is_depthwise_conv2d(inputs[0].shape, attrs['data_layout'], inputs[1].shape,
-                               attrs['kernel_layout'], attrs['groups']):
+        data_info, weight_info = tinfos
+        if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
+                               weight_info.shape, attrs['kernel_layout'],
+                               attrs['groups']):
             new_attrs['kernel_layout'] = 'HWOI'
         else:
             new_attrs['kernel_layout'] = 'HWIO'
index aec758d..e71cfdc 100644 (file)
@@ -90,6 +90,43 @@ def test_conv_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_nhwc_convert_layout():
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout='NCHW',
+                            kernel_layout='OIHW')
+        y = relay.nn.relu(y)
+        y = relay.Function([x, weight], y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        x = relay.layout_transform(x, 'NCHW', 'NHWC')
+        weight = relay.layout_transform(weight, 'OIHW', 'HWIO')
+        y = relay.nn.conv2d(x, weight,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout="NHWC",
+                            kernel_layout="HWIO")
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, 'NHWC', 'NCHW')
+        y = relay.Function(relay.analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NHWC', 'default']}))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
 def test_conv_transpose_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
@@ -795,6 +832,7 @@ def test_different_ops_convert_layout():
 if __name__ == "__main__":
     test_no_convert_layout()
     test_conv_convert_layout()
+    test_conv_nhwc_convert_layout()
     test_conv_bias_pool_convert_layout()
     test_conv_concat_convert_layout()
     test_dual_path_convert_layout()