def test_qnn_conv_nhwc_convert_layout():
def before():
- x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
- weight = relay.var('weight', shape=(64, 64, 3, 3), dtype='int8')
- y = relay.qnn.op.conv2d(x, weight,
- relay.const(1, 'int32'),
- relay.const(1, 'int32'),
- relay.const(1, 'float32'),
- relay.const(1, 'float32'),
- channels=64,
- kernel_size=(3, 3),
- padding=(1, 1),
- data_layout='NCHW',
- kernel_layout='OIHW')
+ x = relay.var("x", shape=(1, 64, 56, 56), dtype="int8")
+ weight = relay.var("weight", shape=(64, 64, 3, 3), dtype="int8")
+ y = relay.qnn.op.conv2d(
+ x,
+ weight,
+ relay.const(1, "int32"),
+ relay.const(1, "int32"),
+ relay.const(1, "float32"),
+ relay.const(1, "float32"),
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ )
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y
def expected():
- x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
- weight = relay.var('weight', shape=(64, 64, 3, 3), dtype='int8')
- x = relay.layout_transform(x, 'NCHW', 'NHWC')
- weight = relay.layout_transform(weight, 'OIHW', 'HWIO')
- y = relay.qnn.op.conv2d(x, weight,
- relay.const(1, 'int32'),
- relay.const(1, 'int32'),
- relay.const(1, 'float32'),
- relay.const(1, 'float32'),
- channels=64,
- kernel_size=(3, 3),
- padding=(1, 1),
- data_layout="NHWC",
- kernel_layout="HWIO")
+ x = relay.var("x", shape=(1, 64, 56, 56), dtype="int8")
+ weight = relay.var("weight", shape=(64, 64, 3, 3), dtype="int8")
+ x = relay.layout_transform(x, "NCHW", "NHWC")
+ weight = relay.layout_transform(weight, "OIHW", "HWIO")
+ y = relay.qnn.op.conv2d(
+ x,
+ weight,
+ relay.const(1, "int32"),
+ relay.const(1, "int32"),
+ relay.const(1, "float32"),
+ relay.const(1, "float32"),
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
y = relay.nn.relu(y)
- y = relay.layout_transform(y, 'NHWC', 'NCHW')
+ y = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.Function(relay.analysis.free_vars(y), y)
return y
a = before()
- a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NHWC', 'default']}))
+ a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NHWC", "default"]}))
b = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)