# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
+import tvm
from tvm import relay
from tvm.relay.op import register_alter_op_layout
assert alpha_equal(a, b), "Actual = \n" + str(a)
+def test_alter_layout_depthwise_conv2d():
+ """Test depthwise_conv2d operator"""
+ def before():
+ x = relay.var("x", shape=(1, 32, 56, 56))
+ w = relay.var("w", shape=(32, 1, 3, 3))
+ y = relay.nn.conv2d(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), groups=32)
+ y = relay.Function(free_vars(y), y)
+ return y
+
+ import topi
+ @register_alter_op_layout("nn.conv2d", level=110)
+ def alter_conv2d(attrs, inputs, tinfos):
+ with tvm.target.create("llvm"):
+ return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay)
+
+ def expected():
+ x = relay.var("x", shape=(1, 32, 56, 56))
+ w = relay.var("w", shape=(32, 1, 3, 3))
+ x = relay.layout_transform(x, "NCHW", "NCHW8c")
+ w = relay.layout_transform(w, "OIHW", "OIHW1i8o")
+ y = relay.nn.contrib_depthwise_conv2d_nchwc(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3),
+ groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o",
+ out_layout="NCHW8c")
+ y = relay.layout_transform(y, "NCHW8c", "NCHW")
+ y = relay.Function(free_vars(y), y)
+ return y
+
+ a = before()
+ a = infer_type(a)
+ a = canonicalize_ops(a)
+ a = infer_type(a)
+ a = alter_op_layout(a)
+ a = infer_type(a)
+
+ b = expected()
+ b = infer_type(b)
+
+ assert(alpha_equal(a, b))
+
def test_alter_layout_prelu():
"""Test PRelu operator"""
def before():
y = relay.Function(free_vars(y), y)
return y
- @register_alter_op_layout("nn.conv2d", level=110)
+ @register_alter_op_layout("nn.conv2d", level=111)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice()
+ test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
from ..nn.util import get_pad_tuple
# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
-autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
+autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
depthwise_conv2d_nchw.fdefault)
# register customized schedule for arm cpu.
-@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'],
+@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu',
['direct', 'contrib_spatial_pack'])
def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
"""Schedule depthwise conv2d
traverse_inline(s, outs[0].op, _callback)
return s
-@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack'])
+@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack'])
def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""TOPI compute callback for depthwise_conv2d nchw
from numbers import Integral
import tvm
+from tvm.api import layout, bijective_layout
from . import tag
def traverse_inline(s, final_op, callback):
x *= 2
n /= 2
return x
+
+
+def get_shape(src_shape, src_layout, dst_layout):
+ """Given a source shape, a source layout and a destination layout, infer
+ the destination shape.
+
+ Parameter
+ ---------
+ src_shape : tuple of int or IntImm
+ Source shape
+
+ src_layout : str or Layout
+ Source layout
+
+ dst_layout : str or Layout
+ Destination layout
+
+ Returns
+ -------
+ dst_shape : tuple of int
+ Destination shape
+ """
+ if src_layout == dst_layout:
+ return get_const_tuple(src_shape)
+
+ if isinstance(src_layout, str):
+ src_layout = layout(src_layout)
+ if isinstance(dst_layout, str):
+ dst_layout = layout(dst_layout)
+
+ assert len(src_layout) == len(dst_layout), \
+ "Incompatible layout %s vs %s" % (src_layout, dst_layout)
+
+ layout_mapping = bijective_layout(src_layout, dst_layout)
+ dst_indices = layout_mapping.forward_index(
+ tvm.convert([i for i in range(len(src_layout))]))
+
+ return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
from tvm.autotvm.task import get_config
from .. import generic, tag
from .. import nn
-from ..util import get_const_tuple
+from ..util import get_const_tuple, get_shape
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
dtype = data.dtype
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
- is_depthwise = groups == in_channel and groups == out_channel
+
+ kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW")
+ is_depthwise = groups == kshape[0] and kshape[1] == 1
# only optimize for NCHW
- if layout != 'NCHW':
+ if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW":
return None
+
if groups != 1 and not is_depthwise:
return None
from tvm.autotvm.task.space import SplitEntity
from tvm.autotvm.task.topi_integration import deserialize_args
from .. import generic, tag
+from ..generic import schedule_depthwise_conv2d_nchw
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
-from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
- depthwise_conv2d_infer_layout
+from ..nn.depthwise_conv2d import depthwise_conv2d_nchw, depthwise_conv2d_NCHWc, \
+ _get_workload, depthwise_conv2d_infer_layout
from .util import get_fp32_len
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
+autotvm.register_topi_compute(depthwise_conv2d_nchw, 'cpu', 'direct',
+ depthwise_conv2d_nchw.fdefault)
+autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'cpu', 'direct',
+ schedule_depthwise_conv2d_nchw.fdefault)
+
+
@autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct')
def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
layout, out_layout, out_dtype=None):
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for util"""
+
+import topi
+
+
+def verify_get_shape(src_shape, src_layout, dst_layout, expect_shape):
+ dst_shape = topi.util.get_shape(src_shape, src_layout, dst_layout)
+ assert dst_shape == expect_shape, \
+ "Shape mismatch: expecting %s but got %s" % (expect_shape, dst_shape)
+
+
+def test_get_shape():
+ verify_get_shape((1, 3, 224, 224), "NCHW", "NCHW", (1, 3, 224, 224))
+ verify_get_shape((1, 3, 224, 224), "NCHW", "NHWC", (1, 224, 224, 3))
+ verify_get_shape((3, 2, 32, 48, 16), "NCHW16c", "NC16cWH", (3, 2, 16, 48, 32))
+ verify_get_shape((2, 3, 32, 32, 16, 8), "OIHW16i8o", "HWO8oI16i", (32, 32, 2, 8, 3, 16))
+
+if __name__ == "__main__":
+ test_get_shape()
\ No newline at end of file