From 5498e54de80d26208b76bc4b3ecd7e3eacb7ff85 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 14 Aug 2019 16:44:13 -0700 Subject: [PATCH] [Relay][Legalize][ARM_CPU] Handling NHWC layout for arm_cpu. (#3754) --- python/tvm/relay/op/nn/_nn.py | 5 ++-- tests/python/relay/test_pass_legalize.py | 46 ++++++++++++++++++++++++++++++++ topi/python/topi/arm_cpu/conv2d.py | 31 +++++++++++++++++++++ topi/python/topi/nn/conv2d.py | 22 +++++++++++++++ 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 3e12c1b..c896a00 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -204,10 +204,11 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): from ... import op return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) -# A placeholder to have at least one invocation of register legalize to register FTVMLegalize. @reg.register_legalize("nn.conv2d") def legalize_conv2d(attrs, inputs, arg_dtypes): - return None + """Legalize conv2d""" + from ... import op + return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 364d6b4..52deeb5 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """Test legalize pass""" +import numpy as np import tvm from tvm import relay +from tvm.contrib import graph_runtime from tvm.relay.op import register_legalize from tvm.relay import transform, analysis @@ -123,8 +125,52 @@ def test_legalize_multi_input(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) +def test_legalize_arm_layout_functional(): + """Test if the legalized conversion yields same result as original""" + def get_output(func, data_val, parameters): + with relay.build_config(opt_level=0): + graph, lib, params = relay.build(func, target='llvm', params=parameters) + m = graph_runtime.create(graph, lib, tvm.cpu()) + m.set_input("data", data_val) + m.set_input(**params) + m.run() + out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy() + return out + + def before(): + n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3 + data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32')) + kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32')) + y = relay.nn.conv2d(data, kernel, + kernel_size=(kh, kw), + channels=oc, + padding=(1, 1), + dilation=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO', + out_dtype='float32') + func = relay.Function([data, kernel], y) + return func + + @register_legalize("nn.conv2d", level=101) + def legalize_conv2d(attrs, inputs, arg_types): + from topi.arm_cpu.conv2d import _conv2d_legalize + return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op) + + a = before() + b = run_opt_pass(a, transform.Legalize()) + assert b.astext().count('transpose') == 3 + + wdata = np.random.rand(3, 3, 16, 32) * 10 + parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))} + data_val = np.random.rand(1, 224, 224, 16).astype('float32') + ref_out = get_output(a, data_val, parameters) + legalized_out = get_output(b, data_val, parameters) + np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01) + if __name__ == "__main__": test_legalize() test_legalize_none() test_legalize_multi_input() + test_legalize_arm_layout_functional() diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 62df8f9..95342b6 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -31,6 +31,7 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ conv2d_winograd_without_weight_transform, \ conv2d_winograd_nnpack_without_weight_transform, \ depthwise_conv2d_nchw +from ..nn import conv2d_legalize from ..nn.util import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices @@ -783,3 +784,33 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # currently we only have contrib_spatial_pack and direct template # add more schedule templates. return None + +@conv2d_legalize.register("arm_cpu") +def _conv2d_legalize(attrs, inputs, arg_types, F): + if F.__name__ != 'tvm.relay.op': + return None + if attrs['data_layout'] == 'NHWC': + data, kernel = inputs + if attrs['kernel_layout'] == 'HWIO': + # Handle HWIO layout. This is common in TF graph. + kernel = F.transpose(kernel, axes=(3, 2, 0, 1)) + elif attrs['kernel_layout'] == 'HWOI': + # Handle HWOI layout. This is common in TF depthwise conv2d graph. + kernel = F.transpose(kernel, axes=(2, 3, 0, 1)) + elif attrs['kernel_layout'] != 'OIHW': + return None + + warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to " + + "fallback to NCHW. This can result in performance degradation.") + # Set new attrs for the tranposed conv. + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['data_layout'] = 'NCHW' + new_attrs['kernel_layout'] = 'OIHW' + + # Convert from NHWC to NCHW. + data = F.transpose(data, axes=(0, 3, 1, 2)) + conv = F.nn.conv2d(data, kernel, **new_attrs) + # Convert back to original NHWC layout. + out = F.transpose(conv, axes=(0, 2, 3, 1)) + return out + return None diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index affbfca..e7ab7ba 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -72,6 +72,28 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N @tvm.target.generic_func +def conv2d_legalize(attrs, inputs, arg_dtypes, F): + """Legalizes Conv2D op. + Parameters + ---------- + attrs : nnvm.top.AttrDict or tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized. + arg_dtypes : list of types + List of types of input arguments + F: symbol + The context, can be either nnvm.sym or relay.op + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level, + so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay. + """ + # not to change by default + return None + + +@tvm.target.generic_func def conv2d_alter_layout(attrs, inputs, tinfos, F): """Change Conv2D layout. -- 2.7.4