return relay.nn.conv2d(data, weight, **new_attrs)
elif desired_data_layout == 'NHWC':
# Check for depthwise convolution.
- if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape,
- attrs['kernel_layout'], attrs['groups']):
+ data_info, weight_info = tinfos
+ if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
+ weight_info.shape, attrs['kernel_layout'],
+ attrs['groups']):
new_attrs['kernel_layout'] = 'HWOI'
else:
new_attrs['kernel_layout'] = 'HWIO'
return relay.qnn.op.conv2d(*inputs, **new_attrs)
if desired_data_layout == 'NHWC':
# Check for depthwise convolution.
- if is_depthwise_conv2d(inputs[0].shape, attrs['data_layout'], inputs[1].shape,
- attrs['kernel_layout'], attrs['groups']):
+ data_info, weight_info = tinfos
+ if is_depthwise_conv2d(data_info.shape, attrs['data_layout'],
+ weight_info.shape, attrs['kernel_layout'],
+ attrs['groups']):
new_attrs['kernel_layout'] = 'HWOI'
else:
new_attrs['kernel_layout'] = 'HWIO'
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+def test_conv_nhwc_convert_layout():
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var('weight', shape=(64, 64, 3, 3))
+ y = relay.nn.conv2d(x, weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout='NCHW',
+ kernel_layout='OIHW')
+ y = relay.nn.relu(y)
+ y = relay.Function([x, weight], y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var('weight', shape=(64, 64, 3, 3))
+ x = relay.layout_transform(x, 'NCHW', 'NHWC')
+ weight = relay.layout_transform(weight, 'OIHW', 'HWIO')
+ y = relay.nn.conv2d(x, weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO")
+ y = relay.nn.relu(y)
+ y = relay.layout_transform(y, 'NHWC', 'NCHW')
+ y = relay.Function(relay.analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NHWC', 'default']}))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
def test_conv_transpose_convert_layout():
def before():
x = relay.var("x", shape=(1, 56, 56, 64))
if __name__ == "__main__":
test_no_convert_layout()
test_conv_convert_layout()
+ test_conv_nhwc_convert_layout()
test_conv_bias_pool_convert_layout()
test_conv_concat_convert_layout()
test_dual_path_convert_layout()