[Relay][Legalize][ARM_CPU] Handling NHWC layout for arm_cpu. (#3754)
authorAnimesh Jain <anijain@umich.edu>
Wed, 14 Aug 2019 23:44:13 +0000 (16:44 -0700)
committerYao Wang <kevinthesunwy@gmail.com>
Wed, 14 Aug 2019 23:44:13 +0000 (16:44 -0700)
python/tvm/relay/op/nn/_nn.py
tests/python/relay/test_pass_legalize.py
topi/python/topi/arm_cpu/conv2d.py
topi/python/topi/nn/conv2d.py

index 3e12c1b..c896a00 100644 (file)
@@ -204,10 +204,11 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
     from ... import op
     return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
 
-# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
 @reg.register_legalize("nn.conv2d")
 def legalize_conv2d(attrs, inputs, arg_dtypes):
-    return None
+    """Legalize conv2d"""
+    from ... import op
+    return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
 
 reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
index 364d6b4..52deeb5 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Test legalize pass"""
+import numpy as np
 import tvm
 
 from tvm import relay
+from tvm.contrib import graph_runtime
 from tvm.relay.op import register_legalize
 from tvm.relay import transform, analysis
 
@@ -123,8 +125,52 @@ def test_legalize_multi_input():
 
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
+def test_legalize_arm_layout_functional():
+    """Test if the legalized conversion yields same result as original"""
+    def get_output(func, data_val, parameters):
+        with relay.build_config(opt_level=0):
+            graph, lib, params = relay.build(func, target='llvm', params=parameters)
+        m = graph_runtime.create(graph, lib, tvm.cpu())
+        m.set_input("data", data_val)
+        m.set_input(**params)
+        m.run()
+        out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
+        return out
+
+    def before():
+        n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
+        data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
+        kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
+        y = relay.nn.conv2d(data, kernel,
+                            kernel_size=(kh, kw),
+                            channels=oc,
+                            padding=(1, 1),
+                            dilation=(1, 1),
+                            data_layout='NHWC',
+                            kernel_layout='HWIO',
+                            out_dtype='float32')
+        func = relay.Function([data, kernel], y)
+        return func
+
+    @register_legalize("nn.conv2d", level=101)
+    def legalize_conv2d(attrs, inputs, arg_types):
+        from topi.arm_cpu.conv2d import _conv2d_legalize
+        return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
+
+    a = before()
+    b = run_opt_pass(a, transform.Legalize())
+    assert b.astext().count('transpose') == 3
+
+    wdata = np.random.rand(3, 3, 16, 32) * 10
+    parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
+    data_val = np.random.rand(1, 224, 224, 16).astype('float32')
+    ref_out = get_output(a, data_val, parameters)
+    legalized_out = get_output(b, data_val, parameters)
+    np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
+
 
 if __name__ == "__main__":
     test_legalize()
     test_legalize_none()
     test_legalize_multi_input()
+    test_legalize_arm_layout_functional()
index 62df8f9..95342b6 100644 (file)
@@ -31,6 +31,7 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
                  conv2d_winograd_without_weight_transform, \
                  conv2d_winograd_nnpack_without_weight_transform, \
                  depthwise_conv2d_nchw
+from ..nn import conv2d_legalize
 from ..nn.util import get_const_int, get_pad_tuple
 from ..nn.winograd_util import winograd_transform_matrices
 
@@ -783,3 +784,33 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
             # currently we only have contrib_spatial_pack and direct template
             # add more schedule templates.
             return None
+
+@conv2d_legalize.register("arm_cpu")
+def _conv2d_legalize(attrs, inputs, arg_types, F):
+    if F.__name__ != 'tvm.relay.op':
+        return None
+    if attrs['data_layout'] == 'NHWC':
+        data, kernel = inputs
+        if attrs['kernel_layout'] == 'HWIO':
+            # Handle HWIO layout. This is common in TF graph.
+            kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
+        elif attrs['kernel_layout'] == 'HWOI':
+            # Handle HWOI layout. This is common in TF depthwise conv2d graph.
+            kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
+        elif attrs['kernel_layout'] != 'OIHW':
+            return None
+
+        warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+                      + "fallback to NCHW. This can result in performance degradation.")
+        # Set new attrs for the tranposed conv.
+        new_attrs = {k: attrs[k] for k in attrs.keys()}
+        new_attrs['data_layout'] = 'NCHW'
+        new_attrs['kernel_layout'] = 'OIHW'
+
+        # Convert from NHWC to NCHW.
+        data = F.transpose(data, axes=(0, 3, 1, 2))
+        conv = F.nn.conv2d(data, kernel, **new_attrs)
+        # Convert back to original NHWC layout.
+        out = F.transpose(conv, axes=(0, 2, 3, 1))
+        return out
+    return None
index affbfca..e7ab7ba 100644 (file)
@@ -72,6 +72,28 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
 
 
 @tvm.target.generic_func
+def conv2d_legalize(attrs, inputs, arg_dtypes, F):
+    """Legalizes Conv2D op.
+    Parameters
+    ----------
+    attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized.
+    arg_dtypes : list of types
+        List of types of input arguments
+    F: symbol
+        The context, can be either nnvm.sym or relay.op
+    Note
+    ----
+    Unlike other TOPI functions, this function operates on both graph level and operator level,
+    so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
+    """
+    # not to change by default
+    return None
+
+
+@tvm.target.generic_func
 def conv2d_alter_layout(attrs, inputs, tinfos, F):
     """Change Conv2D layout.