[REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager. (#5225)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 3 Apr 2020 22:50:11 +0000 (15:50 -0700)
committerGitHub <noreply@github.com>
Fri, 3 Apr 2020 22:50:11 +0000 (15:50 -0700)
* [REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager.

- SplitHostDevice
- ThreadSync
- BindDevice
- LowerThreadAllreduce
- Provide a temp fix for printing IRModule with PrimFunc before the formal text printer.

* Address comments, fix tests.

* Fix relay tests

* Explicit move

28 files changed:
include/tvm/ir/function.h
include/tvm/ir/module.h
include/tvm/tir/analysis.h
include/tvm/tir/ir_pass.h
include/tvm/tir/transform.h
python/tvm/driver/build_module.py
python/tvm/ir/__init__.py
python/tvm/ir/function.py
python/tvm/ir/module.py
python/tvm/tir/transform/function_pass.py
python/tvm/tir/transform/transform.py
src/driver/driver_api.cc
src/printer/relay_text_printer.cc
src/target/codegen.cc
src/target/llvm/codegen_cpu.cc
src/tir/ir/transform.cc
src/tir/pass/ffi_api.cc
src/tir/pass/make_api.cc
src/tir/transforms/bind_device_type.cc [new file with mode: 0644]
src/tir/transforms/lower_thread_allreduce.cc
src/tir/transforms/split_host_device.cc [moved from src/tir/pass/split_host_device.cc with 61% similarity]
src/tir/transforms/tensorcore_infer_fragment.cc
src/tir/transforms/thread_storage_sync.cc
tests/python/unittest/test_tir_analysis_usedef.py [moved from tests/python/unittest/test_tir_pass_split_host_device.py with 98% similarity]
tests/python/unittest/test_tir_pass_inject_double_buffer.py
tests/python/unittest/test_tir_pass_storage_flatten.py
tests/python/unittest/test_tir_transform_lower_warp_memory.py
tests/python/unittest/test_tir_transform_thread_sync.py

index ecf7c19..dc7a2b2 100644 (file)
@@ -47,19 +47,19 @@ enum class CallingConv : int {
    */
   kDefault = 0,
   /*!
+   * \brief PackedFunc that exposes a CPackedFunc signature.
+   *
+   * - Calling by PackedFunc calling convention.
+   * - Implementation: Expose a function with the CPackedFunc signature.
+   */
+  kCPackedFunc = 1,
+  /*!
    * \brief Device kernel launch
    *
    * - Call by PackedFunc calling convention.
    * - Implementation: defined by device runtime(e.g. runtime/cuda)
    */
   kDeviceKernelLaunch = 2,
-  /*!
-   * \brief PackedFunc that exposes a CPackedFunc signature.
-   *
-   * - Calling by PackedFunc calling convention.
-   * - Implementation: Expose a function with the CPackedFunc signature.
-   */
-  kCPackedFunc = 3,
 };
 
 /*!
index f6ea918..f63bf96 100644 (file)
@@ -324,6 +324,8 @@ class IRModule : public ObjectRef {
 
   /*! \brief Declare the container type. */
   using ContainerType = IRModuleNode;
+  // allow copy on write.
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
 };
 
 /*!
index 6bab44e..fe74a96 100644 (file)
@@ -49,6 +49,16 @@ struct ExprDeepEqual {
  public:
   TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
 };
+
+
+/*!
+ * \brief Find undefined vars in the statment.
+ * \param stmt The function to be checked.
+ * \param defs The vars that is defined.
+ * \return Array of undefined vars.
+ */
+Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
+
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_ANALYSIS_H_
index 6a1a178..8ba008b 100644 (file)
@@ -407,56 +407,6 @@ LoweredFunc MakeAPI(Stmt body,
                     bool is_restricted);
 
 /*!
- * \brief Bind the device type of host function to be device_type.
- * \param func The function to be binded.
- * \param device_type The device type to be binded.
- * \return The binded function.
- */
-LoweredFunc BindDeviceType(LoweredFunc func,
-                           int device_type);
-/*!
- * \brief Find undefined vars in the statment.
- * \param stmt The function to be checked.
- * \param defs The vars that is defined.
- * \return Array of undefined vars.
- */
-Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
-
-/*!
- * \brief Split the function into a host function and device functions.
- * \param func The function to be splitted.
- *
- * \return Array of functions, the first one is host function,
- *     the others are device functions.
- */
-Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
-
-/*!
- * \brief Insert sync between parallel read/write of shared buffers.
- *
- * \param stmt The stmt to be trasnformed.
- * \param storage_scope The storage scope considered.
- */
-LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
-
-/*!
- * \brief Lower cross thread alleduce in the stmt.
- * \param f The device function to be lowered.
- * \param warp_size the size of warp where no sync is needed.
- * \return Transformed function.
- */
-LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
-
-/*!
- * \brief Lower warp memory in stmt.
- * \param f The device function to be lowered.
- * \param warp_size the size of warp where no sync is needed.
- *        this function will only take in effect if warp_size is bigger than one.
- * \return Transformed function.
- */
-LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
-
-/*!
  * \brief Remap the thread axis
  *
  *  This can be used to get equivalent program which uses
@@ -471,26 +421,6 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
 LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
 
 /*!
- * \brief Lower packed function call.
- * \param f The function to be lowered.
- * \return Transformed function.
- */
-LoweredFunc LowerTVMBuiltin(LoweredFunc f);
-
-
-/*!
- * \brief Rewrite the pointer content type of arguments,
- *  as well as Alloc internal to the function to use
- *  the most frequently accessed type for load/store
- *  to avoid pointer casting in backend when possible.
- *
- * \note implemeneted in storage_rewrite.cc
- * \param f The function to be trasnformed
- * \return Transformed function.
- */
-LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
-
-/*!
  * \brief Rewrite the pointer content type of arguments,
  *  as well as Alloc internal to the function to use
  *  the most frequently accessed type for load/store
@@ -514,14 +444,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f);
 LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
 
 /*!
- * \brief Infer the TensorCore fragment infomation using tensor intrinsics
- *
- * \param f The device function to be lowered.
- * \return Transformed function.
- */
-LoweredFunc InferFragment(LoweredFunc f);
-
-/*!
  * \brief Verify if memory accesses are legal for a specific target device type.
  *
  *  In the case that tgt is cuda, if not all workload is bound with
index d809e07..211e344 100644 (file)
@@ -59,6 +59,21 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
                                 const tvm::Array<tvm::PrimExpr>& required);
 
 /*!
+ * \brief Bind the device type ofthe function to be
+ *        the device_type specified in the target attribute.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass BindDeviceType();
+
+/*!
+ * \brief Split the function into a host function and device functions.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass SplitHostDevice();
+
+/*!
  * \brief skip assert stmt.
  *
  * \return The pass.
index 7eda40d..e4bd200 100644 (file)
@@ -14,6 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+# pylint: disable=invalid-name
 """The build utils in python.
 
 This module provides the functions to transform schedule to
@@ -25,6 +27,7 @@ import tvm.tir
 
 from tvm.runtime import ndarray
 from tvm.ir import container
+from tvm.ir import CallingConv
 from tvm.target import codegen, BuildConfig
 from tvm.tir import ir_pass
 from tvm.tir.stmt import LoweredFunc
@@ -222,75 +225,59 @@ def _build_for_device(flist, target, target_host):
     mdev : tvm.module
         A module that contains device code.
     """
-    @tvm.tir.transform.prim_func_pass(opt_level=0)
-    class BindTarget:
-        def __init__(self, target):
-            self.target = target
-
-        # pylint: disable=unused-argument
-        def transform_function(self, func, mod, ctx):
-            return func.with_attr("target", self.target)
-
     target = _target.create(target)
+    target_host = _target.create(target_host)
     device_type = ndarray.context(target.target_name, 0).device_type
-    fhost = []
-    fdevice = []
+
     for func in flist:
         if not ir_pass.VerifyMemory(func, device_type):
             raise ValueError(
                 "Direct host side access to device memory is detected in %s. "
                 "Did you forget to bind?" % func.name)
-        if func.func_type == LoweredFunc.MixedFunc:
-            if BuildConfig.current().detect_global_barrier:
-                func = ir_pass.ThreadSync(func, "global")
-            func = ir_pass.ThreadSync(func, "shared")
-            func = ir_pass.ThreadSync(func, "warp")
-            func = ir_pass.InferFragment(func)
-            warp_size = target.thread_warp_size
-            func = ir_pass.LowerThreadAllreduce(func, warp_size)
-            fsplits = list(ir_pass.SplitHostDevice(func))
-            fhost.append(fsplits[0])
-            for x in fsplits[1:]:
-                fdevice.append(x)
-        elif func.func_type == LoweredFunc.HostFunc:
-            fhost.append(func)
-        elif func.func_type == LoweredFunc.DeviceFunc:
-            fdevice.append(func)
-        else:
-            raise ValueError("unknown function type %d" % func.func_type)
-
-    if "gpu" in target.keys and not fdevice:
-        warnings.warn(
-            "Specified target %s, but cannot find device code, did you do "
-            "bind?" % target)
 
-    fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
+    mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
+    opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
+    if BuildConfig.current().detect_global_barrier:
+        opt_mixed += [tvm.tir.transform.ThreadSync("global")]
+    opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
+                  tvm.tir.transform.ThreadSync("warp"),
+                  tvm.tir.transform.InferFragment(),
+                  tvm.tir.transform.LowerThreadAllreduce(),
+                  tvm.tir.transform.BindDeviceType(),
+                  tvm.tir.transform.SplitHostDevice()]
+    mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed)
 
-    if device_type == ndarray.cpu(0).device_type and target_host == target:
-        assert not fdevice
-
-    target_host = _target.create(target_host)
 
     # device optimizations
-    mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
     opt_device = tvm.ir.transform.Sequential(
-        [BindTarget(target),
+        [tvm.tir.transform.Filter(
+            lambda f: "calling_conv" in f.attrs and
+            f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
          tvm.tir.transform.LowerWarpMemory(),
          tvm.tir.transform.LowerDeviceStorageAccessInfo(),
          tvm.tir.transform.LowerIntrin()])
-    mod_dev = opt_device(mod_dev)
+    mod_dev = opt_device(mod_mixed)
 
     # host optimizations
-    mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
     opt_host = tvm.ir.transform.Sequential(
-        [BindTarget(target_host),
+        [tvm.tir.transform.Filter(
+            lambda f: "calling_conv" not in f.attrs or
+            f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
+         tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
          tvm.tir.transform.LowerTVMBuiltin(),
          tvm.tir.transform.LowerDeviceStorageAccessInfo(),
          tvm.tir.transform.LowerIntrin(),
          tvm.tir.transform.CombineContextCall()])
-    mod_host = opt_host(mod_host)
+    mod_host = opt_host(mod_mixed)
+
+    if device_type == ndarray.cpu(0).device_type and target_host == target:
+        assert len(mod_dev.functions) == 0
+    if "gpu" in target.keys and len(mod_dev.functions) == 0:
+        warnings.warn(
+            "Specified target %s, but cannot find device code, did you do "
+            "bind?" % target)
 
-    rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
+    rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
     return mod_host, rt_mod_dev
 
 
index b3efd6b..1aabf3e 100644 (file)
@@ -23,7 +23,7 @@ from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
 from .tensor_type import TensorType
 from .type_relation import TypeCall, TypeRelation
 from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
-from .function import BaseFunc
+from .function import CallingConv, BaseFunc
 from .adt import Constructor, TypeData
 from .module import IRModule
 from .attrs import Attrs, DictAttrs, make_node
index 70eb51a..afc8c10 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Function defintiions."""
+from enum import IntEnum
 from .expr import RelayExpr
 from . import _ffi_api
 
 
+class CallingConv(IntEnum):
+    """Possible kinds of calling conventions."""
+    DEFAULT = 0
+    C_PACKED_FUNC = 1
+    DEVICE_KERNEL_LAUNCH = 2
+
+
 class BaseFunc(RelayExpr):
     """Base class of all functions."""
     @property
index 24f5211..8d75d8e 100644 (file)
@@ -60,7 +60,6 @@ class IRModule(Node):
             type_definitions = mapped_type_defs
         self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
 
-
     def __setitem__(self, var, val):
         """Add a mapping to the module.
 
index 93bb996..a19cc2f 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 """TIR specific function pass support."""
 import inspect
+import types
 import functools
 
 import tvm._ffi
@@ -142,7 +143,7 @@ def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
             return _wrap_class_function_pass(pass_arg, info)
         if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
             raise TypeError("pass_func must be a callable for Module pass")
-        return _ffi_api.MakeFunctionPass(pass_arg, info)
+        return _ffi_api.CreatePrimFuncPass(pass_arg, info)
 
     if pass_func:
         return create_function_pass(pass_func)
index 6be4a38..c823c1a 100644 (file)
 """Wrapping existing transformations."""
 # pylint: disable=invalid-name
 from . import _ffi_api
+from . import function_pass as _fpass
+
+
+def Apply(ftransform):
+    """Apply ftransform to each function in the Module.
+
+    This function is a thin wrapper around tvm.tir.transform.prim_func_pass
+
+    Parameters
+    ----------
+    ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc
+       The transformation pass.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    # pylint: disable=unused-argument
+    def _transform(func, mod, ctx):
+        return ftransform(func)
+    return _fpass.prim_func_pass(_transform, opt_level=0)
+
+
+def Filter(fcond):
+    """Filter functions by the calling convention attribute.
+
+    Parameters
+    ----------
+    fcond : tvm.tir.PrimFunc -> bool
+        The condition of the filtering.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    # pylint: disable=unused-argument
+    def _transform(func, mod, ctx):
+        return func if fcond(func) else None
+    return _fpass.prim_func_pass(_transform, opt_level=0)
+
+
+def BindDeviceType():
+    """Bind the device type of the function to be
+       the device_type specified in the target attribute.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    return _ffi_api.BindDeviceType()
+
+
+def SplitHostDevice():
+    """Split the function into a host function and device functions.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    return _ffi_api.SplitHostDevice()
 
 
 def SkipAssert():
index f59e764..d54d6f8 100644 (file)
@@ -185,75 +185,50 @@ transform::Pass BindTarget(Target target) {
 }
 
 
+template<typename FCond>
+transform::Pass FilterBy(FCond fcond) {
+  auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
+    if (fcond(f)) {
+      return f;
+    } else {
+      return tir::PrimFunc(nullptr);
+    }
+  };
+  return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
+}
+
+
 std::pair<IRModule, IRModule>
 split_dev_host_funcs(const Array<LoweredFunc>& funcs,
                      const Target& target,
                      const Target& target_host,
                      const BuildConfig& config) {
-  std::unordered_set<std::string> all_names;
-  for (const auto& x : funcs) {
-    CHECK(all_names.count(x->name) == 0)
-        << "Duplicate function name " << x->name;
-    all_names.insert(x->name);
-  }
-
-  Array<LoweredFunc> fhost;
-  Array<LoweredFunc> fdevice;
-
   for (const auto& x : funcs) {
     CHECK(tir::VerifyMemory(x, target->device_type))
         << "Direct host side access to device memory is detected in "
         << x->func_name() << ". Did you forget to bind?";
-
-    if (x->func_type == tir::kMixedFunc) {
-      auto func = x;
-      if (config->detect_global_barrier) {
-        func = tir::ThreadSync(func, "global");
-      }
-
-      func = tir::ThreadSync(func, "shared");
-      func = tir::ThreadSync(func, "warp");
-      func = tir::InferFragment(func);
-      func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
-      auto fsplits = tir::SplitHostDevice(func);
-      fhost.push_back(fsplits[0]);
-      for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
-        fdevice.push_back(*f);
-      }
-    } else if (x->func_type == tir::kHostFunc) {
-      fhost.push_back(x);
-    } else if (x->func_type == tir::kDeviceFunc) {
-      fdevice.push_back(x);
-    } else {
-      LOG(FATAL) << "unknown function type " << x->func_type;
-    }
-  }
-
-  auto keys = target->keys();
-  bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
-  if (target_is_gpu && fdevice.size() == 0) {
-    LOG(WARNING) << "Specified target "
-                 << target->str()
-                 << " but cannot find device code. Did you forget to bind?";
   }
 
+  IRModule mod_mixed = codegen::ToIRModule(funcs);
 
-  if (target->device_type == target::llvm()->device_type &&
-      target_host == target) {
-    CHECK(fdevice.empty()) << "No device code should be generated when target "
-                           << "and host_target are both llvm target."
-                           << "\n";
-  }
-
-  for (size_t i = 0; i < fhost.size(); ++i) {
-    auto func = fhost[i];
-    func = tir::BindDeviceType(func, target->device_type);
-    fhost.Set(i, func);
+  Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
+  if (config->detect_global_barrier) {
+    mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
   }
+  mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
+  mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
+  mixed_pass_list.push_back(tir::transform::InferFragment());
+  mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+  mixed_pass_list.push_back(tir::transform::BindDeviceType());
+  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+  auto opt_mixed = transform::Sequential(mixed_pass_list);
+  mod_mixed = opt_mixed(std::move(mod_mixed));
 
-  // host pipeline
-  auto mhost = codegen::ToIRModule(fhost);
   auto host_pass_list = {
+    FilterBy([](const tir::PrimFunc& f) {
+      int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
+      return value != static_cast<int>(CallingConv::kDeviceKernelLaunch);
+    }),
     BindTarget(target_host),
     tir::transform::LowerTVMBuiltin(),
     tir::transform::LowerIntrin(),
@@ -261,18 +236,38 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs,
     tir::transform::CombineContextCall(),
   };
   auto opt_host = transform::Sequential(host_pass_list);
-  mhost = opt_host(mhost);
+  auto mhost = opt_host(mod_mixed);
 
   // device pipeline
-  auto mdevice = codegen::ToIRModule(fdevice);
   auto device_pass_list = {
+    FilterBy([](const tir::PrimFunc& f) {
+      int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
+      return value == static_cast<int>(CallingConv::kDeviceKernelLaunch);
+    }),
     BindTarget(target),
     tir::transform::LowerWarpMemory(),
     tir::transform::LowerIntrin(),
     tir::transform::LowerDeviceStorageAccessInfo(),
   };
   auto opt_device = transform::Sequential(device_pass_list);
-  mdevice = opt_device(mdevice);
+  auto mdevice = opt_device(mod_mixed);
+
+  // some final misc checks.
+  auto keys = target->keys();
+  bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
+  if (target_is_gpu && mdevice->functions.size() == 0) {
+    LOG(WARNING) << "Specified target "
+                 << target->str()
+                 << " but cannot find device code. Did you forget to bind?";
+  }
+
+  if (target->device_type == target::llvm()->device_type &&
+      target_host == target) {
+    CHECK(mdevice->functions.empty())
+        << "No device code should be generated when target "
+        << "and host_target are both llvm target."
+        << "\n";
+  }
 
   return {mhost, mdevice};
 }
index 56e77b7..bda997a 100644 (file)
@@ -34,6 +34,7 @@
  */
 #include <tvm/ir/type_functor.h>
 #include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include "doc.h"
@@ -434,6 +435,10 @@ class RelayTextPrinter :
   Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
     if (auto* n = base_func.as<relay::FunctionNode>()) {
       return PrintFunc(prefix, GetRef<relay::Function>(n));
+    } else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
+      std::ostringstream os;
+      os << GetRef<tir::PrimFunc>(n);
+      return Doc::RawText(os.str());
     } else {
       // def @xyz = meta['ExternalFunc'][id]
       Doc doc;
@@ -455,8 +460,9 @@ class RelayTextPrinter :
     }
     // functions
     for (const auto& kv : mod->functions) {
-      dg_ = DependencyGraph::Create(&arena_, kv.second);
-
+      if (kv.second.as<relay::FunctionNode>()) {
+        dg_ = DependencyGraph::Create(&arena_, kv.second);
+      }
       if (counter++ != 0) {
         doc << Doc::NewLine();
       }
index a977d35..703328f 100644 (file)
@@ -50,9 +50,10 @@ tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
   Map<tir::Var, PrimExpr> remap_vars;
 
   for (auto var : from->args) {
-    if (from->handle_data_type.count(var)) {
+    auto it = from->handle_data_type.find(var);
+    if (it != from->handle_data_type.end()) {
       tir::Var new_var(var->name_hint,
-                       PointerType(PrimType(var->dtype)));
+                       PointerType(PrimType((*it).second->dtype)));
       args.push_back(new_var);
       remap_vars.Set(var, new_var);
     } else {
index 70bcfe8..33a3e17 100644 (file)
@@ -24,6 +24,7 @@
 
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/tir/ir_pass.h>
+#include <tvm/tir/analysis.h>
 #include <memory>
 #include <unordered_map>
 #include "codegen_cpu.h"
index f991e90..773c67d 100644 (file)
@@ -108,8 +108,13 @@ IRModule PrimFuncPassNode::operator()(const IRModule& mod,
       updates.push_back({it.first, updated_func});
     }
   }
+  // automatic removal of None
   for (const auto& pair : updates) {
-    updated_mod->Add(pair.first, pair.second, true);
+    if (pair.second.defined()) {
+      updated_mod->Add(pair.first, pair.second, true);
+    } else {
+      updated_mod->Remove(pair.first);
+    }
   }
   pass_ctx.Trace(updated_mod, pass_info, false);
   return updated_mod;
index ff821fe..83db1a9 100644 (file)
@@ -128,10 +128,7 @@ REGISTER_PASS(VectorizeLoop);
 REGISTER_PASS(SkipVectorize);
 REGISTER_PASS(UnrollLoop);
 REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(ThreadSync);
 REGISTER_PASS(MakeAPI);
-REGISTER_PASS(BindDeviceType);
-REGISTER_PASS(SplitHostDevice);
 REGISTER_PASS(StorageRewrite);
 REGISTER_PASS(CoProcSync);
 REGISTER_PASS(LowerStorageAccessInfo);
@@ -141,7 +138,6 @@ REGISTER_PASS(InjectDoubleBuffer);
 REGISTER_PASS(LoopPartition);
 REGISTER_PASS(RemoveNoOp);
 REGISTER_PASS(LiftAttrScope);
-REGISTER_PASS(LowerThreadAllreduce);
 REGISTER_PASS(RemapThreadAxis);
 REGISTER_PASS(LowerCustomDatatypes);
 REGISTER_PASS(VerifyMemory);
@@ -150,7 +146,6 @@ REGISTER_PASS(DecorateDeviceScope);
 REGISTER_PASS(InstrumentBoundCheckers);
 REGISTER_PASS(VerifyCompactBuffer);
 REGISTER_PASS(HoistIfThenElse);
-REGISTER_PASS(InferFragment)
 REGISTER_PASS(NarrowDataType);
 }  // namespace tir
 }  // namespace tvm
index f8eae64..861cd43 100644 (file)
@@ -218,69 +218,6 @@ LoweredFunc MakeAPI(Stmt body,
   return f;
 }
 
-class DeviceTypeBinder: public StmtExprMutator {
- public:
-  explicit DeviceTypeBinder(int device_type)
-      : device_type_(device_type) {}
-
-  Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::device_context_type) {
-      if (const VarNode* var = op->value.as<VarNode>()) {
-        var_ = var;
-        PrimExpr value = make_const(op->value.dtype(), device_type_);
-        Stmt body = StmtExprMutator::VisitStmt_(op);
-        var_ = nullptr;
-        std::ostringstream os;
-        os << "device_type need to be " << device_type_;
-        return AssertStmtNode::make(op->value == value, os.str(), body);
-      }
-    }
-    return StmtExprMutator::VisitStmt_(op);
-  }
-
-  Stmt VisitStmt_(const IfThenElseNode* op) final {
-    // eager simplify if guard.
-    Stmt res = StmtExprMutator::VisitStmt_(op);
-    op = res.as<IfThenElseNode>();
-    if (is_zero(op->condition)) {
-      if (op->else_case.defined()) return op->else_case;
-      return EvaluateNode::make(0);
-    }
-    if (is_one(op->condition)) {
-      return op->then_case;
-    }
-    return res;
-  }
-
-  PrimExpr VisitExpr_(const NENode* op) final {
-    // eager check NE for device check
-    PrimExpr res = StmtExprMutator::VisitExpr_(op);
-    op = res.as<NENode>();
-    if (tir::ExprDeepEqual()(op->a, op->b)) {
-      return make_const(op->dtype, false);
-    }
-    return res;
-  }
-
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    if (op == var_) {
-      return make_const(op->dtype, device_type_);
-    } else {
-      return GetRef<PrimExpr>(op);
-    }
-  }
-
- public:
-  const VarNode* var_{nullptr};
-  int device_type_;
-};
-
-LoweredFunc BindDeviceType(LoweredFunc f,
-                           int device_type) {
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = DeviceTypeBinder(device_type)(n->body);
-  return LoweredFunc(n);
-}
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc
new file mode 100644 (file)
index 0000000..486f21c
--- /dev/null
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bind_device_type.cc
+ * \brief Bind the device type according to the target field.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace tir {
+
+class DeviceTypeBinder: public StmtExprMutator {
+ public:
+  explicit DeviceTypeBinder(int device_type)
+      : device_type_(device_type) {}
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::device_context_type) {
+      if (const VarNode* var = op->value.as<VarNode>()) {
+        var_ = var;
+        PrimExpr value = make_const(op->value.dtype(), device_type_);
+        Stmt body = StmtExprMutator::VisitStmt_(op);
+        var_ = nullptr;
+        std::ostringstream os;
+        os << "device_type need to be " << device_type_;
+        return AssertStmtNode::make(op->value == value, os.str(), body);
+      }
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
+    // eager simplify if guard.
+    Stmt res = StmtExprMutator::VisitStmt_(op);
+    op = res.as<IfThenElseNode>();
+    if (is_zero(op->condition)) {
+      if (op->else_case.defined()) return op->else_case;
+      return EvaluateNode::make(0);
+    }
+    if (is_one(op->condition)) {
+      return op->then_case;
+    }
+    return res;
+  }
+
+  PrimExpr VisitExpr_(const NENode* op) final {
+    // eager check NE for device check
+    PrimExpr res = StmtExprMutator::VisitExpr_(op);
+    op = res.as<NENode>();
+    if (tir::ExprDeepEqual()(op->a, op->b)) {
+      return make_const(op->dtype, false);
+    }
+    return res;
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    if (op == var_) {
+      return make_const(op->dtype, device_type_);
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+ public:
+  const VarNode* var_{nullptr};
+  int device_type_;
+};
+
+namespace transform {
+
+Pass BindDeviceType() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+    CHECK(target.defined())
+        << "BindDeviceType: Require the target attribute";
+    n->body = DeviceTypeBinder(target->device_type)(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.BindDeviceType")
+.set_body_typed(BindDeviceType);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
index e7e89f8..c4df2dc 100644 (file)
@@ -340,14 +340,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   std::unordered_map<const VarNode *, Stmt> alloc_remap_;
 };
 
-LoweredFunc
-LowerThreadAllreduce(LoweredFunc f, int warp_size) {
-  CHECK_NE(f->func_type, kHostFunc);
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = ThreadAllreduceBuilder(warp_size)(n->body);
-  return LoweredFunc(n);
-}
-
 namespace transform {
 
 Pass LowerThreadAllreduce() {
@@ -356,10 +348,6 @@ Pass LowerThreadAllreduce() {
     auto target = f->GetAttr<Target>(tvm::attr::kTarget);
     CHECK(target.defined())
         << "LowerThreadAllreduce: Require the target attribute";
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
-    CHECK(calling_conv.defined() &&
-          calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
-        << "LowerThreadAllreeduce: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
     n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body);
     return f;
   };
similarity index 61%
rename from src/tir/pass/split_host_device.cc
rename to src/tir/transforms/split_host_device.cc
index 519101f..838ad82 100644 (file)
  * \file split_host_device.cc
  * \brief Split device function from host.
  */
+#include <tvm/ir/transform.h>
 #include <tvm/tir/expr.h>
-#include <tvm/tir/lowered_func.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/runtime/module.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
+
 #include <unordered_map>
 
 namespace tvm {
 namespace tir {
 
 // use/def analysis, also delete unreferenced lets
-class IRUseDefAnalysis : public StmtExprMutator {
+class VarUseDefAnalysis : public StmtExprMutator {
  public:
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
@@ -156,8 +160,27 @@ class IRUseDefAnalysis : public StmtExprMutator {
   std::unordered_map<const VarNode*, int> def_count_;
 };
 
+
+Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
+  VarUseDefAnalysis m;
+  for (Var arg : args) {
+    m.use_count_[arg.get()] = 0;
+  }
+  m(stmt);
+  return m.undefined_;
+}
+
+
 class HostDeviceSplitter : public StmtMutator {
  public:
+  explicit HostDeviceSplitter(IRModuleNode* device_mod,
+                              Target device_target,
+                              std::string name_prefix)
+      : device_mod_(device_mod),
+        device_target_(device_target),
+        name_prefix_(name_prefix) {
+  }
+
   Stmt VisitStmt_(const AllocateNode* op) final {
     handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
     return StmtMutator::VisitStmt_(op);
@@ -172,86 +195,128 @@ class HostDeviceSplitter : public StmtMutator {
     return StmtMutator::VisitStmt_(op);
   }
 
-  Array<LoweredFunc> Split(LoweredFunc f) {
-    CHECK_EQ(f->func_type, kMixedFunc);
-    for (auto kv : f->handle_data_type) {
-      handle_data_type_[kv.first.get()] = kv.second;
-    }
-    name_ = f->name;
-    ObjectPtr<LoweredFuncNode> n =
-        make_object<LoweredFuncNode>(*f.operator->());
-    n->body = operator()(f->body);
-    n->func_type = kHostFunc;
-    Array<LoweredFunc> ret{LoweredFunc(n)};
-    for (LoweredFunc x : device_funcs_) {
-      ret.push_back(x);
-    }
-    return ret;
-  }
-
  private:
   Stmt SplitDeviceFunc(Stmt body) {
     std::ostringstream os;
-    os << name_ << "_kernel" << device_funcs_.size();
-    ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>();
+    os << name_prefix_ << "_kernel" << device_func_counter_++;
+    std::string kernel_symbol = os.str();
     // isolate the device function.
-    IRUseDefAnalysis m;
+    VarUseDefAnalysis m;
     m.visit_thread_extent_ = false;
-    n->body = m(std::move(body));
-    n->name = os.str();
-    n->func_type = kDeviceFunc;
-    n->thread_axis = m.thread_axis_;
+    body = m(std::move(body));
+
+    Array<Var> params;
+    Array<PrimExpr> arguments;
+    Map<tir::Var, PrimExpr> remap_vars;
+
     // Strictly order the arguments: Var pointers, positional arguments.
-    for (Var v : m.undefined_) {
-      if (v.dtype().is_handle()) {
-        n->args.push_back(v);
-        // mark handle data type.
-        auto it = handle_data_type_.find(v.get());
+    for (Var var : m.undefined_) {
+      if (var.dtype().is_handle()) {
+        // Create a new version of v.
+        auto it = handle_data_type_.find(var.get());
         if (it != handle_data_type_.end()) {
-          n->handle_data_type.Set(v, it->second);
+          tir::Var new_var(var->name_hint,
+                           PointerType(PrimType((*it).second->dtype)));
+          params.push_back(new_var);
+          remap_vars.Set(var, new_var);
+        } else {
+          params.push_back(var);
         }
+        arguments.push_back(var);
       }
     }
-    for (Var v : m.undefined_) {
-      if (!v.dtype().is_handle()) {
-        n->args.push_back(v);
+    // positional arguments
+    for (Var var : m.undefined_) {
+      if (!var.dtype().is_handle()) {
+        params.push_back(var);
+        arguments.push_back(var);
       }
     }
-    LoweredFunc f_device(n);
+    PrimFunc device_func(params, Substitute(body, remap_vars));
+    device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_);
+    device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
+                           Integer(CallingConv::kDeviceKernelLaunch));
+    device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
+                           runtime::String(kernel_symbol));
+    device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
+    device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
+    device_mod_->Add(GlobalVar(kernel_symbol), device_func);
+
+    // generate calls to the device function
     Array<PrimExpr> call_args;
-    call_args.push_back(StringImmNode::make(f_device->name));
-    for (Var arg : n->args) {
+    call_args.push_back(StringImmNode::make(kernel_symbol));
+    for (PrimExpr arg : arguments) {
       call_args.push_back(arg);
     }
     for (PrimExpr ext : m.thread_extent_) {
       call_args.push_back(ext);
     }
-    device_funcs_.emplace_back(f_device);
     return EvaluateNode::make(CallNode::make(
         DataType::Int(32), intrinsic::tvm_call_packed,
         call_args, CallNode::Intrinsic));
   }
 
-  // function name
-  std::string name_;
-  // the device functions
+  // target ir module
+  IRModuleNode* device_mod_;
+  // Device target
+  Target device_target_;
+  // function name hint
+  std::string name_prefix_;
+  // Number of device functions.
+  int device_func_counter_{0};
   std::vector<LoweredFunc> device_funcs_;
   std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
 };
 
 
-Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
-  IRUseDefAnalysis m;
-  for (Var arg : args) {
-    m.use_count_[arg.get()] = 0;
-  }
-  m(stmt);
-  return m.undefined_;
+PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
+  auto target = func->GetAttr<Target>(tvm::attr::kTarget);
+  CHECK(target.defined())
+      << "SplitHostDevice: Require the target attribute";
+  auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  CHECK(global_symbol.defined())
+      << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
+
+  HostDeviceSplitter splitter(
+      device_mod, target, static_cast<std::string>(global_symbol));
+
+  auto* n = func.CopyOnWrite();
+  n->body = splitter(std::move(n->body));
+  // set the host target to None.
+  func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
+  return std::move(func);
 }
 
-Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
-  return HostDeviceSplitter().Split(func);
+
+
+namespace transform {
+
+Pass SplitHostDevice() {
+  auto pass_func = [](IRModule m, PassContext ctx) {
+    IRModuleNode* mptr = m.CopyOnWrite();
+    std::vector<std::pair<GlobalVar, PrimFunc> > updates;
+
+    for (const auto& kv : mptr->functions) {
+      if (auto* n = kv.second.as<PrimFuncNode>()) {
+        PrimFunc func = GetRef<PrimFunc>(n);
+        auto updated_func = SplitHostDevice(std::move(func), mptr);
+        updates.push_back({kv.first, updated_func});
+      }
+    }
+
+    for (const auto& pair : updates) {
+      mptr->Add(pair.first, pair.second, true);
+    }
+    return m;
+  };
+
+  return tvm::transform::CreateModulePass(
+      pass_func, 0, "tir.SplitHostDevice", {});
 }
 
+TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice")
+.set_body_typed(SplitHostDevice);
+
+}  // namespace transform
 }  // namespace tir
 }  // namespace tvm
index fad4233..1ece078 100644 (file)
@@ -218,26 +218,19 @@ Stmt InferFragment(Stmt stmt) {
   return stmt;
 }
 
-LoweredFunc InferFragment(LoweredFunc f) {
-  CHECK_NE(f->func_type, kHostFunc);
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = InferFragment(f->body);
-  return LoweredFunc(n);
-}
-
 namespace transform {
 
-Pass InferFragement() {
+Pass InferFragment() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
     n->body = InferFragment(std::move(n->body));
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {});
+  return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
 }
 
-TVM_REGISTER_GLOBAL("tir.transform.InferFragement")
-.set_body_typed(InferFragement);
+TVM_REGISTER_GLOBAL("tir.transform.InferFragment")
+.set_body_typed(InferFragment);
 
 }  // namespace transform
 }  // namespace tir
index b631a62..f464af6 100644 (file)
@@ -374,13 +374,6 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
   return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
 }
 
-LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
-  CHECK_NE(f->func_type, kHostFunc);
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = ThreadSync(f->body, storage_scope);
-  return LoweredFunc(n);
-}
-
 namespace transform {
 
 Pass ThreadSync(std::string storage_scope) {
@@ -28,7 +28,7 @@ def test_loop_dependent_allocate():
     s[AA].compute_at(s[C], s[C].op.axis[0])
     # this line should fail due to IRUseDefAnalysis sees an allocate statement
     # referencing undefined variable
-    tvm.lower(s, [A,C])
+    tvm.lower(s, [A, C])
 
 if __name__ == "__main__":
     test_loop_dependent_allocate()
index 0fe3f61..94e29c6 100644 (file)
@@ -41,7 +41,9 @@ def test_double_buffer():
     assert isinstance(stmt.body.body, tvm.tir.Allocate)
     assert stmt.body.body.extents[0].value == 2
     f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
-    f = tvm.tir.ir_pass.ThreadSync(f, "shared")
+    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
+
     count = [0]
     def count_sync(op):
         if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
index e8a78cb..dbfcd20 100644 (file)
@@ -93,7 +93,10 @@ def test_flatten_double_buffer():
     assert isinstance(stmt.body.body, tvm.tir.Allocate)
     assert stmt.body.body.extents[0].value == 2
     f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
-    f = tvm.tir.ir_pass.ThreadSync(f, "shared")
+    f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
+    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
+
     count = [0]
     def count_sync(op):
         if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
index 66d3cfb..167899a 100644 (file)
@@ -33,16 +33,15 @@ def test_lower_warp_mem():
     xo, xi = s[AA].split(s[AA].op.axis[0], 32)
     s[AA].bind(xi, tx)
 
-    f = tvm.lower(s, [A, B])
-    fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f)
-
-    # temp adapter to convert loweredFunc to IRModule
-    # to test passes in the new style.
-    fname = fdevice.name
-    mod = tvm.testing.LoweredFuncsToIRModule([fdevice])
     cuda_target = tvm.target.create("cuda")
     assert cuda_target.thread_warp_size == 32
-    mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
+    f = tvm.lower(s, [A, B], name="f")
+
+
+    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
+    fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
+    mod = tvm.IRModule.from_expr(fdevice)
     fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
     assert(fdevice.body.body.value.value == "local")
     assert(fdevice.body.body.body.extents[0].value == 2)
index e692e23..6c9e7f9 100644 (file)
@@ -38,13 +38,13 @@ def test_thread_storage_sync():
     A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
     f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
-    flist = tvm.tir.ir_pass.SplitHostDevice(f)
-    f = flist[1]
-    fname = f.name
-    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    cuda_target = tvm.target.create("cuda")
 
+    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
+    fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
+    mod = tvm.IRModule.from_expr(fdevice)
     cuda_target = tvm.target.create("cuda")
-    mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
     f = tvm.tir.transform.ThreadSync("shared")(mod)["main"]
     body_list = tvm.tir.stmt_list(f.body.body.body.body)
     assert(body_list[1].value.name == "tvm_storage_sync")