return copy[i] - C[i];
}, "elemwise_sub");
+ const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope");
+ (*enter_target_scope_func)(target_cuda);
auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
+
+ (*enter_target_scope_func)(target_llvm);
auto s2 = create_schedule({elemwise_sub->op});
auto config = BuildConfig::Create();
dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
- check_device(device, target)
+ with tvm.target.create(device):
+ check_device(device, target)
def get_duplex_graph(host_dev_type, device_dev_type):
dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
- check_device(device, target)
+ with tvm.target.create(device):
+ check_device(device, target)
if __name__ == "__main__":
test_simplex_data_transferring()
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cuda/extern.h
- * \brief CUDA schedule for extern followed by injective operations
- */
-#ifndef TOPI_CUDA_EXTERN_H_
-#define TOPI_CUDA_EXTERN_H_
-
-#include "topi/tags.h"
-#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
-#include "tvm/build_module.h"
-
-namespace topi {
-using namespace tvm;
-
-namespace cuda {
-/*!
- * \brief Schedule a given operation representing one of the outputs of an
- * external function which is followed by injective operations.
- *
- * \param target The target to generate a schedule for.
- * \param op The operation representing the output followed by injective operations.
- * \param sch The schedule to apply this scheduling to
- *
- * \return The schedule given by sch
- */
-inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
- auto x = op.output(0);
- auto fused = detail::Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
- auto num_thread = target->max_num_threads;
- IterVar bx, tx;
- sch[x].split(fused, num_thread, &bx, &tx);
- sch[x].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
- sch[x].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
- return sch;
-}
-
-/*!
-* \brief Schedule an extern op followed by injective operations.
-* For example, cudnn kernel + bias add + relu
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the op.
-*/
-inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
- Array<Operation> out_ops;
- for (auto t : outs) {
- out_ops.push_back(t->op);
- }
- auto s = create_schedule(out_ops);
-
- tvm::schedule::AutoInlineInjective(s);
- for (auto out : outs) {
- if (out->op->derived_from<ExternOpNode>()) {
- continue;
- }
- ScheduleOutputForExtern(target, out->op, s);
- }
-
- return s;
-}
-
-} // namespace cuda
-} // namespace topi
-#endif // TOPI_CUDA_EXTERN_H_
using namespace tvm;
namespace cuda {
+
/*!
-* \brief Schedule a given injective operation.
-*
-* \param target The target to generate a schedule for.
-* \param op The operation representing the injective operation.
-* \param s The schedule to apply this scheduling to
-*/
-inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
- auto x = op.output(0);
- auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
+ * \brief Updates an existing schedule for the given injective ops.
+ *
+ * \param sch The schedule to update.
+ * \param out The tensor representing the injective op.
+ *
+ * \return The updated schedule.
+ */
+inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
+ auto fused = detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
+ auto target = Target::Current(false);
auto num_thread = target->max_num_threads;
IterVar bx, tx;
- s[x].split(fused, num_thread, &bx, &tx);
- s[x].bind(bx, thread_axis(Range(), "blockIdx.x"));
- s[x].bind(tx, thread_axis(Range(), "threadIdx.x"));
+ sch[out].split(fused, num_thread, &bx, &tx);
+ sch[out].bind(bx, thread_axis(Range(), "blockIdx.x"));
+ sch[out].bind(tx, thread_axis(Range(), "threadIdx.x"));
+ return sch;
}
/*!
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
- ScheduleInjectiveOp(target, out->op, s);
+ schedule_injective_from_existing(s, out);
}
return s;
}
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
+#include "injective.h"
namespace topi {
using namespace tvm;
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
+
+ tvm::schedule::AutoInlineInjective(s);
+ for (auto out : outs) {
+ if (out->op->derived_from<ExternOpNode>()) {
+ continue;
+ }
+ tvm::GenericFunc::Get("schedule_injective_from_existing")(s, out);
+ }
+
return s;
}
namespace generic {
/*!
+ * \brief Updates an existing schedule for the given injective ops.
+ *
+ * \param sch The schedule to update.
+ * \param out The tensor representing the injective op.
+ *
+ * \return The updated schedule.
+ */
+inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
+ detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
+ return sch;
+}
+
+/*!
* \brief Create a generic schedule for the given injective ops.
*
* \param target The target to generate a schedule for.
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
- detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
+ schedule_injective_from_existing(s, x);
return s;
}
using namespace tvm;
namespace x86 {
+
+/*!
+ * \brief Updates an existing schedule for the given injective ops.
+ *
+ * \param sch The schedule to update.
+ * \param out The tensor representing the injective op.
+ *
+ * \return The updated schedule.
+ */
+inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
+ auto axis = sch[out]->op.as<ComputeOpNode>()->axis;
+ if (axis.size() == 4) {
+ auto n = axis[0];
+ auto c = axis[1];
+ auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h
+ sch[out].parallel(fused);
+ } else {
+ sch[out].parallel(axis[0]);
+ }
+ return sch;
+}
+
/*!
* \brief Create an x86 schedule for the given injective ops.
*
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
- auto axis = s[x]->op.as<ComputeOpNode>()->axis;
- if (axis.size() == 4) {
- auto n = axis[0];
- auto c = axis[1];
- auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
- s[x].parallel(fused);
- } else {
- s[x].parallel(axis[0]);
- }
+ schedule_injective_from_existing(s, x);
return s;
}
import tvm
from .. import generic
+@generic.schedule_injective_from_existing.register(["arm_cpu"])
+def schedule_injective_from_existing(sch, out):
+ """Schedule for injective op from existing schedule.
+
+ Parameters
+ ----------
+ sch: Schedule
+ The schedule to update.
+ out: Tensor
+ The tensor representing the injective op.
+
+ Returns
+ -------
+ sch: Schedule
+ The updated schedule.
+ """
+ if len(sch[out].op.axis) >= 4:
+ fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
+ sch[out].parallel(fused)
+ elif len(sch[out].op.axis) >= 3:
+ fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
+ sch[out].parallel(fused)
+ elif len(sch[out].op.axis) >= 2:
+ sch[out].parallel(sch[out].op.axis[0])
+ return sch
+
@generic.schedule_injective.register(["arm_cpu"])
def schedule_injective(outs):
"""ARM CPU schedule for injective op.
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s)
- if len(s[x].op.axis) >= 4:
- fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
- s[x].parallel(fused)
- elif len(s[x].op.axis) >= 3:
- fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
- s[x].parallel(fused)
- elif len(s[x].op.axis) >= 2:
- s[x].parallel(s[x].op.axis[0])
+ schedule_injective_from_existing(s, x)
return s
@generic.schedule_concatenate.register(["arm_cpu"])
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_adaptive_pool
-from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
from .vision import *
import tvm
from tvm import autotvm
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
packed_kernel.name == 'packed_kernel':
# data and kernel are not pre-computed, schedule layout transform here
- _schedule_injective(packed_data.op, s)
- _schedule_injective(packed_kernel.op, s)
+ schedule_injective_from_existing(s, packed_data)
+ schedule_injective_from_existing(s, packed_kernel)
if pad_data != packed_data:
s[pad_data].compute_inline()
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name, unused-variable,
-"""Schedule for cudnn and miopen extern op"""
-import tvm
-from .. import generic
-from .injective import _schedule_injective
-
-
-@generic.schedule_extern.register(["cuda", "gpu"])
-def schedule_extern(outs):
- """Schedule for an extern op followed by injective operations.
- For example, cudnn kernel + bias add + relu.
-
- Parameters
- ----------
- outs: Array of Tensor
- The computation graph description of extern plus injective ops in the format
- of an array of tensors.
-
- Returns
- -------
- sch: Schedule
- The computation schedule for the op.
- """
- outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
- s = tvm.create_schedule([x.op for x in outs])
-
- tvm.schedule.AutoInlineInjective(s)
- for out in outs:
- if isinstance(out.op, tvm.tensor.ExternOp):
- continue
- _schedule_injective(out.op, s)
- return s
import tvm
from tvm import autotvm
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
packed_kernel.name == 'packed_kernel':
# data and kernel are not pre-computed, schedule layout transform here
- _schedule_injective(packed_data.op, s)
- _schedule_injective(packed_kernel.op, s)
+ schedule_injective_from_existing(s, packed_data)
+ schedule_injective_from_existing(s, packed_kernel)
if pad_data != packed_data:
s[pad_data].compute_inline()
import tvm
from .. import generic, util
-def _schedule_injective(op, sch):
- x = op.output(0)
- fused = sch[x].fuse(*sch[x].op.axis)
+@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
+def schedule_injective_from_existing(sch, out):
+ """Schedule for injective op from existing schedule.
+
+ Parameters
+ ----------
+ sch: Schedule
+ The schedule to update.
+ out: Tensor
+ The tensor representing the injective op.
+
+ Returns
+ -------
+ sch: Schedule
+ The updated schedule.
+ """
+ fused = sch[out].fuse(*sch[out].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
max_block = 256
try:
- const_size = util.get_const_int(util.prod(x.shape))
+ const_size = util.get_const_int(util.prod(out.shape))
max_block = 256
need_block_split = const_size > max_block * num_thread
except ValueError:
need_block_split = False
if need_block_split:
- xo, xi = sch[x].split(fused, factor=num_thread * max_block)
- bx, tx = sch[x].split(xi, factor=num_thread)
- sch[x].reorder(bx, tx, xo)
- sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
- sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
+ xo, xi = sch[out].split(fused, factor=num_thread * max_block)
+ bx, tx = sch[out].split(xi, factor=num_thread)
+ sch[out].reorder(bx, tx, xo)
+ sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
+ sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
else:
- bx, tx = sch[x].split(fused, factor=num_thread)
- sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
- sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
+ bx, tx = sch[out].split(fused, factor=num_thread)
+ sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
+ sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
return sch
-
@generic.schedule_injective.register(["cuda", "gpu"])
def schedule_injective(outs):
"""Schedule for injective op.
tvm.schedule.AutoInlineInjective(s)
for out in outs:
- _schedule_injective(out.op, s)
+ schedule_injective_from_existing(s, out)
return s
schedule_elemwise = schedule_injective
import tvm
from .. import tag
from .. import generic
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce:
data_out = op.output(0)
if not sch[data_out].op.reduce_axis:
- return _schedule_injective(op, sch)
+ return schedule_injective_from_existing(sch, op.output(0))
if len(sch[data_out].op.axis) > 0:
all_reduce = False
"""Internal travserse function"""
if tag.is_broadcast(operator.tag):
if operator not in scheduled_ops:
- _schedule_injective(operator, sch)
+ schedule_injective_from_existing(sch, operator.output(0))
for tensor in operator.input_tensors:
traverse_after_reduce(tensor.op)
elif operator.tag == 'comm_reduce':
"""Schedule for softmax operator"""
import tvm
from .. import generic
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
@generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs):
ops.append(exp.op)
for op in ops:
- s = _schedule_injective(op, s)
+ s = schedule_injective_from_existing(s, op.output(0))
else:
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
- from .injective import _schedule_injective
+ from .injective import schedule_injective_from_existing
def traverse(op):
if tag.is_injective(op.tag):
- _schedule_injective(op, s)
+ schedule_injective_from_existing(s, op.output(0))
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
- from .injective import _schedule_injective
+ from .injective import schedule_injective_from_existing
def traverse(op):
if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
- _schedule_injective(op, s)
+ schedule_injective_from_existing(s, op.output(0))
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
from __future__ import absolute_import as _abs
import tvm
+from .. import cpp
@tvm.target.generic_func
def schedule_extern(outs):
sch: Schedule
The computation schedule for the op.
"""
- target = tvm.target.current_target(allow_none=False)
- if target.target_name != "llvm":
- raise RuntimeError("schedule_extern not registered for '%s'" % target)
- outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
- return tvm.create_schedule([x.op for x in outs])
+ target = tvm.target.current_target()
+ return cpp.generic.schedule_extern(target, outs)
import tvm
+@tvm.target.override_native_generic_func("schedule_injective_from_existing")
+def schedule_injective_from_existing(sch, out):
+ """Schedule for injective op from existing schedule.
+
+ Parameters
+ ----------
+ sch: Schedule
+ The schedule to update.
+ out: Tensor
+ The tensor representing the injective op.
+
+ Returns
+ -------
+ sch: Schedule
+ The updated schedule.
+ """
+ sch[out].fuse(s[out].op.axis)
+ return sch
+
@tvm.target.override_native_generic_func("schedule_injective")
def schedule_injective(outs):
"""Schedule for injective op.
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
- s[x].fuse(s[x].op.axis)
+ schedule_injective_from_existing(s, x)
return s
@tvm.target.generic_func
import tvm
from .. import generic
+@generic.schedule_injective_from_existing.register(["hls"])
+def schedule_injective_from_existing(sch, out):
+ """Schedule for injective op from existing schedule.
+
+ Parameters
+ ----------
+ sch: Schedule
+ The schedule to update.
+ out: Tensor
+ The tensor representing the injective op.
+
+ Returns
+ -------
+ sch: Schedule
+ The updated schedule.
+ """
+ fused = sch[out].fuse(*sch[out].op.axis)
+ px, x = sch[out].split(fused, nparts=1)
+ sch[out].bind(px, tvm.thread_axis("pipeline"))
+ return sch
+
@generic.schedule_injective.register(["hls"])
def schedule_injective(outs):
"""Schedule for injective op.
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
for out in outs:
- fused = s[out].fuse(*s[out].op.axis)
- px, x = s[out].split(fused, nparts=1)
- s[out].bind(px, tvm.thread_axis("pipeline"))
+ schedule_injective_from_existing(s, out)
return s
schedule_elemwise = schedule_injective
import tvm
from .. import generic
-def _schedule_injective(op, sch):
- x = op.output(0)
- sch[x].opengl()
- return sch
+@generic.schedule_injective_from_existing.register(["opengl"])
+def schedule_injective_from_existing(sch, out):
+ """Schedule for injective op from existing schedule.
+
+ Parameters
+ ----------
+ sch: Schedule
+ The schedule to update.
+ out: Tensor
+ The tensor representing the injective op.
+ Returns
+ -------
+ sch: Schedule
+ The updated schedule.
+ """
+ sch[out].opengl()
+ return sch
@generic.schedule_injective.register(["opengl"])
def schedule_injective(outs):
tvm.schedule.AutoInlineInjective(s)
for out in outs:
- _schedule_injective(out.op, s)
+ schedule_injective_from_existing(s, out)
return s
schedule_elemwise = schedule_injective
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
def _schedule_dense(cfg, outs):
+ target = tvm.target.current_target()
+ if "cblas" in target.libs:
+ return generic.schedule_extern(outs)
+
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
import tvm
from .. import generic
+@generic.schedule_injective_from_existing.register(["cpu"])
+def schedule_injective_from_existing(sch, out):
+ """Schedule for injective op from existing schedule.
+
+ Parameters
+ ----------
+ sch: Schedule
+ The schedule to update.
+ out: Tensor
+ The tensor representing the injective op.
+
+ Returns
+ -------
+ sch: Schedule
+ The updated schedule.
+ """
+ if len(sch[out].op.axis) >= 5:
+ fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
+ sch[out].parallel(fused)
+ elif len(sch[out].op.axis) >= 3:
+ fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
+ sch[out].parallel(fused)
+ elif len(sch[out].op.axis) >= 1:
+ sch[out].parallel(sch[out].op.axis[0])
+ return sch
+
@generic.schedule_injective.register(["cpu"])
def schedule_injective(outs):
"""X86 schedule for injective op.
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
- if len(s[x].op.axis) >= 5:
- fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
- s[x].parallel(fused)
- elif len(s[x].op.axis) >= 3:
- fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
- s[x].parallel(fused)
- elif len(s[x].op.axis) >= 1:
- s[x].parallel(s[x].op.axis[0])
+ schedule_injective_from_existing(s, x)
return s
@generic.schedule_concatenate.register(["cpu"])
#include <topi/generic/injective.h>
#include <topi/cuda/dense.h>
-#include <topi/cuda/extern.h>
#include <topi/cuda/injective.h>
#include <topi/cuda/pooling.h>
#include <topi/cuda/reduction.h>
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
+TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
+ });
+
/* x86 schedules */
TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_injective(args[0], args[1]);
});
+TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
+ });
+
/* ROCm schedules */
TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_dense(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_extern")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- *rv = topi::cuda::schedule_extern(args[0], args[1]);
- });
-
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_injective(args[0], args[1]);
});
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
+ });
+
TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_pool(args[0], args[1]);
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
+/*! \brief Builder function for instantiating schedules from existing schedules. */
+using FTVMScheduleFromExistingBuilder = std::function<
+ tvm::Schedule(tvm::Schedule sch, const tvm::Tensor& out)>;
+
+/*!
+ * \brief Helper function for registering generic functions matching the
+ * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
+ * with a PackedFunc suitable for passing to a tvm::GenericFunc.
+ *
+ * \param builder The schedule builder to wrap.
+ *
+ * \return The wrapped schedule builder
+ */
+inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
+ return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
+ *ret = builder(args[0], args[1]);
+ });
+}
+
+TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
+.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
+.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
+.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
+
/*! \brief Builder function for instantiating dense ops. */
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
const tvm::Tensor& data,