From 1d2436647bffdcbb1e133b55dc4c7365f604fc3d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 10 Nov 2019 19:09:16 -0800 Subject: [PATCH] [TOPI][AlterOpLayout][ARM] Enabling NHWC to NCHW layout transformation. (#4249) --- tests/python/relay/test_pass_alter_op_layout.py | 62 +++++++++++ tests/python/relay/test_pass_legalize.py | 44 -------- topi/python/topi/arm_cpu/conv2d.py | 130 ++++++++++++------------ 3 files changed, 129 insertions(+), 107 deletions(-) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index f1200ec..2738690 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -916,6 +916,67 @@ def test_alter_layout_sum(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) +def test_alter_layout_nhwc_nchw_arm(): + """ Check NHWC to NHCW conversion for a small sequence of ops.""" + # Register alter op layout. "level" is used to override the previously registered functions. + @register_alter_op_layout("nn.conv2d", level=115) + def alter_conv2d(attrs, inputs, tinfos): + from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm + return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay) + + # Check NHWC conversion. + def before_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight1, + channels=64, + kernel_size=(3, 3), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y = relay.nn.avg_pool2d(y, + pool_size=(1,1), + layout='NHWC') + y = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected_nhwc(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + y = relay.layout_transform(x, "NHWC", "NCHW") + weight1 = relay.layout_transform(weight1, "HWIO", "OIHW") + weight2 = relay.layout_transform(weight2, "HWIO", "OIHW") + y = relay.nn.conv2d(y, weight1, + channels=64, + kernel_size=(3, 3)) + y = relay.nn.relu(y) + y = relay.nn.avg_pool2d(y, + pool_size=(1,1)) + y = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3)) + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before_nhwc() + a = run_opt_pass(a, transform.AlterOpLayout()) + + b = expected_nhwc() + b = run_opt_pass(b, transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -932,3 +993,4 @@ if __name__ == "__main__": test_alter_layout_pad() test_alter_layout_pool() test_alter_layout_sum() + test_alter_layout_nhwc_nchw_arm() diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index c5303ef..2f0fbee 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -171,53 +171,9 @@ def test_legalize_multi_input(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) -def test_legalize_arm_layout_functional(): - """Test if the legalized conversion yields same result as original""" - def get_output(func, data_val, parameters): - with relay.build_config(opt_level=0): - graph, lib, params = relay.build(func, target='llvm', params=parameters) - m = graph_runtime.create(graph, lib, tvm.cpu()) - m.set_input("data", data_val) - m.set_input(**params) - m.run() - out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy() - return out - - def before(): - n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3 - data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32')) - kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32')) - y = relay.nn.conv2d(data, kernel, - kernel_size=(kh, kw), - channels=oc, - padding=(1, 1), - dilation=(1, 1), - data_layout='NHWC', - kernel_layout='HWIO', - out_dtype='float32') - func = relay.Function([data, kernel], y) - return func - - @register_legalize("nn.conv2d", level=105) - def legalize_conv2d(attrs, inputs, types): - from topi.arm_cpu.conv2d import _conv2d_legalize - return _conv2d_legalize(attrs, inputs, types) - - a = before() - b = run_opt_pass(a, transform.Legalize()) - assert b.astext().count('transpose') == 3 - - wdata = np.random.rand(3, 3, 16, 32) * 10 - parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))} - data_val = np.random.rand(1, 224, 224, 16).astype('float32') - ref_out = get_output(a, data_val, parameters) - legalized_out = get_output(b, data_val, parameters) - np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01) - if __name__ == "__main__": test_legalize() test_legalize_none() test_legalize_multiple_ops() test_legalize_multi_input() - test_legalize_arm_layout_functional() diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index c06c739..cbb6085 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -22,7 +22,6 @@ import logging import tvm from tvm import autotvm -from tvm import relay import tvm.contrib.nnpack from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \ @@ -32,7 +31,6 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ conv2d_winograd_without_weight_transform, \ conv2d_winograd_nnpack_without_weight_transform, \ depthwise_conv2d_nchw -from ..nn import conv2d_legalize from ..nn.util import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \ @@ -508,32 +506,63 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): groups = attrs.get_int('groups') data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout" layout = attrs[data_layout_key] + kernel_layout = attrs['kernel_layout'] out_dtype = attrs["out_dtype"] if out_dtype in ("same", ""): out_dtype = tinfos[0].dtype - if layout != 'NCHW': - return None if dilation != (1, 1): logger.warning("Does not support weight pre-transform for dilated convolution.") return None + # query config of this workload data, kernel = tinfos[0:2] - N, CI, H, W = get_const_tuple(data.shape) - CO, _, KH, KW = get_const_tuple(kernel.shape) + if groups == 1: + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) + else: + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) + + if layout == 'NCHW' and kernel_layout == 'OIHW': + N, CI, H, W = get_const_tuple(data.shape) + CO, _, KH, KW = get_const_tuple(kernel.shape) + elif layout == 'NHWC' and kernel_layout == 'HWIO': + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, _, CO = get_const_tuple(kernel.shape) + # Also modify the workload to pick up because later we convert to NCHW + # layout. + new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype) + new_kernel = tvm.placeholder((CO, CI, KH, KW), dtype=kernel.dtype) + new_layout = 'NCHW' + workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_layout, out_dtype], conv2d) + elif layout == 'NHWC' and kernel_layout == 'HWOI': + # This is the case for depthwise convolution. + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, CO, M = get_const_tuple(kernel.shape) + # Also modify the workload to pick up because later we convert to NCHW + # layout. + new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype) + new_kernel = tvm.placeholder((CO, M, KH, KW), dtype=kernel.dtype) + workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) + else: + return None idxd = tvm.indexdiv if groups == 1: - # query config of this workload - workload = autotvm.task.args_to_workload( - [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) target = tvm.target.current_target() dispatch_ctx = autotvm.DispatchContext.current cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: # if is fallback, clear query cache and return None autotvm.task.clear_fallback_cache(target, workload) + if layout == 'NHWC' and kernel_layout == 'HWIO': + new_attrs['data_layout'] = 'NCHW' + new_attrs['kernel_layout'] = 'OIHW' + return F.nn.conv2d(*copy_inputs, **new_attrs) return None if cfg.template_key == 'direct': # pack weight tensor @@ -541,7 +570,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): new_attrs['kernel_layout'] = 'OIHW%do' % VC # Store the same config for the altered operator (workload) - new_data = data + new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype) + new_attrs[data_layout_key] = 'NCHW' new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) @@ -560,7 +590,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): tile_size = _pick_tile_size(tinfos[0], tinfos[1]) VC = cfg['tile_bna'].val - weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], + weight = copy_inputs[1] + if kernel_layout != 'OIHW': + weight = F.transpose(weight, axes=(2, 3, 0, 1)) + weight = F.nn.contrib_conv2d_winograd_weight_transform(weight, tile_size=tile_size) if VC > 0: weight = F.reshape(weight, @@ -581,9 +614,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): copy_inputs[1] = weight new_attrs['tile_size'] = tile_size + new_attrs[data_layout_key] = 'NCHW' # Store the same config for the altered operator (workload) - new_data = data + new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_weight, strides, padding, dilation, new_attrs[data_layout_key], out_dtype, tile_size], @@ -596,14 +630,21 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # for winograd_nnpack_fp16, the the precomputeprune pass must run on device, # where float16 is supported weight_dtype = 'float32' + weight = copy_inputs[1] + if kernel_layout != 'OIHW': + weight = F.transpose(weight, axes=(2, 3, 0, 1)) + weight = F.nn.contrib_conv2d_winograd_weight_transform(weight, + tile_size=tile_size) transformed_kernel = F.nn.contrib_conv2d_winograd_nnpack_weight_transform( - copy_inputs[1], + weight, convolution_algorithm=cfg['winograd_nnpack_algorithm'].val, out_dtype=weight_dtype) copy_inputs[1] = transformed_kernel - new_data = data + + new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype) new_kernel = tvm.placeholder((CO, CI, 8, 8), "float32") bias = tvm.placeholder((CO, ), "float32") + new_attrs[data_layout_key] = 'NCHW' new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, bias, strides, padding, dilation, new_attrs[data_layout_key], out_dtype] @@ -617,22 +658,30 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): else: raise RuntimeError("Unsupported template_key '%s'" % cfg.template_key) else: - workload = autotvm.task.args_to_workload( - [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) target = tvm.target.current_target() dispatch_ctx = autotvm.DispatchContext.current cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: # if is fallback, clear query cache and return None autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload) + if layout == 'NHWC' and kernel_layout == 'HWOI': + new_attrs['data_layout'] = 'NCHW' + new_attrs['kernel_layout'] = 'OIHW' + return F.nn.conv2d(*copy_inputs, **new_attrs) return None if cfg.template_key == 'contrib_spatial_pack': VC = cfg['tile_co'].size[-1] new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1]) # Store the same config for the altered operator (workload) - new_data = data - CO, M, KH, KW = get_const_tuple(kernel.shape) + new_data = tvm.placeholder((N, CI, H, W), dtype=data.dtype) + new_attrs[data_layout_key] = 'NCHW' + if attrs['kernel_layout'] == 'OIHW': + CO, M, KH, KW = get_const_tuple(kernel.shape) + elif attrs['kernel_layout'] == 'HWOI': + KH, KW, CO, M = get_const_tuple(kernel.shape) + else: + raise RuntimeError("Depthwise conv should either have OIHW/HWIO kernel layout") new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, out_dtype], @@ -644,48 +693,3 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # currently we only have contrib_spatial_pack and direct template # add more schedule templates. return None - -@conv2d_legalize.register("arm_cpu") -def _conv2d_legalize(attrs, inputs, arg_types): - """Legalizes Conv2D op. - - Parameters - ---------- - attrs : tvm.attrs.Attrs - Attributes of current convolution - inputs : list of tvm.relay.Expr - The args of the Relay expr to be legalized - types : list of types - List of input and output types - - Returns - ------- - result : tvm.relay.Expr - The legalized expr - """ - - if attrs['data_layout'] == 'NHWC': - data, kernel = inputs - if attrs['kernel_layout'] == 'HWIO': - # Handle HWIO layout. This is common in TF graph. - kernel = relay.transpose(kernel, axes=(3, 2, 0, 1)) - elif attrs['kernel_layout'] == 'HWOI': - # Handle HWOI layout. This is common in TF depthwise conv2d graph. - kernel = relay.transpose(kernel, axes=(2, 3, 0, 1)) - elif attrs['kernel_layout'] != 'OIHW': - return None - - logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to " - + "fallback to NCHW. This can result in performance degradation.") - # Set new attrs for the tranposed conv. - new_attrs = {k: attrs[k] for k in attrs.keys()} - new_attrs['data_layout'] = 'NCHW' - new_attrs['kernel_layout'] = 'OIHW' - - # Convert from NHWC to NCHW. - data = relay.transpose(data, axes=(0, 3, 1, 2)) - conv = relay.nn.conv2d(data, kernel, **new_attrs) - # Convert back to original NHWC layout. - out = relay.transpose(conv, axes=(0, 2, 3, 1)) - return out - return None -- 2.7.4