Fix x86 depthwise conv2d alter_op_layout (#3264)
authorYao Wang <kevinthesunwy@gmail.com>
Thu, 6 Jun 2019 18:41:50 +0000 (11:41 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Thu, 6 Jun 2019 18:41:50 +0000 (11:41 -0700)
* Fix x86 depthwise conv2d alter_op_layout

* Small fix

* Add test case

* Fix test

* Assert kernel layout

* Minor fix

* Add get_shape function

* Minor change

tests/python/relay/test_pass_alter_op_layout.py
topi/python/topi/arm_cpu/depthwise_conv2d.py
topi/python/topi/util.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/depthwise_conv2d.py
topi/tests/python/test_topi_util.py [new file with mode: 0644]

index 2eea1c4..7d022ba 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test alter op layout pass"""
+import tvm
 
 from tvm import relay
 from tvm.relay.op import register_alter_op_layout
@@ -513,6 +514,45 @@ def test_alter_layout_strided_slice():
 
     assert alpha_equal(a, b), "Actual = \n" + str(a)
 
+def test_alter_layout_depthwise_conv2d():
+    """Test depthwise_conv2d operator"""
+    def before():
+        x = relay.var("x", shape=(1, 32, 56, 56))
+        w = relay.var("w", shape=(32, 1, 3, 3))
+        y = relay.nn.conv2d(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), groups=32)
+        y = relay.Function(free_vars(y), y)
+        return y
+
+    import topi
+    @register_alter_op_layout("nn.conv2d", level=110)
+    def alter_conv2d(attrs, inputs, tinfos):
+        with tvm.target.create("llvm"):
+            return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay)
+
+    def expected():
+        x = relay.var("x", shape=(1, 32, 56, 56))
+        w = relay.var("w", shape=(32, 1, 3, 3))
+        x = relay.layout_transform(x, "NCHW", "NCHW8c")
+        w = relay.layout_transform(w, "OIHW", "OIHW1i8o")
+        y = relay.nn.contrib_depthwise_conv2d_nchwc(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3),
+                                                    groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o",
+                                                    out_layout="NCHW8c")
+        y = relay.layout_transform(y, "NCHW8c", "NCHW")
+        y = relay.Function(free_vars(y), y)
+        return y
+
+    a = before()
+    a = infer_type(a)
+    a = canonicalize_ops(a)
+    a = infer_type(a)
+    a = alter_op_layout(a)
+    a = infer_type(a)
+
+    b = expected()
+    b = infer_type(b)
+
+    assert(alpha_equal(a, b))
+
 def test_alter_layout_prelu():
     """Test PRelu operator"""
     def before():
@@ -524,7 +564,7 @@ def test_alter_layout_prelu():
         y = relay.Function(free_vars(y), y)
         return y
 
-    @register_alter_op_layout("nn.conv2d", level=110)
+    @register_alter_op_layout("nn.conv2d", level=111)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
         new_attrs = dict(attrs)
@@ -571,4 +611,5 @@ if __name__ == "__main__":
     test_alter_layout_concatenate()
     test_alter_layout_nchw_upsamping_op()
     test_alter_layout_strided_slice()
+    test_alter_layout_depthwise_conv2d()
     test_alter_layout_prelu()
index e09e355..51088df 100644 (file)
@@ -26,11 +26,11 @@ from ..util import traverse_inline, get_const_tuple, get_const_int
 from ..nn.util import get_pad_tuple
 
 # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
-autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
+autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
                               depthwise_conv2d_nchw.fdefault)
 
 # register customized schedule for arm cpu.
-@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'],
+@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu',
                                 ['direct', 'contrib_spatial_pack'])
 def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
     """Schedule depthwise conv2d
@@ -151,7 +151,7 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
     return s
 
-@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack'])
+@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack'])
 def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
     """TOPI compute callback for depthwise_conv2d nchw
 
index 623c81a..d4e23be 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
 from numbers import Integral
 
 import tvm
+from tvm.api import layout, bijective_layout
 from . import tag
 
 def traverse_inline(s, final_op, callback):
@@ -289,3 +290,41 @@ def get_max_power2_factor(n, max_value=None):
         x *= 2
         n /= 2
     return x
+
+
+def get_shape(src_shape, src_layout, dst_layout):
+    """Given a source shape, a source layout and a destination layout, infer
+    the destination shape.
+
+    Parameter
+    ---------
+    src_shape : tuple of int or IntImm
+        Source shape
+
+    src_layout : str or Layout
+        Source layout
+
+    dst_layout : str or Layout
+        Destination layout
+
+    Returns
+    -------
+    dst_shape : tuple of int
+        Destination shape
+    """
+    if src_layout == dst_layout:
+        return get_const_tuple(src_shape)
+
+    if isinstance(src_layout, str):
+        src_layout = layout(src_layout)
+    if isinstance(dst_layout, str):
+        dst_layout = layout(dst_layout)
+
+    assert len(src_layout) == len(dst_layout), \
+        "Incompatible layout %s vs %s" % (src_layout, dst_layout)
+
+    layout_mapping = bijective_layout(src_layout, dst_layout)
+    dst_indices = layout_mapping.forward_index(
+        tvm.convert([i for i in range(len(src_layout))]))
+
+    return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
index d0894ad..82d1caa 100644 (file)
@@ -26,7 +26,7 @@ from tvm.autotvm.task.topi_integration import deserialize_args
 from tvm.autotvm.task import get_config
 from .. import generic, tag
 from .. import nn
-from ..util import get_const_tuple
+from ..util import get_const_tuple, get_shape
 from ..nn.conv2d import conv2d, conv2d_NCHWc, \
     conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
 from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
@@ -415,11 +415,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
 
     dtype = data.dtype
     out_dtype = dtype if out_dtype in ("same", "") else out_dtype
-    is_depthwise = groups == in_channel and groups == out_channel
+
+    kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW")
+    is_depthwise = groups == kshape[0] and kshape[1] == 1
 
     # only optimize for NCHW
-    if layout != 'NCHW':
+    if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW":
         return None
+
     if groups != 1 and not is_depthwise:
         return None
 
index 6ea11f2..ddcd841 100644 (file)
@@ -22,11 +22,12 @@ from tvm.autotvm.task import get_config
 from tvm.autotvm.task.space import SplitEntity
 from tvm.autotvm.task.topi_integration import deserialize_args
 from .. import generic, tag
+from ..generic import schedule_depthwise_conv2d_nchw
 from ..nn.pad import pad
 from ..util import get_const_tuple
 from ..nn.util import get_pad_tuple
-from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
-    depthwise_conv2d_infer_layout
+from ..nn.depthwise_conv2d import depthwise_conv2d_nchw, depthwise_conv2d_NCHWc, \
+    _get_workload, depthwise_conv2d_infer_layout
 
 from .util import get_fp32_len
 
@@ -70,6 +71,12 @@ def _fallback_schedule(cfg, wkl):
     cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
 
 
+autotvm.register_topi_compute(depthwise_conv2d_nchw, 'cpu', 'direct',
+                              depthwise_conv2d_nchw.fdefault)
+autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'cpu', 'direct',
+                               schedule_depthwise_conv2d_nchw.fdefault)
+
+
 @autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct')
 def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
                                 layout, out_layout, out_dtype=None):
diff --git a/topi/tests/python/test_topi_util.py b/topi/tests/python/test_topi_util.py
new file mode 100644 (file)
index 0000000..534b699
--- /dev/null
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for util"""
+
+import topi
+
+
+def verify_get_shape(src_shape, src_layout, dst_layout, expect_shape):
+    dst_shape = topi.util.get_shape(src_shape, src_layout, dst_layout)
+    assert dst_shape == expect_shape, \
+        "Shape mismatch: expecting %s but got %s" % (expect_shape, dst_shape)
+
+
+def test_get_shape():
+    verify_get_shape((1, 3, 224, 224), "NCHW", "NCHW", (1, 3, 224, 224))
+    verify_get_shape((1, 3, 224, 224), "NCHW", "NHWC", (1, 224, 224, 3))
+    verify_get_shape((3, 2, 32, 48, 16), "NCHW16c", "NC16cWH", (3, 2, 16, 48, 32))
+    verify_get_shape((2, 3, 32, 32, 16, 8), "OIHW16i8o", "HWO8oI16i", (32, 32, 2, 8, 3, 16))
+
+if __name__ == "__main__":
+    test_get_shape()
\ No newline at end of file