[TOPI][x86] Introduce schedule_injective_from_existing and unify external schedules...
authorJon Soifer <soiferj@gmail.com>
Thu, 26 Sep 2019 05:48:50 +0000 (22:48 -0700)
committermasahi <masahi129@gmail.com>
Thu, 26 Sep 2019 05:48:50 +0000 (14:48 +0900)
* Fix extern schedule for x86

* Register x86::schedule_extern

* Fix

* Fix

* Replace extern.py with extern.h

* Introduce new generic function schedule_injective_from_existing

* Fix

* Fix

* Add back to C++

* Fix style

* Injective schedule calls local schedule_injective_from_existing

* Fix

* Remove target arg from schedule_injective_from_existing

* Fix docs

* Try to fix unit test

* Fix test

* Fix other tests

* Fix bug

24 files changed:
tests/cpp/build_module_test.cc
tests/python/unittest/test_runtime_heterogeneous.py
topi/include/topi/cuda/extern.h [deleted file]
topi/include/topi/cuda/injective.h
topi/include/topi/generic/extern.h
topi/include/topi/generic/injective.h
topi/include/topi/x86/injective.h
topi/python/topi/arm_cpu/injective.py
topi/python/topi/cuda/__init__.py
topi/python/topi/cuda/conv2d_int8.py
topi/python/topi/cuda/extern.py [deleted file]
topi/python/topi/cuda/group_conv2d_nchw.py
topi/python/topi/cuda/injective.py
topi/python/topi/cuda/reduction.py
topi/python/topi/cuda/softmax.py
topi/python/topi/cuda/sort.py
topi/python/topi/cuda/vision.py
topi/python/topi/generic/extern.py
topi/python/topi/generic/injective.py
topi/python/topi/hls/injective.py
topi/python/topi/opengl/injective.py
topi/python/topi/x86/dense.py
topi/python/topi/x86/injective.py
topi/src/topi.cc

index 1a7f791..a7237db 100644 (file)
@@ -102,7 +102,11 @@ TEST(BuildModule, Heterogeneous) {
     return copy[i] - C[i];
   }, "elemwise_sub");
 
+  const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope");
+  (*enter_target_scope_func)(target_cuda);
   auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
+
+  (*enter_target_scope_func)(target_llvm);
   auto s2 = create_schedule({elemwise_sub->op});
 
   auto config = BuildConfig::Create();
index e0ae2e1..eb51fae 100644 (file)
@@ -174,7 +174,8 @@ def test_simplex_data_transferring():
 
     dev_tar = {"cuda": "cuda", "opencl": "opencl"}
     for device, target in dev_tar.items():
-        check_device(device, target)
+        with tvm.target.create(device):
+            check_device(device, target)
 
 
 def get_duplex_graph(host_dev_type, device_dev_type):
@@ -394,7 +395,8 @@ def test_duplex_data_transferring():
 
     dev_tar = {"cuda": "cuda", "opencl": "opencl"}
     for device, target in dev_tar.items():
-        check_device(device, target)
+        with tvm.target.create(device):
+            check_device(device, target)
 
 if __name__ == "__main__":
     test_simplex_data_transferring()
diff --git a/topi/include/topi/cuda/extern.h b/topi/include/topi/cuda/extern.h
deleted file mode 100644 (file)
index 7800986..0000000
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cuda/extern.h
- * \brief CUDA schedule for extern followed by injective operations
- */
-#ifndef TOPI_CUDA_EXTERN_H_
-#define TOPI_CUDA_EXTERN_H_
-
-#include "topi/tags.h"
-#include "topi/detail/fuse.h"
-#include "tvm/operation.h"
-#include "tvm/build_module.h"
-
-namespace topi {
-using namespace tvm;
-
-namespace cuda {
-/*!
- * \brief Schedule a given operation representing one of the outputs of an
- * external function which is followed by injective operations.
- *
- * \param target The target to generate a schedule for.
- * \param op The operation representing the output followed by injective operations.
- * \param sch The schedule to apply this scheduling to
- *
- * \return The schedule given by sch
- */
-inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
-  auto x = op.output(0);
-  auto fused = detail::Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
-  auto num_thread = target->max_num_threads;
-  IterVar bx, tx;
-  sch[x].split(fused, num_thread, &bx, &tx);
-  sch[x].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
-  sch[x].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
-  return sch;
-}
-
-/*!
-* \brief Schedule an extern op followed by injective operations.
-* For example, cudnn kernel + bias add + relu
-*
-* \param target The target to generate a schedule for.
-* \param outs The output tensors.
-*
-* \return A schedule for the op.
-*/
-inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
-  Array<Operation> out_ops;
-  for (auto t : outs) {
-    out_ops.push_back(t->op);
-  }
-  auto s = create_schedule(out_ops);
-
-  tvm::schedule::AutoInlineInjective(s);
-  for (auto out : outs) {
-    if (out->op->derived_from<ExternOpNode>()) {
-      continue;
-    }
-    ScheduleOutputForExtern(target, out->op, s);
-  }
-
-  return s;
-}
-
-}  // namespace cuda
-}  // namespace topi
-#endif  // TOPI_CUDA_EXTERN_H_
index e629ae1..663bc1f 100644 (file)
@@ -33,21 +33,24 @@ namespace topi {
 using namespace tvm;
 
 namespace cuda {
+
 /*!
-* \brief Schedule a given injective operation.
-*
-* \param target The target to generate a schedule for.
-* \param op The operation representing the injective operation.
-* \param s The schedule to apply this scheduling to
-*/
-inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
-  auto x = op.output(0);
-  auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
+ * \brief Updates an existing schedule for the given injective ops.
+ *
+ * \param sch The schedule to update.
+ * \param out The tensor representing the injective op.
+ * 
+ * \return The updated schedule.
+ */
+inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
+  auto fused = detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
+  auto target = Target::Current(false);
   auto num_thread = target->max_num_threads;
   IterVar bx, tx;
-  s[x].split(fused, num_thread, &bx, &tx);
-  s[x].bind(bx, thread_axis(Range(), "blockIdx.x"));
-  s[x].bind(tx, thread_axis(Range(), "threadIdx.x"));
+  sch[out].split(fused, num_thread, &bx, &tx);
+  sch[out].bind(bx, thread_axis(Range(), "blockIdx.x"));
+  sch[out].bind(tx, thread_axis(Range(), "threadIdx.x"));
+  return sch;
 }
 
 /*!
@@ -66,7 +69,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
   auto s = create_schedule(out_ops);
   tvm::schedule::AutoInlineInjective(s);
   for (auto out : outs) {
-    ScheduleInjectiveOp(target, out->op, s);
+    schedule_injective_from_existing(s, out);
   }
   return s;
 }
index 6c51d7c..5c0c392 100644 (file)
@@ -28,6 +28,7 @@
 #include "topi/detail/fuse.h"
 #include "tvm/operation.h"
 #include "tvm/build_module.h"
+#include "injective.h"
 
 namespace topi {
 using namespace tvm;
@@ -47,6 +48,15 @@ inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
     out_ops.push_back(t->op);
   }
   auto s = create_schedule(out_ops);
+
+  tvm::schedule::AutoInlineInjective(s);
+  for (auto out : outs) {
+    if (out->op->derived_from<ExternOpNode>()) {
+      continue;
+    }
+    tvm::GenericFunc::Get("schedule_injective_from_existing")(s, out);
+  }
+
   return s;
 }
 
index f26651c..fa7df4c 100644 (file)
@@ -35,6 +35,19 @@ using namespace tvm;
 namespace generic {
 
 /*!
+ * \brief Updates an existing schedule for the given injective ops.
+ *
+ * \param sch The schedule to update.
+ * \param out The tensor representing the injective op.
+ * 
+ * \return The updated schedule.
+ */
+inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
+  detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
+  return sch;
+}
+
+/*!
  * \brief Create a generic schedule for the given injective ops.
  *
  * \param target The target to generate a schedule for.
@@ -50,7 +63,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
   auto s = create_schedule(out_ops);
   tvm::schedule::AutoInlineInjective(s);
   auto x = outs[0];
-  detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
+  schedule_injective_from_existing(s, x);
 
   return s;
 }
index adbe382..7cb79ae 100644 (file)
@@ -33,6 +33,28 @@ namespace topi {
 using namespace tvm;
 
 namespace x86 {
+
+/*!
+ * \brief Updates an existing schedule for the given injective ops.
+ *
+ * \param sch The schedule to update.
+ * \param out The tensor representing the injective op.
+ * 
+ * \return The updated schedule.
+ */
+inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
+  auto axis = sch[out]->op.as<ComputeOpNode>()->axis;
+  if (axis.size() == 4) {
+    auto n = axis[0];
+    auto c = axis[1];
+    auto fused = detail::Fuse(sch[out], { n, c });  // for nhwc layout, fuse n and h
+    sch[out].parallel(fused);
+  } else {
+    sch[out].parallel(axis[0]);
+  }
+  return sch;
+}
+
 /*!
 * \brief Create an x86 schedule for the given injective ops.
 *
@@ -50,15 +72,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
   tvm::schedule::AutoInlineInjective(s);
 
   auto x = outs[0];
-  auto axis = s[x]->op.as<ComputeOpNode>()->axis;
-  if (axis.size() == 4) {
-    auto n = axis[0];
-    auto c = axis[1];
-    auto fused = detail::Fuse(s[x], { n, c });  // for nhwc layout, fuse n and h
-    s[x].parallel(fused);
-  } else {
-    s[x].parallel(axis[0]);
-  }
+  schedule_injective_from_existing(s, x);
 
   return s;
 }
index 028558f..3727754 100644 (file)
 import tvm
 from .. import generic
 
+@generic.schedule_injective_from_existing.register(["arm_cpu"])
+def schedule_injective_from_existing(sch, out):
+    """Schedule for injective op from existing schedule.
+
+    Parameters
+    ----------
+    sch: Schedule
+         The schedule to update.
+    out: Tensor
+         The tensor representing the injective op.
+
+    Returns
+    -------
+    sch: Schedule
+         The updated schedule.
+    """
+    if len(sch[out].op.axis) >= 4:
+        fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
+        sch[out].parallel(fused)
+    elif len(sch[out].op.axis) >= 3:
+        fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
+        sch[out].parallel(fused)
+    elif len(sch[out].op.axis) >= 2:
+        sch[out].parallel(sch[out].op.axis[0])
+    return sch
+
 @generic.schedule_injective.register(["arm_cpu"])
 def schedule_injective(outs):
     """ARM CPU schedule for injective op.
@@ -42,14 +68,7 @@ def schedule_injective(outs):
         (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
         s[x].vectorize(ii)
     tvm.schedule.AutoInlineInjective(s)
-    if len(s[x].op.axis) >= 4:
-        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
-        s[x].parallel(fused)
-    elif len(s[x].op.axis) >= 3:
-        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
-        s[x].parallel(fused)
-    elif len(s[x].op.axis) >= 2:
-        s[x].parallel(s[x].op.axis[0])
+    schedule_injective_from_existing(s, x)
     return s
 
 @generic.schedule_concatenate.register(["arm_cpu"])
index 403f67b..beab0af 100644 (file)
@@ -13,7 +13,6 @@ from .softmax import schedule_softmax
 from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
 from .dense import schedule_dense
 from .pooling import schedule_pool, schedule_adaptive_pool
-from .extern import schedule_extern
 from .nn import schedule_lrn, schedule_l2_normalize
 from .batch_matmul import schedule_batch_matmul
 from .vision import *
index 25d4945..580cf96 100644 (file)
@@ -19,7 +19,7 @@
 import tvm
 from tvm import autotvm
 
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
 from .tensor_intrin import dp4a
 from ..nn.pad import pad
 from ..nn.util import get_pad_tuple
@@ -172,8 +172,8 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output):
         if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
                        packed_kernel.name == 'packed_kernel':
             # data and kernel are not pre-computed, schedule layout transform here
-            _schedule_injective(packed_data.op, s)
-            _schedule_injective(packed_kernel.op, s)
+            schedule_injective_from_existing(s, packed_data)
+            schedule_injective_from_existing(s, packed_kernel)
 
     if pad_data != packed_data:
         s[pad_data].compute_inline()
diff --git a/topi/python/topi/cuda/extern.py b/topi/python/topi/cuda/extern.py
deleted file mode 100644 (file)
index 74b88ec..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name, unused-variable,
-"""Schedule for cudnn and miopen extern op"""
-import tvm
-from .. import generic
-from .injective import _schedule_injective
-
-
-@generic.schedule_extern.register(["cuda", "gpu"])
-def schedule_extern(outs):
-    """Schedule for an extern op followed by injective operations.
-       For example, cudnn kernel + bias add + relu.
-
-    Parameters
-    ----------
-    outs: Array of Tensor
-          The computation graph description of extern plus injective ops in the format
-          of an array of tensors.
-
-    Returns
-    -------
-    sch: Schedule
-        The computation schedule for the op.
-    """
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
-
-    tvm.schedule.AutoInlineInjective(s)
-    for out in outs:
-        if isinstance(out.op, tvm.tensor.ExternOp):
-            continue
-        _schedule_injective(out.op, s)
-    return s
index cd8c823..f4bb734 100644 (file)
@@ -19,7 +19,7 @@
 import tvm
 from tvm import autotvm
 
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
 from .tensor_intrin import dp4a
 from ..nn.pad import pad
 from ..nn.util import get_pad_tuple
@@ -201,8 +201,8 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
         if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
                 packed_kernel.name == 'packed_kernel':
             # data and kernel are not pre-computed, schedule layout transform here
-            _schedule_injective(packed_data.op, s)
-            _schedule_injective(packed_kernel.op, s)
+            schedule_injective_from_existing(s, packed_data)
+            schedule_injective_from_existing(s, packed_kernel)
 
     if pad_data != packed_data:
         s[pad_data].compute_inline()
index 8a16c5c..a6ec853 100644 (file)
 import tvm
 from .. import generic, util
 
-def _schedule_injective(op, sch):
-    x = op.output(0)
-    fused = sch[x].fuse(*sch[x].op.axis)
+@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
+def schedule_injective_from_existing(sch, out):
+    """Schedule for injective op from existing schedule.
+
+    Parameters
+    ----------
+    sch: Schedule
+         The schedule to update.
+    out: Tensor
+         The tensor representing the injective op.
+
+    Returns
+    -------
+    sch: Schedule
+         The updated schedule.
+    """
+    fused = sch[out].fuse(*sch[out].op.axis)
     num_thread = tvm.target.current_target(allow_none=False).max_num_threads
     max_block = 256
 
     try:
-        const_size = util.get_const_int(util.prod(x.shape))
+        const_size = util.get_const_int(util.prod(out.shape))
         max_block = 256
         need_block_split = const_size > max_block * num_thread
     except ValueError:
         need_block_split = False
 
     if need_block_split:
-        xo, xi = sch[x].split(fused, factor=num_thread * max_block)
-        bx, tx = sch[x].split(xi, factor=num_thread)
-        sch[x].reorder(bx, tx, xo)
-        sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
-        sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
+        xo, xi = sch[out].split(fused, factor=num_thread * max_block)
+        bx, tx = sch[out].split(xi, factor=num_thread)
+        sch[out].reorder(bx, tx, xo)
+        sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
+        sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
     else:
-        bx, tx = sch[x].split(fused, factor=num_thread)
-        sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
-        sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
+        bx, tx = sch[out].split(fused, factor=num_thread)
+        sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
+        sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
 
     return sch
 
-
 @generic.schedule_injective.register(["cuda", "gpu"])
 def schedule_injective(outs):
     """Schedule for injective op.
@@ -66,7 +79,7 @@ def schedule_injective(outs):
 
     tvm.schedule.AutoInlineInjective(s)
     for out in outs:
-        _schedule_injective(out.op, s)
+        schedule_injective_from_existing(s, out)
     return s
 
 schedule_elemwise = schedule_injective
index 2588531..a7a56b9 100644 (file)
@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
 import tvm
 from .. import tag
 from .. import generic
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
 
 def _schedule_reduce(op, sch, is_idx_reduce=False):
     if is_idx_reduce:
@@ -30,7 +30,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
         data_out = op.output(0)
 
     if not sch[data_out].op.reduce_axis:
-        return _schedule_injective(op, sch)
+        return schedule_injective_from_existing(sch, op.output(0))
 
     if len(sch[data_out].op.axis) > 0:
         all_reduce = False
@@ -126,7 +126,7 @@ def schedule_reduce(outs):
         """Internal travserse function"""
         if tag.is_broadcast(operator.tag):
             if operator not in scheduled_ops:
-                _schedule_injective(operator, sch)
+                schedule_injective_from_existing(sch, operator.output(0))
             for tensor in operator.input_tensors:
                 traverse_after_reduce(tensor.op)
         elif operator.tag == 'comm_reduce':
index 09b6ef8..9fdac50 100644 (file)
@@ -18,7 +18,7 @@
 """Schedule for softmax operator"""
 import tvm
 from .. import generic
-from .injective import _schedule_injective
+from .injective import schedule_injective_from_existing
 
 @generic.schedule_softmax.register(["cuda", "gpu"])
 def schedule_softmax(outs):
@@ -58,7 +58,7 @@ def schedule_softmax(outs):
             ops.append(exp.op)
             
         for op in ops:
-            s = _schedule_injective(op, s)
+            s = schedule_injective_from_existing(s, op.output(0))
     else:
         num_thread = 64
         block_x = tvm.thread_axis("blockIdx.x")
index 1d9148f..c45465e 100644 (file)
@@ -42,10 +42,10 @@ def _schedule_sort(outs):
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
-    from .injective import _schedule_injective
+    from .injective import schedule_injective_from_existing
     def traverse(op):
         if tag.is_injective(op.tag):
-            _schedule_injective(op, s)
+            schedule_injective_from_existing(s, op.output(0))
         for tensor in op.input_tensors:
             if tensor.op.input_tensors and tensor.op not in scheduled_ops:
                 traverse(tensor.op)
index 968e554..3a90402 100644 (file)
@@ -28,10 +28,10 @@ def _default_schedule(outs):
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
-    from .injective import _schedule_injective
+    from .injective import schedule_injective_from_existing
     def traverse(op):
         if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
-            _schedule_injective(op, s)
+            schedule_injective_from_existing(s, op.output(0))
         for tensor in op.input_tensors:
             if tensor.op.input_tensors and tensor.op not in scheduled_ops:
                 traverse(tensor.op)
index 3728222..a060114 100644 (file)
@@ -19,6 +19,7 @@
 from __future__ import absolute_import as _abs
 
 import tvm
+from .. import cpp
 
 @tvm.target.generic_func
 def schedule_extern(outs):
@@ -35,8 +36,5 @@ def schedule_extern(outs):
     sch: Schedule
         The computation schedule for the op.
     """
-    target = tvm.target.current_target(allow_none=False)
-    if target.target_name != "llvm":
-        raise RuntimeError("schedule_extern not registered for '%s'" % target)
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    return tvm.create_schedule([x.op for x in outs])
+    target = tvm.target.current_target()
+    return cpp.generic.schedule_extern(target, outs)
index ec1732f..178363d 100644 (file)
@@ -20,6 +20,25 @@ from __future__ import absolute_import as _abs
 
 import tvm
 
+@tvm.target.override_native_generic_func("schedule_injective_from_existing")
+def schedule_injective_from_existing(sch, out):
+    """Schedule for injective op from existing schedule.
+
+    Parameters
+    ----------
+    sch: Schedule
+         The schedule to update.
+    out: Tensor
+         The tensor representing the injective op.
+
+    Returns
+    -------
+    sch: Schedule
+         The updated schedule.
+    """
+    sch[out].fuse(s[out].op.axis)
+    return sch
+
 @tvm.target.override_native_generic_func("schedule_injective")
 def schedule_injective(outs):
     """Schedule for injective op.
@@ -42,7 +61,7 @@ def schedule_injective(outs):
     x = outs[0]
     s = tvm.create_schedule([x.op for x in outs])
     tvm.schedule.AutoInlineInjective(s)
-    s[x].fuse(s[x].op.axis)
+    schedule_injective_from_existing(s, x)
     return s
 
 @tvm.target.generic_func
index 5a43306..de58428 100644 (file)
 import tvm
 from .. import generic
 
+@generic.schedule_injective_from_existing.register(["hls"])
+def schedule_injective_from_existing(sch, out):
+    """Schedule for injective op from existing schedule.
+
+    Parameters
+    ----------
+    sch: Schedule
+         The schedule to update.
+    out: Tensor
+         The tensor representing the injective op.
+
+    Returns
+    -------
+    sch: Schedule
+         The updated schedule.
+    """
+    fused = sch[out].fuse(*sch[out].op.axis)
+    px, x = sch[out].split(fused, nparts=1)
+    sch[out].bind(px, tvm.thread_axis("pipeline"))
+    return sch
+
 @generic.schedule_injective.register(["hls"])
 def schedule_injective(outs):
     """Schedule for injective op.
@@ -38,9 +59,7 @@ def schedule_injective(outs):
     s = tvm.create_schedule([x.op for x in outs])
     tvm.schedule.AutoInlineInjective(s)
     for out in outs:
-        fused = s[out].fuse(*s[out].op.axis)
-        px, x = s[out].split(fused, nparts=1)
-        s[out].bind(px, tvm.thread_axis("pipeline"))
+        schedule_injective_from_existing(s, out)
     return s
 
 schedule_elemwise = schedule_injective
index 1f3f494..d3ebc94 100644 (file)
 import tvm
 from .. import generic
 
-def _schedule_injective(op, sch):
-    x = op.output(0)
-    sch[x].opengl()
-    return sch
+@generic.schedule_injective_from_existing.register(["opengl"])
+def schedule_injective_from_existing(sch, out):
+    """Schedule for injective op from existing schedule.
+
+    Parameters
+    ----------
+    sch: Schedule
+         The schedule to update.
+    out: Tensor
+         The tensor representing the injective op.
 
+    Returns
+    -------
+    sch: Schedule
+         The updated schedule.
+    """
+    sch[out].opengl()
+    return sch
 
 @generic.schedule_injective.register(["opengl"])
 def schedule_injective(outs):
@@ -45,7 +58,7 @@ def schedule_injective(outs):
 
     tvm.schedule.AutoInlineInjective(s)
     for out in outs:
-        _schedule_injective(out.op, s)
+        schedule_injective_from_existing(s, out)
     return s
 
 schedule_elemwise = schedule_injective
index 27cadfa..9b7624e 100644 (file)
@@ -111,6 +111,10 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
 
 @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
 def _schedule_dense(cfg, outs):
+    target = tvm.target.current_target()
+    if "cblas" in target.libs:
+        return generic.schedule_extern(outs)
+
     s = tvm.create_schedule([x.op for x in outs])
 
     def _callback(op):
index 37080e0..5bcb179 100644 (file)
@@ -20,6 +20,32 @@ from __future__ import absolute_import as _abs
 import tvm
 from .. import generic
 
+@generic.schedule_injective_from_existing.register(["cpu"])
+def schedule_injective_from_existing(sch, out):
+    """Schedule for injective op from existing schedule.
+
+    Parameters
+    ----------
+    sch: Schedule
+         The schedule to update.
+    out: Tensor
+         The tensor representing the injective op.
+
+    Returns
+    -------
+    sch: Schedule
+         The updated schedule.
+    """
+    if len(sch[out].op.axis) >= 5:
+        fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
+        sch[out].parallel(fused)
+    elif len(sch[out].op.axis) >= 3:
+        fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
+        sch[out].parallel(fused)
+    elif len(sch[out].op.axis) >= 1:
+        sch[out].parallel(sch[out].op.axis[0])
+    return sch
+
 @generic.schedule_injective.register(["cpu"])
 def schedule_injective(outs):
     """X86 schedule for injective op.
@@ -39,14 +65,7 @@ def schedule_injective(outs):
     x = outs[0]
     s = tvm.create_schedule([x.op for x in outs])
     tvm.schedule.AutoInlineInjective(s)
-    if len(s[x].op.axis) >= 5:
-        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
-        s[x].parallel(fused)
-    elif len(s[x].op.axis) >= 3:
-        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
-        s[x].parallel(fused)
-    elif len(s[x].op.axis) >= 1:
-        s[x].parallel(s[x].op.axis[0])
+    schedule_injective_from_existing(s, x)
     return s
 
 @generic.schedule_concatenate.register(["cpu"])
index b070939..7114f4d 100644 (file)
@@ -56,7 +56,6 @@
 #include <topi/generic/injective.h>
 
 #include <topi/cuda/dense.h>
-#include <topi/cuda/extern.h>
 #include <topi/cuda/injective.h>
 #include <topi/cuda/pooling.h>
 #include <topi/cuda/reduction.h>
@@ -586,6 +585,11 @@ TVM_REGISTER_GLOBAL("topi.generic.schedule_injective")
   *rv = topi::generic::schedule_injective(args[0], args[1]);
   });
 
+TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
+ });
+
 /* x86 schedules */
 TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
@@ -611,6 +615,11 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
   *rv = topi::x86::schedule_injective(args[0], args[1]);
   });
 
+TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
+ });
+
 /* ROCm schedules */
 TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
@@ -643,16 +652,16 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
   *rv = topi::cuda::schedule_dense(args[0], args[1]);
   });
 
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_extern")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = topi::cuda::schedule_extern(args[0], args[1]);
-  });
-
 TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = topi::cuda::schedule_injective(args[0], args[1]);
   });
 
+TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
+ });
+
 TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = topi::cuda::schedule_pool(args[0], args[1]);
@@ -752,6 +761,30 @@ TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
 .set_default(WrapSchedule(topi::generic::default_schedule))
 .register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
 
+/*! \brief Builder function for instantiating schedules from existing schedules. */
+using FTVMScheduleFromExistingBuilder = std::function<
+  tvm::Schedule(tvm::Schedule sch, const tvm::Tensor& out)>;
+
+/*!
+ * \brief Helper function for registering generic functions matching the
+ * FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
+ * with a PackedFunc suitable for passing to a tvm::GenericFunc.
+ *
+ * \param builder The schedule builder to wrap.
+ *
+ * \return The wrapped schedule builder
+ */
+inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
+  return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
+    *ret = builder(args[0], args[1]);
+  });
+}
+
+TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
+.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
+.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
+.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
+
 /*! \brief Builder function for instantiating dense ops. */
 using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
                                                      const tvm::Tensor& data,