From b330d3012140dc3f2d91f5615985ca78916a827e Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Wed, 25 Sep 2019 22:48:50 -0700 Subject: [PATCH] [TOPI][x86] Introduce schedule_injective_from_existing and unify external schedules for all targets (#3983) * Fix extern schedule for x86 * Register x86::schedule_extern * Fix * Fix * Replace extern.py with extern.h * Introduce new generic function schedule_injective_from_existing * Fix * Fix * Add back to C++ * Fix style * Injective schedule calls local schedule_injective_from_existing * Fix * Remove target arg from schedule_injective_from_existing * Fix docs * Try to fix unit test * Fix test * Fix other tests * Fix bug --- tests/cpp/build_module_test.cc | 4 + .../python/unittest/test_runtime_heterogeneous.py | 6 +- topi/include/topi/cuda/extern.h | 86 ---------------------- topi/include/topi/cuda/injective.h | 29 ++++---- topi/include/topi/generic/extern.h | 10 +++ topi/include/topi/generic/injective.h | 15 +++- topi/include/topi/x86/injective.h | 32 +++++--- topi/python/topi/arm_cpu/injective.py | 35 +++++++-- topi/python/topi/cuda/__init__.py | 1 - topi/python/topi/cuda/conv2d_int8.py | 6 +- topi/python/topi/cuda/extern.py | 48 ------------ topi/python/topi/cuda/group_conv2d_nchw.py | 6 +- topi/python/topi/cuda/injective.py | 41 +++++++---- topi/python/topi/cuda/reduction.py | 6 +- topi/python/topi/cuda/softmax.py | 4 +- topi/python/topi/cuda/sort.py | 4 +- topi/python/topi/cuda/vision.py | 4 +- topi/python/topi/generic/extern.py | 8 +- topi/python/topi/generic/injective.py | 21 +++++- topi/python/topi/hls/injective.py | 25 ++++++- topi/python/topi/opengl/injective.py | 23 ++++-- topi/python/topi/x86/dense.py | 4 + topi/python/topi/x86/injective.py | 35 +++++++-- topi/src/topi.cc | 45 +++++++++-- 24 files changed, 273 insertions(+), 225 deletions(-) delete mode 100644 topi/include/topi/cuda/extern.h delete mode 100644 topi/python/topi/cuda/extern.py diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 1a7f791..a7237db 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -102,7 +102,11 @@ TEST(BuildModule, Heterogeneous) { return copy[i] - C[i]; }, "elemwise_sub"); + const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope"); + (*enter_target_scope_func)(target_cuda); auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); + + (*enter_target_scope_func)(target_llvm); auto s2 = create_schedule({elemwise_sub->op}); auto config = BuildConfig::Create(); diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index e0ae2e1..eb51fae 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -174,7 +174,8 @@ def test_simplex_data_transferring(): dev_tar = {"cuda": "cuda", "opencl": "opencl"} for device, target in dev_tar.items(): - check_device(device, target) + with tvm.target.create(device): + check_device(device, target) def get_duplex_graph(host_dev_type, device_dev_type): @@ -394,7 +395,8 @@ def test_duplex_data_transferring(): dev_tar = {"cuda": "cuda", "opencl": "opencl"} for device, target in dev_tar.items(): - check_device(device, target) + with tvm.target.create(device): + check_device(device, target) if __name__ == "__main__": test_simplex_data_transferring() diff --git a/topi/include/topi/cuda/extern.h b/topi/include/topi/cuda/extern.h deleted file mode 100644 index 7800986..0000000 --- a/topi/include/topi/cuda/extern.h +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file cuda/extern.h - * \brief CUDA schedule for extern followed by injective operations - */ -#ifndef TOPI_CUDA_EXTERN_H_ -#define TOPI_CUDA_EXTERN_H_ - -#include "topi/tags.h" -#include "topi/detail/fuse.h" -#include "tvm/operation.h" -#include "tvm/build_module.h" - -namespace topi { -using namespace tvm; - -namespace cuda { -/*! - * \brief Schedule a given operation representing one of the outputs of an - * external function which is followed by injective operations. - * - * \param target The target to generate a schedule for. - * \param op The operation representing the output followed by injective operations. - * \param sch The schedule to apply this scheduling to - * - * \return The schedule given by sch - */ -inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) { - auto x = op.output(0); - auto fused = detail::Fuse(sch[x], sch[x]->op.as()->axis); - auto num_thread = target->max_num_threads; - IterVar bx, tx; - sch[x].split(fused, num_thread, &bx, &tx); - sch[x].bind(bx, tvm::thread_axis(Range(), "blockIdx.x")); - sch[x].bind(tx, tvm::thread_axis(Range(), "threadIdx.x")); - return sch; -} - -/*! -* \brief Schedule an extern op followed by injective operations. -* For example, cudnn kernel + bias add + relu -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the op. -*/ -inline Schedule schedule_extern(const Target& target, Array outs) { - Array out_ops; - for (auto t : outs) { - out_ops.push_back(t->op); - } - auto s = create_schedule(out_ops); - - tvm::schedule::AutoInlineInjective(s); - for (auto out : outs) { - if (out->op->derived_from()) { - continue; - } - ScheduleOutputForExtern(target, out->op, s); - } - - return s; -} - -} // namespace cuda -} // namespace topi -#endif // TOPI_CUDA_EXTERN_H_ diff --git a/topi/include/topi/cuda/injective.h b/topi/include/topi/cuda/injective.h index e629ae1..663bc1f 100644 --- a/topi/include/topi/cuda/injective.h +++ b/topi/include/topi/cuda/injective.h @@ -33,21 +33,24 @@ namespace topi { using namespace tvm; namespace cuda { + /*! -* \brief Schedule a given injective operation. -* -* \param target The target to generate a schedule for. -* \param op The operation representing the injective operation. -* \param s The schedule to apply this scheduling to -*/ -inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) { - auto x = op.output(0); - auto fused = detail::Fuse(s[x], s[x]->op.as()->axis); + * \brief Updates an existing schedule for the given injective ops. + * + * \param sch The schedule to update. + * \param out The tensor representing the injective op. + * + * \return The updated schedule. + */ +inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) { + auto fused = detail::Fuse(sch[out], sch[out]->op.as()->axis); + auto target = Target::Current(false); auto num_thread = target->max_num_threads; IterVar bx, tx; - s[x].split(fused, num_thread, &bx, &tx); - s[x].bind(bx, thread_axis(Range(), "blockIdx.x")); - s[x].bind(tx, thread_axis(Range(), "threadIdx.x")); + sch[out].split(fused, num_thread, &bx, &tx); + sch[out].bind(bx, thread_axis(Range(), "blockIdx.x")); + sch[out].bind(tx, thread_axis(Range(), "threadIdx.x")); + return sch; } /*! @@ -66,7 +69,7 @@ inline Schedule schedule_injective(const Target &target, const Array& ou auto s = create_schedule(out_ops); tvm::schedule::AutoInlineInjective(s); for (auto out : outs) { - ScheduleInjectiveOp(target, out->op, s); + schedule_injective_from_existing(s, out); } return s; } diff --git a/topi/include/topi/generic/extern.h b/topi/include/topi/generic/extern.h index 6c51d7c..5c0c392 100644 --- a/topi/include/topi/generic/extern.h +++ b/topi/include/topi/generic/extern.h @@ -28,6 +28,7 @@ #include "topi/detail/fuse.h" #include "tvm/operation.h" #include "tvm/build_module.h" +#include "injective.h" namespace topi { using namespace tvm; @@ -47,6 +48,15 @@ inline Schedule schedule_extern(const Target& target, Array outs) { out_ops.push_back(t->op); } auto s = create_schedule(out_ops); + + tvm::schedule::AutoInlineInjective(s); + for (auto out : outs) { + if (out->op->derived_from()) { + continue; + } + tvm::GenericFunc::Get("schedule_injective_from_existing")(s, out); + } + return s; } diff --git a/topi/include/topi/generic/injective.h b/topi/include/topi/generic/injective.h index f26651c..fa7df4c 100644 --- a/topi/include/topi/generic/injective.h +++ b/topi/include/topi/generic/injective.h @@ -35,6 +35,19 @@ using namespace tvm; namespace generic { /*! + * \brief Updates an existing schedule for the given injective ops. + * + * \param sch The schedule to update. + * \param out The tensor representing the injective op. + * + * \return The updated schedule. + */ +inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) { + detail::Fuse(sch[out], sch[out]->op.as()->axis); + return sch; +} + +/*! * \brief Create a generic schedule for the given injective ops. * * \param target The target to generate a schedule for. @@ -50,7 +63,7 @@ inline Schedule schedule_injective(const Target &target, const Array& ou auto s = create_schedule(out_ops); tvm::schedule::AutoInlineInjective(s); auto x = outs[0]; - detail::Fuse(s[x], s[x]->op.as()->axis); + schedule_injective_from_existing(s, x); return s; } diff --git a/topi/include/topi/x86/injective.h b/topi/include/topi/x86/injective.h index adbe382..7cb79ae 100644 --- a/topi/include/topi/x86/injective.h +++ b/topi/include/topi/x86/injective.h @@ -33,6 +33,28 @@ namespace topi { using namespace tvm; namespace x86 { + +/*! + * \brief Updates an existing schedule for the given injective ops. + * + * \param sch The schedule to update. + * \param out The tensor representing the injective op. + * + * \return The updated schedule. + */ +inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) { + auto axis = sch[out]->op.as()->axis; + if (axis.size() == 4) { + auto n = axis[0]; + auto c = axis[1]; + auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h + sch[out].parallel(fused); + } else { + sch[out].parallel(axis[0]); + } + return sch; +} + /*! * \brief Create an x86 schedule for the given injective ops. * @@ -50,15 +72,7 @@ inline Schedule schedule_injective(const Target &target, const Array& ou tvm::schedule::AutoInlineInjective(s); auto x = outs[0]; - auto axis = s[x]->op.as()->axis; - if (axis.size() == 4) { - auto n = axis[0]; - auto c = axis[1]; - auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h - s[x].parallel(fused); - } else { - s[x].parallel(axis[0]); - } + schedule_injective_from_existing(s, x); return s; } diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 028558f..3727754 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -19,6 +19,32 @@ import tvm from .. import generic +@generic.schedule_injective_from_existing.register(["arm_cpu"]) +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + + Returns + ------- + sch: Schedule + The updated schedule. + """ + if len(sch[out].op.axis) >= 4: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 3: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 2: + sch[out].parallel(sch[out].op.axis[0]) + return sch + @generic.schedule_injective.register(["arm_cpu"]) def schedule_injective(outs): """ARM CPU schedule for injective op. @@ -42,14 +68,7 @@ def schedule_injective(outs): (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8) s[x].vectorize(ii) tvm.schedule.AutoInlineInjective(s) - if len(s[x].op.axis) >= 4: - fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2]) - s[x].parallel(fused) - elif len(s[x].op.axis) >= 3: - fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) - s[x].parallel(fused) - elif len(s[x].op.axis) >= 2: - s[x].parallel(s[x].op.axis[0]) + schedule_injective_from_existing(s, x) return s @generic.schedule_concatenate.register(["arm_cpu"]) diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 403f67b..beab0af 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -13,7 +13,6 @@ from .softmax import schedule_softmax from .injective import schedule_injective, schedule_elemwise, schedule_broadcast from .dense import schedule_dense from .pooling import schedule_pool, schedule_adaptive_pool -from .extern import schedule_extern from .nn import schedule_lrn, schedule_l2_normalize from .batch_matmul import schedule_batch_matmul from .vision import * diff --git a/topi/python/topi/cuda/conv2d_int8.py b/topi/python/topi/cuda/conv2d_int8.py index 25d4945..580cf96 100644 --- a/topi/python/topi/cuda/conv2d_int8.py +++ b/topi/python/topi/cuda/conv2d_int8.py @@ -19,7 +19,7 @@ import tvm from tvm import autotvm -from .injective import _schedule_injective +from .injective import schedule_injective_from_existing from .tensor_intrin import dp4a from ..nn.pad import pad from ..nn.util import get_pad_tuple @@ -172,8 +172,8 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output): if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\ packed_kernel.name == 'packed_kernel': # data and kernel are not pre-computed, schedule layout transform here - _schedule_injective(packed_data.op, s) - _schedule_injective(packed_kernel.op, s) + schedule_injective_from_existing(s, packed_data) + schedule_injective_from_existing(s, packed_kernel) if pad_data != packed_data: s[pad_data].compute_inline() diff --git a/topi/python/topi/cuda/extern.py b/topi/python/topi/cuda/extern.py deleted file mode 100644 index 74b88ec..0000000 --- a/topi/python/topi/cuda/extern.py +++ /dev/null @@ -1,48 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-variable, -"""Schedule for cudnn and miopen extern op""" -import tvm -from .. import generic -from .injective import _schedule_injective - - -@generic.schedule_extern.register(["cuda", "gpu"]) -def schedule_extern(outs): - """Schedule for an extern op followed by injective operations. - For example, cudnn kernel + bias add + relu. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of extern plus injective ops in the format - of an array of tensors. - - Returns - ------- - sch: Schedule - The computation schedule for the op. - """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - - tvm.schedule.AutoInlineInjective(s) - for out in outs: - if isinstance(out.op, tvm.tensor.ExternOp): - continue - _schedule_injective(out.op, s) - return s diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index cd8c823..f4bb734 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -19,7 +19,7 @@ import tvm from tvm import autotvm -from .injective import _schedule_injective +from .injective import schedule_injective_from_existing from .tensor_intrin import dp4a from ..nn.pad import pad from ..nn.util import get_pad_tuple @@ -201,8 +201,8 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output): if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\ packed_kernel.name == 'packed_kernel': # data and kernel are not pre-computed, schedule layout transform here - _schedule_injective(packed_data.op, s) - _schedule_injective(packed_kernel.op, s) + schedule_injective_from_existing(s, packed_data) + schedule_injective_from_existing(s, packed_kernel) if pad_data != packed_data: s[pad_data].compute_inline() diff --git a/topi/python/topi/cuda/injective.py b/topi/python/topi/cuda/injective.py index 8a16c5c..a6ec853 100644 --- a/topi/python/topi/cuda/injective.py +++ b/topi/python/topi/cuda/injective.py @@ -19,33 +19,46 @@ import tvm from .. import generic, util -def _schedule_injective(op, sch): - x = op.output(0) - fused = sch[x].fuse(*sch[x].op.axis) +@generic.schedule_injective_from_existing.register(["cuda", "gpu"]) +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + + Returns + ------- + sch: Schedule + The updated schedule. + """ + fused = sch[out].fuse(*sch[out].op.axis) num_thread = tvm.target.current_target(allow_none=False).max_num_threads max_block = 256 try: - const_size = util.get_const_int(util.prod(x.shape)) + const_size = util.get_const_int(util.prod(out.shape)) max_block = 256 need_block_split = const_size > max_block * num_thread except ValueError: need_block_split = False if need_block_split: - xo, xi = sch[x].split(fused, factor=num_thread * max_block) - bx, tx = sch[x].split(xi, factor=num_thread) - sch[x].reorder(bx, tx, xo) - sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) - sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) + xo, xi = sch[out].split(fused, factor=num_thread * max_block) + bx, tx = sch[out].split(xi, factor=num_thread) + sch[out].reorder(bx, tx, xo) + sch[out].bind(bx, tvm.thread_axis("blockIdx.x")) + sch[out].bind(tx, tvm.thread_axis("threadIdx.x")) else: - bx, tx = sch[x].split(fused, factor=num_thread) - sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) - sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) + bx, tx = sch[out].split(fused, factor=num_thread) + sch[out].bind(tx, tvm.thread_axis("threadIdx.x")) + sch[out].bind(bx, tvm.thread_axis("blockIdx.x")) return sch - @generic.schedule_injective.register(["cuda", "gpu"]) def schedule_injective(outs): """Schedule for injective op. @@ -66,7 +79,7 @@ def schedule_injective(outs): tvm.schedule.AutoInlineInjective(s) for out in outs: - _schedule_injective(out.op, s) + schedule_injective_from_existing(s, out) return s schedule_elemwise = schedule_injective diff --git a/topi/python/topi/cuda/reduction.py b/topi/python/topi/cuda/reduction.py index 2588531..a7a56b9 100644 --- a/topi/python/topi/cuda/reduction.py +++ b/topi/python/topi/cuda/reduction.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs import tvm from .. import tag from .. import generic -from .injective import _schedule_injective +from .injective import schedule_injective_from_existing def _schedule_reduce(op, sch, is_idx_reduce=False): if is_idx_reduce: @@ -30,7 +30,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): data_out = op.output(0) if not sch[data_out].op.reduce_axis: - return _schedule_injective(op, sch) + return schedule_injective_from_existing(sch, op.output(0)) if len(sch[data_out].op.axis) > 0: all_reduce = False @@ -126,7 +126,7 @@ def schedule_reduce(outs): """Internal travserse function""" if tag.is_broadcast(operator.tag): if operator not in scheduled_ops: - _schedule_injective(operator, sch) + schedule_injective_from_existing(sch, operator.output(0)) for tensor in operator.input_tensors: traverse_after_reduce(tensor.op) elif operator.tag == 'comm_reduce': diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 09b6ef8..9fdac50 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -18,7 +18,7 @@ """Schedule for softmax operator""" import tvm from .. import generic -from .injective import _schedule_injective +from .injective import schedule_injective_from_existing @generic.schedule_softmax.register(["cuda", "gpu"]) def schedule_softmax(outs): @@ -58,7 +58,7 @@ def schedule_softmax(outs): ops.append(exp.op) for op in ops: - s = _schedule_injective(op, s) + s = schedule_injective_from_existing(s, op.output(0)) else: num_thread = 64 block_x = tvm.thread_axis("blockIdx.x") diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 1d9148f..c45465e 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -42,10 +42,10 @@ def _schedule_sort(outs): outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] - from .injective import _schedule_injective + from .injective import schedule_injective_from_existing def traverse(op): if tag.is_injective(op.tag): - _schedule_injective(op, s) + schedule_injective_from_existing(s, op.output(0)) for tensor in op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 968e554..3a90402 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -28,10 +28,10 @@ def _default_schedule(outs): outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] - from .injective import _schedule_injective + from .injective import schedule_injective_from_existing def traverse(op): if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']: - _schedule_injective(op, s) + schedule_injective_from_existing(s, op.output(0)) for tensor in op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) diff --git a/topi/python/topi/generic/extern.py b/topi/python/topi/generic/extern.py index 3728222..a060114 100644 --- a/topi/python/topi/generic/extern.py +++ b/topi/python/topi/generic/extern.py @@ -19,6 +19,7 @@ from __future__ import absolute_import as _abs import tvm +from .. import cpp @tvm.target.generic_func def schedule_extern(outs): @@ -35,8 +36,5 @@ def schedule_extern(outs): sch: Schedule The computation schedule for the op. """ - target = tvm.target.current_target(allow_none=False) - if target.target_name != "llvm": - raise RuntimeError("schedule_extern not registered for '%s'" % target) - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - return tvm.create_schedule([x.op for x in outs]) + target = tvm.target.current_target() + return cpp.generic.schedule_extern(target, outs) diff --git a/topi/python/topi/generic/injective.py b/topi/python/topi/generic/injective.py index ec1732f..178363d 100644 --- a/topi/python/topi/generic/injective.py +++ b/topi/python/topi/generic/injective.py @@ -20,6 +20,25 @@ from __future__ import absolute_import as _abs import tvm +@tvm.target.override_native_generic_func("schedule_injective_from_existing") +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + + Returns + ------- + sch: Schedule + The updated schedule. + """ + sch[out].fuse(s[out].op.axis) + return sch + @tvm.target.override_native_generic_func("schedule_injective") def schedule_injective(outs): """Schedule for injective op. @@ -42,7 +61,7 @@ def schedule_injective(outs): x = outs[0] s = tvm.create_schedule([x.op for x in outs]) tvm.schedule.AutoInlineInjective(s) - s[x].fuse(s[x].op.axis) + schedule_injective_from_existing(s, x) return s @tvm.target.generic_func diff --git a/topi/python/topi/hls/injective.py b/topi/python/topi/hls/injective.py index 5a43306..de58428 100644 --- a/topi/python/topi/hls/injective.py +++ b/topi/python/topi/hls/injective.py @@ -19,6 +19,27 @@ import tvm from .. import generic +@generic.schedule_injective_from_existing.register(["hls"]) +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + + Returns + ------- + sch: Schedule + The updated schedule. + """ + fused = sch[out].fuse(*sch[out].op.axis) + px, x = sch[out].split(fused, nparts=1) + sch[out].bind(px, tvm.thread_axis("pipeline")) + return sch + @generic.schedule_injective.register(["hls"]) def schedule_injective(outs): """Schedule for injective op. @@ -38,9 +59,7 @@ def schedule_injective(outs): s = tvm.create_schedule([x.op for x in outs]) tvm.schedule.AutoInlineInjective(s) for out in outs: - fused = s[out].fuse(*s[out].op.axis) - px, x = s[out].split(fused, nparts=1) - s[out].bind(px, tvm.thread_axis("pipeline")) + schedule_injective_from_existing(s, out) return s schedule_elemwise = schedule_injective diff --git a/topi/python/topi/opengl/injective.py b/topi/python/topi/opengl/injective.py index 1f3f494..d3ebc94 100644 --- a/topi/python/topi/opengl/injective.py +++ b/topi/python/topi/opengl/injective.py @@ -19,11 +19,24 @@ import tvm from .. import generic -def _schedule_injective(op, sch): - x = op.output(0) - sch[x].opengl() - return sch +@generic.schedule_injective_from_existing.register(["opengl"]) +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + Returns + ------- + sch: Schedule + The updated schedule. + """ + sch[out].opengl() + return sch @generic.schedule_injective.register(["opengl"]) def schedule_injective(outs): @@ -45,7 +58,7 @@ def schedule_injective(outs): tvm.schedule.AutoInlineInjective(s) for out in outs: - _schedule_injective(out.op, s) + schedule_injective_from_existing(s, out) return s schedule_elemwise = schedule_injective diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py index 27cadfa..9b7624e 100644 --- a/topi/python/topi/x86/dense.py +++ b/topi/python/topi/x86/dense.py @@ -111,6 +111,10 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct") def _schedule_dense(cfg, outs): + target = tvm.target.current_target() + if "cblas" in target.libs: + return generic.schedule_extern(outs) + s = tvm.create_schedule([x.op for x in outs]) def _callback(op): diff --git a/topi/python/topi/x86/injective.py b/topi/python/topi/x86/injective.py index 37080e0..5bcb179 100644 --- a/topi/python/topi/x86/injective.py +++ b/topi/python/topi/x86/injective.py @@ -20,6 +20,32 @@ from __future__ import absolute_import as _abs import tvm from .. import generic +@generic.schedule_injective_from_existing.register(["cpu"]) +def schedule_injective_from_existing(sch, out): + """Schedule for injective op from existing schedule. + + Parameters + ---------- + sch: Schedule + The schedule to update. + out: Tensor + The tensor representing the injective op. + + Returns + ------- + sch: Schedule + The updated schedule. + """ + if len(sch[out].op.axis) >= 5: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 3: + fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1]) + sch[out].parallel(fused) + elif len(sch[out].op.axis) >= 1: + sch[out].parallel(sch[out].op.axis[0]) + return sch + @generic.schedule_injective.register(["cpu"]) def schedule_injective(outs): """X86 schedule for injective op. @@ -39,14 +65,7 @@ def schedule_injective(outs): x = outs[0] s = tvm.create_schedule([x.op for x in outs]) tvm.schedule.AutoInlineInjective(s) - if len(s[x].op.axis) >= 5: - fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2]) - s[x].parallel(fused) - elif len(s[x].op.axis) >= 3: - fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) - s[x].parallel(fused) - elif len(s[x].op.axis) >= 1: - s[x].parallel(s[x].op.axis[0]) + schedule_injective_from_existing(s, x) return s @generic.schedule_concatenate.register(["cpu"]) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index b070939..7114f4d 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -56,7 +56,6 @@ #include #include -#include #include #include #include @@ -586,6 +585,11 @@ TVM_REGISTER_GLOBAL("topi.generic.schedule_injective") *rv = topi::generic::schedule_injective(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); + }); + /* x86 schedules */ TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -611,6 +615,11 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") *rv = topi::x86::schedule_injective(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); + }); + /* ROCm schedules */ TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -643,16 +652,16 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense") *rv = topi::cuda::schedule_dense(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_extern") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_extern(args[0], args[1]); - }); - TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = topi::cuda::schedule_injective(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); + }); + TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = topi::cuda::schedule_pool(args[0], args[1]); @@ -752,6 +761,30 @@ TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) .set_default(WrapSchedule(topi::generic::default_schedule)) .register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense)); +/*! \brief Builder function for instantiating schedules from existing schedules. */ +using FTVMScheduleFromExistingBuilder = std::function< + tvm::Schedule(tvm::Schedule sch, const tvm::Tensor& out)>; + +/*! + * \brief Helper function for registering generic functions matching the + * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped + * with a PackedFunc suitable for passing to a tvm::GenericFunc. + * + * \param builder The schedule builder to wrap. + * + * \return The wrapped schedule builder + */ +inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) { + return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { + *ret = builder(args[0], args[1]); + }); +} + +TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) +.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) +.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) +.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing)); + /*! \brief Builder function for instantiating dense ops. */ using FTVMDenseOpBuilder = std::function