[REFACTOR][TIR] Migrate all low-level passes to the Pass Manager. (#5233)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sun, 5 Apr 2020 00:36:49 +0000 (17:36 -0700)
committerGitHub <noreply@github.com>
Sun, 5 Apr 2020 00:36:49 +0000 (17:36 -0700)
* [REFACTOR][TIR] Migrate all low-level passes to the Pass Manager.

This PR migrates the tvm.lower to return IRModule of PrimFuncs
instead of the LoweredFuncs.

* Remove LoweredFunc.

63 files changed:
apps/lldb/tvm.py
docs/dev/codebase_walkthrough.rst
include/tvm/driver/driver_api.h
include/tvm/ir/module.h
include/tvm/target/codegen.h
include/tvm/tir/analysis.h
include/tvm/tir/ir_pass.h
include/tvm/tir/lowered_func.h [deleted file]
include/tvm/tir/transform.h
python/tvm/driver/build_module.py
python/tvm/relay/backend/_backend.py
python/tvm/relay/backend/graph_runtime_codegen.py
python/tvm/runtime/__init__.py
python/tvm/target/build_config.py
python/tvm/testing.py
python/tvm/tir/__init__.py
python/tvm/tir/analysis/analysis.py
python/tvm/tir/function.py
python/tvm/tir/stmt.py
python/tvm/tir/transform/transform.py
src/contrib/hybrid/codegen_hybrid.h
src/driver/driver_api.cc
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.h
src/relay/backend/graph_runtime_codegen.cc
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h
src/relay/transforms/gradient.cc
src/target/build_common.h
src/target/codegen.cc
src/target/llvm/codegen_llvm.cc
src/target/llvm/llvm_module.cc
src/target/source/codegen_c.h
src/target/spirv/codegen_spirv.h
src/target/stackvm/codegen_stackvm.h
src/tir/analysis/verify_memory.cc [moved from src/tir/pass/verify_memory.cc with 86% similarity]
src/tir/ir/buffer.cc
src/tir/ir/lowered_func.cc [deleted file]
src/tir/pass/ffi_api.cc
src/tir/pass/storage_rewrite.cc
src/tir/transforms/lower_custom_datatypes.cc [moved from src/tir/pass/lower_custom_datatypes.cc with 88% similarity]
src/tir/transforms/make_packed_api.cc [moved from src/tir/pass/make_api.cc with 61% similarity]
src/tir/transforms/remap_thread_axis.cc [moved from src/tir/pass/remap_thread_axis.cc with 73% similarity]
src/tir/transforms/split_host_device.cc
tests/cpp/build_module_test.cc
tests/python/integration/test_dot.py
tests/python/unittest/test_runtime_extension.py
tests/python/unittest/test_runtime_heterogeneous.py
tests/python/unittest/test_runtime_module_load.py
tests/python/unittest/test_target_codegen_llvm.py
tests/python/unittest/test_target_codegen_static_init.py
tests/python/unittest/test_target_codegen_vm_basic.py
tests/python/unittest/test_target_custom_datatypes.py
tests/python/unittest/test_tir_analysis_verify_memory.py [moved from tests/python/unittest/test_tir_pass_verify_memory.py with 70% similarity]
tests/python/unittest/test_tir_pass_bound_checkers.py
tests/python/unittest/test_tir_pass_inject_double_buffer.py
tests/python/unittest/test_tir_pass_loop_partition.py
tests/python/unittest/test_tir_pass_storage_flatten.py
tests/python/unittest/test_tir_transform_combine_context_call.py
tests/python/unittest/test_tir_transform_lower_warp_memory.py
tests/python/unittest/test_tir_transform_make_packed_api.py [moved from tests/python/unittest/test_tir_pass_makeapi.py with 84% similarity]
tests/python/unittest/test_tir_transform_thread_sync.py
tutorials/dev/low_level_custom_pass.py

index 811d32d..135aeff 100644 (file)
@@ -46,7 +46,6 @@ def __lldb_init_module(debugger, _):
         "tvm::IterVarAttr",
         "tvm::IterVarRelation",
         "tvm::Layout",
-        "tir::LoweredFunc",
         "tvm::Map",
         "tvm::Map",
         "tvm::MemoryInfo",
index b7eb06b..a66328f 100644 (file)
@@ -145,15 +145,6 @@ After lowering is done, ``build()`` function generates target machine code from
 
 Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
 
-::
-
-   runtime::Module Build(const Array<LoweredFunc>& funcs,
-                         const std::string& target) {
-     std::string build_f_name = "codegen.build_" + target;
-     const PackedFunc* bf = runtime::Registry::Get(build_f_name);
-     runtime::Module m = (*bf)(funcs, target);
-     return m;
-   }
 
 
 The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:
index 64d5173..e6d4427 100644 (file)
@@ -32,8 +32,8 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/target/target.h>
 #include <tvm/support/with.h>
+#include <tvm/ir/module.h>
 #include <tvm/te/schedule_pass.h>
-#include <tvm/tir/lowered_func.h>
 
 #include <string>
 #include <vector>
 
 namespace tvm {
 /*!
-* \brief Build a LoweredFunc given a schedule, args and binds
+* \brief Build an IRModule given a schedule, args and binds
 * \param sch The schedule to lower.
 * \param args The arguments to the function.
 * \param name The name of the lowered function.
 * \param binds Buffer assignments.
 * \param config The build configuration.
-* \return The lowered function.
+* \return The result module.
 */
-TVM_DLL Array<tir::LoweredFunc> lower(
+TVM_DLL IRModule lower(
     te::Schedule sch,
     const Array<te::Tensor>& args,
     const std::string& name,
@@ -59,44 +59,43 @@ TVM_DLL Array<tir::LoweredFunc> lower(
     const BuildConfig& config);
 
 /*!
-* \brief Build a device and host module for a specific target from an array of lowered functions.
+* \brief Build a device and host module for a specific target from an IRModule.
 * \param funcs The functions to be built.
 * \param target The target device to build for.
 * \param target_host The target for building host code. To use the default, pass Target()
 * \param config The build configuration.
 * \return The built module.
 */
-TVM_DLL runtime::Module build(const Array<tir::LoweredFunc>& funcs,
+TVM_DLL runtime::Module build(const IRModule& funcs,
                               const Target& target,
                               const Target& target_host,
                               const BuildConfig& config);
 
 /*!
  * \brief Build a device and host module for a specific target from a map
- * contains target to a list of lowered functions pairs. This function is used
+ * contains target to IRModule. This function is used
  * for heterogeneous build.
- * \param input The map contains target to a list of lowered functions pairs.
+ * \param input The map contains target to an IRModule.
  * \param target_host The target for building host code. To use the default,
  *        pass Target().
  * \param config The build configuration.
  * \return The built module that contains code for different processors.
  */
-TVM_DLL runtime::Module build(const Map<Target, Array<tir::LoweredFunc>>& input,
+TVM_DLL runtime::Module build(const Map<Target, IRModule>& input,
                               const Target& target_host,
                               const BuildConfig& config);
 
 /*!
  * \brief Build a device and host module for a specific target from a map
- * contains target to a list of lowered functions pairs. This function is used
+ * contains target to IRModule. This function is used
  * for heterogeneous build.
- * \param input The map contains target string to a list of lowered functions
- *        pairs.
+ * \param input The map contains target string to an  IRModule.
  * \param target_host The target for building host code. To use the default,
  *        pass Target().
  * \param config The build configuration.
  * \return The built module that contains code for different processors.
  */
-TVM_DLL runtime::Module build(const Map<std::string, Array<tir::LoweredFunc>>& input,
+TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input,
                               const Target& target_host,
                               const BuildConfig& config);
 }  // namespace tvm
index f63bf96..b0776de 100644 (file)
@@ -297,6 +297,15 @@ class IRModule : public ObjectRef {
     CHECK(ptr != nullptr);
     return static_cast<IRModuleNode*>(ptr);
   }
+
+  /*!
+   * \brief Construct an empty module.
+   *
+   * \returns The constructed module
+   */
+  static IRModule Empty() {
+    return IRModule(Map<GlobalVar, BaseFunc>());
+  }
   /*!
    * \brief Construct a module from a standalone expression.
    *
index c604eb5..4b7ea56 100644 (file)
@@ -27,7 +27,6 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/ir/module.h>
 #include <tvm/tir/expr.h>
-#include <tvm/tir/lowered_func.h>
 #include <tvm/target/target.h>
 
 #include <string>
@@ -42,17 +41,6 @@ using runtime::TVMArgs;
 using runtime::TVMRetValue;
 
 /*!
- * \brief Temporary backward compatible function to convert a list
- *  of LoweredFunc to a IRModule of PrimfFuncs
- * \param funcs The input lowered function.
- * \return The IRModule.
- *
- * \note This function is only used for code refactor and will be
- *       removed once the refactor completes.
- */
-IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);
-
-/*!
  * \brief Build a module from array of lowered function.
  * \param mod The Module to be built
  * \param target The target to be built.
index fe74a96..6af9958 100644 (file)
 #ifndef TVM_TIR_ANALYSIS_H_
 #define TVM_TIR_ANALYSIS_H_
 
+#include <tvm/ir/module.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
 #include <tvm/tir/stmt.h>
 
+
 namespace tvm {
 namespace tir {
 
@@ -59,6 +62,18 @@ struct ExprDeepEqual {
  */
 Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
 
+/*!
+ * \brief Verify if memory accesses are legal for a specific target device type.
+ *
+ *  In the case that tgt is cuda, if not all workload is bound with
+ *  threads, CPU code is generated that tries to access GPU memory,
+ *  which is illegal. This pass performs verification for this case.
+ *
+ * \param mod The module to be verified.
+ * \return Success of memory verification.
+ */
+void VerifyMemory(const IRModule& mod);
+
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_ANALYSIS_H_
index 8ba008b..e228ce3 100644 (file)
@@ -31,7 +31,6 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/buffer.h>
 #include <tvm/tir/function.h>
-#include <tvm/tir/lowered_func.h>
 
 #include <unordered_map>
 #include <unordered_set>
@@ -367,60 +366,6 @@ Stmt HoistIfThenElse(Stmt stmt);
 Stmt NarrowDataType(Stmt stmt, int target_bits);
 
 /*!
- * \brief Make an user callable API LoweredFunc.
- *
- *  The main task of this function is to create code to :
- *   - Map the values in the api_args to Var that is required by body.
- *   - Insert assertions to check type/value of the passed arguments.
- *
- * \param body The body of the function.
- * \param name The name of the function.
- * \param api_args Arguments to the function, can be either Var, or Buffer
- * \param num_unpacked_args Number of arguments that
- *         are processed in plain form instead of packed form.
- * \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
- *  It is recommended to set to true for optimized code if such invariant holds.
- *
- * \return a LoweredFunc with the specified signiture.
- *
- * \note
- *  The function signature have two cases
- *
- *  let num_packed_args = len(api_args) - num_unpacked_args;
- *
- *  if num_packed_args is zero:
- *     f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
- *
- *  if num_packed_args is not zero:
- *       f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
- *         api_arg_k, api_arg_k+1, ... api_arg_n,
- *         TVMValue* out_ret_val, int* out_ret_tcode)
- *
- *       where n == len(api_args), k == num_packed_args
- *
- *  There is no thread_axis in generated function.
- */
-LoweredFunc MakeAPI(Stmt body,
-                    std::string name,
-                    Array<ObjectRef> api_args,
-                    int num_unpacked_args,
-                    bool is_restricted);
-
-/*!
- * \brief Remap the thread axis
- *
- *  This can be used to get equivalent program which uses
- *  threadIdx.y in place of threadIdx.x by passing
- *  {"threadIdx.x": thread_axis("threadIdx.y")}
- *
- *
- * \param f The device function to be lowered.
- * \param axis_map The map from StringImm -> ItrVar
- * \return Transformed function.
- */
-LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
-
-/*!
  * \brief Rewrite the pointer content type of arguments,
  *  as well as Alloc internal to the function to use
  *  the most frequently accessed type for load/store
@@ -433,31 +378,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
 PrimFunc PointerValueTypeRewrite(PrimFunc f);
 
 /*!
- * \brief Lower custom datatypes.
- *
- * See tvm::datatypes::Registry for more information on adding custom datatypes.
- *
- * \param f The device function to be lowered.
- * \param target The target device.
- * \return Transformed function.
- */
-LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
-
-/*!
- * \brief Verify if memory accesses are legal for a specific target device type.
- *
- *  In the case that tgt is cuda, if not all workload is bound with
- *  threads, CPU code is generated that tries to access GPU memory,
- *  which is illegal. This pass performs verification for this case.
- *
- * \param func The function to be verified.
- * \param device_type The target device type.
- * \return Success of memory verification.
- */
-bool VerifyMemory(LoweredFunc func, int device_type);
-
-
-/*!
  * \brief Verify the correctness of a GPU code
  *        It will check the whether the amount of memory usage or the number of threads
  *        in a block exceeds the limit
diff --git a/include/tvm/tir/lowered_func.h b/include/tvm/tir/lowered_func.h
deleted file mode 100644 (file)
index 2d01c89..0000000
+++ /dev/null
@@ -1,149 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/tir/lowered_func.h
- * \brief Information about a lowered TVM function.
- *  This data structure is final step toward codegen.
- */
-#ifndef TVM_TIR_LOWERED_FUNC_H_
-#define TVM_TIR_LOWERED_FUNC_H_
-
-#include <tvm/node/container.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
-#include <string>
-
-namespace tvm {
-namespace tir {
-
-// Internal node container of lowered function.
-class LoweredFuncNode;
-
-/*!
- * \brief LoweredFunc represents function after lowering.
- *  This is the final IR representation before codegen.
- */
-class LoweredFunc : public FunctionRef {
- public:
-  LoweredFunc() {}
-  explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const LoweredFuncNode* operator->() const;
-  /*! \brief specify container node */
-  using ContainerType = LoweredFuncNode;
-};
-
-/*! \brief specific type of lowered function */
-enum LoweredFuncType : int {
-  /*! \brief Function that can mix device and host calls */
-  kMixedFunc = 0,
-  /*! \brief Only contains host code */
-  kHostFunc = 1,
-  /*! \brief Only contains device code */
-  kDeviceFunc = 2
-};
-
-/*! \brief Node container of LoweredFunc */
-class LoweredFuncNode : public tir::FunctionBaseNode {
- public:
-  /*! \brief The name of the function */
-  std::string name;
-  /*!
-   * \brief The arguments of the function
-   *  This function can only take pod type(int, float) and void* as arguments.
-   */
-  Array<Var> args;
-  /*!
-   * \brief The IterVar axis of threads
-   *  Each axis need host function to specify a size.
-   * \note Calling convention into LoweredFunc
-   *
-   * Assume we have a LoweredFunc f, a call into f
-   *   Call(f, arg1, arg2, ..., arg_n,
-   *        size_axis_1, size_axis_2, ... size_axis_m)
-   *
-   * Here n = len(args), m = len(thread_axis)
-   *
-   * The CodeGen should take this and translate this call
-   * to corresponding API specific kernel launchs or function calls.
-   */
-  Array<IterVar> thread_axis;
-  /*!
-   * \brief The hint data type of Var handles defined in LetStmt
-   *  Can be used as hint when generating type signiture.
-   *  The creation rule is given by
-   *  handle_data_type[var_handle] = make_const(the_type, 0);
-   *
-   * \note Expr is used instead Type, because Type cannot be hold by Map.
-   *  constant Expr of given type is used.
-   */
-  Map<Var, PrimExpr> handle_data_type;
-  /*! \brief The type of the function */
-  LoweredFuncType func_type{kMixedFunc};
-  /*! \brief Whether this function is packed function */
-  bool is_packed_func{true};
-  /*!
-   * \brief Whether function ensures that argument pointers do not alias.
-   *  This corresponds to restrict keyword in C.
-   */
-  bool is_restricted{true};
-  /*! \brief The body statment of the function */
-  Stmt body;
-  /*! \return name of the operation */
-  const std::string& func_name() const final {
-    return name;
-  }
-  // there is no return value, but return 1
-  // to enable Call into this function.
-  int num_outputs() const final {
-    return 1;
-  }
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("name", &name);
-    v->Visit("args", &args);
-    v->Visit("thread_axis", &thread_axis);
-    v->Visit("handle_data_type", &handle_data_type);
-    v->Visit("func_type", &func_type);
-    v->Visit("is_packed_func", &is_packed_func);
-    v->Visit("is_restricted", &is_restricted);
-    v->Visit("body", &body);
-  }
-
-  static constexpr const char* _type_key = "LoweredFunc";
-  TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object);
-};
-
-// Implementations of inline functions
-inline const LoweredFuncNode* LoweredFunc::operator->() const {
-  return static_cast<const LoweredFuncNode*>(get());
-}
-}  // namespace tir
-}  // namespace tvm
-
-namespace std {
-template <>
-struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash {
-};
-}
-
-#endif  // TVM_TIR_LOWERED_FUNC_H_
index 211e344..860014d 100644 (file)
@@ -59,6 +59,61 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
                                 const tvm::Array<tvm::PrimExpr>& required);
 
 /*!
+ * \brief Transform the high-level PrimFunc to a low-level version
+ *        that can be used as an API function.
+ *
+ *
+ *  The main task of this function is to create code to :
+ *   - Map the values in the api_args to Var that is required by body.
+ *   - Insert assertions to check type/value of the passed arguments.
+ *
+ * \param num_unpacked_args Number of arguments that
+ *         are processed in plain form instead of packed form.
+ *
+ * \note
+ *  The function signature have two cases
+ *
+ *  let num_packed_args = len(api_args) - num_unpacked_args;
+ *
+ *  if num_packed_args is zero:
+ *     f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
+ *
+ *  if num_packed_args is not zero:
+ *       f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
+ *         api_arg_k, api_arg_k+1, ... api_arg_n,
+ *         TVMValue* out_ret_val, int* out_ret_tcode)
+ *
+ *       where n == len(api_args), k == num_packed_args
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
+
+
+/*!
+ * \brief Remap the thread axis
+ *
+ *  This can be used to get equivalent program which uses
+ *  threadIdx.y in place of threadIdx.x by passing
+ *  {"threadIdx.x": thread_axis("threadIdx.y")}
+ *
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
+
+
+/*!
+ * \brief Lower custom datatypes.
+ *
+ * See tvm::datatypes::Registry for more information on adding custom datatypes.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerCustomDatatypes();
+
+
+/*!
  * \brief Bind the device type ofthe function to be
  *        the device_type specified in the target attribute.
  *
index e4bd200..0dd01e1 100644 (file)
@@ -17,9 +17,6 @@
 
 # pylint: disable=invalid-name
 """The build utils in python.
-
-This module provides the functions to transform schedule to
-LoweredFunc and compiled Module.
 """
 import warnings
 
@@ -30,7 +27,6 @@ from tvm.ir import container
 from tvm.ir import CallingConv
 from tvm.target import codegen, BuildConfig
 from tvm.tir import ir_pass
-from tvm.tir.stmt import LoweredFunc
 from tvm.te import tensor
 from tvm.te import schedule
 from tvm import target as _target
@@ -136,8 +132,8 @@ def lower(sch,
 
     Returns
     -------
-    f : LoweredFunc or Stmt
-       The result function, if with_api_wrapper=False
+    m : IRModule or Stmt
+       The result IRModule, if simple_mode=False
        Then the Stmt before make api is returned.
     """
     cfg = BuildConfig.current()
@@ -199,16 +195,21 @@ def lower(sch,
     if simple_mode:
         return stmt
 
-    return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
+    f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
+        "global_symbol", tvm.runtime.String(name))
+    if cfg.restricted_func:
+        f = f.with_attr("tir.no_alias", True)
+    mod = tvm.IRModule({name: f})
+    return tvm.tir.transform.MakePackedAPI()(mod)
 
 
-def _build_for_device(flist, target, target_host):
+def _build_for_device(input_mod, target, target_host):
     """Build the lowered functions for a device with the given compilation
     target.
 
     Parameters
     ----------
-    flist : list of LoweredFunc
+    input_mod : IRModule
         The schedule to be built.
 
     target : str or :any:`tvm.target.Target`
@@ -219,8 +220,8 @@ def _build_for_device(flist, target, target_host):
 
     Returns
     -------
-    fhost : list of LoweredFunc
-        A list of lowered functions for the host.
+    fhost : IRModule
+        The host IRModule.
 
     mdev : tvm.module
         A module that contains device code.
@@ -229,14 +230,13 @@ def _build_for_device(flist, target, target_host):
     target_host = _target.create(target_host)
     device_type = ndarray.context(target.target_name, 0).device_type
 
-    for func in flist:
-        if not ir_pass.VerifyMemory(func, device_type):
-            raise ValueError(
-                "Direct host side access to device memory is detected in %s. "
-                "Did you forget to bind?" % func.name)
+    mod_mixed = input_mod
+    mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
+    tvm.tir.analysis.verify_memory(mod_mixed)
 
-    mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
-    opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
+    opt_mixed = []
+    if len(mod_mixed.functions) == 1:
+        opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
     if BuildConfig.current().detect_global_barrier:
         opt_mixed += [tvm.tir.transform.ThreadSync("global")]
     opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
@@ -292,7 +292,7 @@ def build(inputs,
 
     Parameters
     ----------
-    inputs : tvm.te.Schedule, LoweredFunc, or dict of target to LoweredFunc list
+    inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
         The schedule to be built
 
     args : list of Buffer or Tensor or Var, optional
@@ -326,7 +326,7 @@ def build(inputs,
     ________
     There are two typical example uses of this function depending on the type
     of the argument `inputs`:
-    1. it is a list of lowered functions:
+    1. it is an IRModule.
 
     .. code-block:: python
 
@@ -335,10 +335,10 @@ def build(inputs,
         B = te.placeholder((n,), name='B')
         C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
         s = tvm.te.create_schedule(C.op)
-        f = tvm.lower(s, [A, B, C], name="test_add")
-        m = tvm.build(f, target="llvm")
+        m = tvm.lower(s, [A, B, C], name="test_add")
+        rt_mod = tvm.build(m, target="llvm")
 
-    2. it is a dict of compilation target to list of lowered functions:
+    2. it is a dict of compilation target to IRModule.
 
     .. code-block:: python
 
@@ -349,9 +349,9 @@ def build(inputs,
         s1 = tvm.te.create_schedule(C.op)
         with tvm.target.cuda() as cuda_tgt:
           s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
-          f1 = tvm.lower(s1, [A, B, C], name="test_add1")
-          f2 = tvm.lower(s2, [A, B, C], name="test_add2")
-          m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
+          m1 = tvm.lower(s1, [A, B, C], name="test_add1")
+          m2 = tvm.lower(s2, [A, B, C], name="test_add2")
+          rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
 
     Note
     ----
@@ -360,45 +360,36 @@ def build(inputs,
     if isinstance(inputs, schedule.Schedule):
         if args is None:
             raise ValueError("args must be given for build from schedule")
-        flist = lower(inputs, args,
-                      name=name,
-                      binds=binds)
-        if isinstance(flist, LoweredFunc):
-            flist = [flist]
-    elif isinstance(inputs, LoweredFunc):
-        if args:
-            raise ValueError("args must be done when build from LoweredFunc.")
-        flist = [inputs]
+        input_mod = lower(inputs, args,
+                          name=name,
+                          binds=binds)
     elif isinstance(inputs, (list, tuple, container.Array)):
-        flist = inputs
+        merged_mod = tvm.IRModule({})
+        for x in inputs:
+            merged_mod.update(x)
+        input_mod = merged_mod
+    elif isinstance(inputs, tvm.IRModule):
+        input_mod = inputs
     elif not isinstance(inputs, (dict, container.Map)):
-        raise ValueError("inputs must be Schedule, LoweredFunc, list of "
-                         "LoweredFunc, or dict of target to list of "
-                         "LoweredFunc.")
+        raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")
 
     if not isinstance(inputs, (dict, container.Map)):
         target = _target.Target.current() if target is None else target
         target = target if target else "llvm"
-        target_flist = {target: flist}
+        target_input_mod = {target: input_mod}
     else:
-        target_flist = inputs
+        target_input_mod = inputs
 
-    for tar, flist in target_flist.items():
+    for tar, mod in target_input_mod.items():
         if not isinstance(tar, (str, _target.Target)):
             raise ValueError("The key of inputs must be str or "
                              "_target.Target when inputs is dict.")
-        fname_set = set()
-        for x in flist:
-            if not isinstance(x, LoweredFunc):
-                raise ValueError("inputs must be Schedule, LoweredFunc, list "
-                                 "of LoweredFunc, or dict of str to list of "
-                                 "LoweredFunc.")
-            if x.name in fname_set:
-                raise ValueError("Duplicate function name %s" % x.name)
-            fname_set.add(x.name)
+        if not isinstance(mod, tvm.IRModule):
+            raise ValueError("inputs must be Schedule, IRModule,"
+                             "or dict of str to IRModule.")
 
     if not target_host:
-        for tar, _ in target_flist.items():
+        for tar, _ in target_input_mod.items():
             tar = _target.create(tar)
             device_type = ndarray.context(tar.target_name, 0).device_type
             if device_type == ndarray.cpu(0).device_type:
@@ -410,8 +401,8 @@ def build(inputs,
     mod_host_all = tvm.IRModule({})
 
     device_modules = []
-    for tar, flist in target_flist.items():
-        mod_host, mdev = _build_for_device(flist, tar, target_host)
+    for tar, input_mod in target_input_mod.items():
+        mod_host, mdev = _build_for_device(input_mod, tar, target_host)
         mod_host_all.update(mod_host)
         device_modules.append(mdev)
 
index df0347b..641ff04 100644 (file)
@@ -17,7 +17,6 @@
 """The interface of expr function exposed from C++."""
 import tvm._ffi
 import tvm.driver
-from tvm.ir import container as _container
 
 
 @tvm._ffi.register_func("relay.backend.lower")
@@ -40,7 +39,7 @@ def lower(sch, inputs, func_name, source_func):
 
     Returns
     -------
-    lowered_funcs : List[tvm.LoweredFunc]
+    mod : tvm.IRModule
         The result of lowering.
     """
     # pylint: disable=broad-except, import-outside-toplevel
@@ -56,20 +55,17 @@ def lower(sch, inputs, func_name, source_func):
         msg += "-----------------------------\n"
         msg += source_func.astext()
         raise RuntimeError(msg)
-    return f if isinstance(
-        f, (_container.Array, tuple, list)) else [f]
+    return f
 
 
 @tvm._ffi.register_func("relay.backend.build")
-def build(funcs, target, target_host=None):
+def build(mod, target, target_host=None):
     """Backend build function.
 
     Parameters
     ----------
-    funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
-        A list of lowered functions or dictionary mapping from targets to
-        lowered functions.
-
+    mod : tvm.IRModule or Dict[str, tvm.IRModule]
+        Input module
 
     target : tvm.Target
         The target to run the code on.
@@ -84,7 +80,7 @@ def build(funcs, target, target_host=None):
     """
     if target_host == "":
         target_host = None
-    return tvm.driver.build(funcs, target=target, target_host=target_host)
+    return tvm.driver.build(mod, target=target, target_host=target_host)
 
 
 @tvm._ffi.register_func("relay._tensor_value_repr")
index 762210d..3e5f015 100644 (file)
@@ -48,7 +48,7 @@ class GraphRuntimeCodegen(object):
         self._get_graph_json = self._mod["get_graph_json"]
         self._list_params_name = self._mod["list_params_name"]
         self._get_param_by_name = self._mod["get_param_by_name"]
-        self._get_lowered_funcs = self._mod["get_lowered_funcs"]
+        self._get_irmodule = self._mod["get_irmodule"]
         self._setup(mod, target)
 
     def _setup(self, mod, target):
@@ -74,14 +74,14 @@ class GraphRuntimeCodegen(object):
         -------
         graph_json : str
             The graph json that can be consumed by runtime.
-        lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
+        mod : IRModule or Dict[str, IRModule]
             The lowered functions.
         params : Dict[str, tvm.nd.NDArray]
             Additional constant parameters.
         """
         self._codegen(func)
         graph_json = self._get_graph_json()
-        lowered_func = self._get_lowered_funcs()
+        lowered_func = self._get_irmodule()
         param_names = self._list_params_name()
         params = {}
         for name in param_names:
index 24db0e8..235ef0c 100644 (file)
@@ -28,3 +28,4 @@ from .object_generic import convert_to_object, convert, const
 from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
 from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
 from .module import load_module, enabled, system_lib
+from .container import String
index c105175..6a0dcf7 100644 (file)
@@ -20,9 +20,7 @@ import tvm._ffi
 import tvm.ir
 
 from tvm.runtime import Object
-from tvm.ir import container
 from tvm.tir import Stmt
-from tvm.tir.stmt import LoweredFunc
 from . import _ffi_api
 
 
@@ -48,17 +46,13 @@ class DumpIR(object):
         def dump(*args, **kwargs):
             """dump function"""
             retv = func(*args, **kwargs)
-            if not isinstance(retv, (Stmt, LoweredFunc, container.Array)):
+            if not isinstance(retv, (Stmt,)):
                 return retv
             fname = func.func_name if hasattr(func, 'func_name') else func.__name__
             pname = str(self._pass_id) + "_" + fname + "_ir.cc"
             with open(pname, "a") as f:
-                out = retv.body if isinstance(retv, LoweredFunc) else retv
+                out = retv
                 f.write(str(out))
-                if isinstance(retv, container.Array):
-                    for x in retv:
-                        out = x.body if isinstance(x, LoweredFunc) else x
-                        f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
                 self._pass_id += 1
             return retv
         return dump
index 077ac35..9c42930 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+# pylint: disable=invalid-name
 """ TVM testing utilities """
 import logging
 import numpy as np
+import tvm
 import tvm._ffi
 
 
@@ -165,4 +168,40 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
                      x_name, grad.shape, dist, max_diff, avg_diff)
 
 
+def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
+    """Legacy adapter to build a Module from statement.
+
+    Used for migrating existing test cases only.
+
+    Parameters
+    ----------
+    stmt: Stmt
+        The input statement.
+
+    name: str
+        The name of the funciton.
+
+    args: list of Buffer or Vars
+        The function arguments
+
+    num_unpacked_args: int
+        Number of unpacked arguments.
+
+    nolias: bool
+        Whether allow noalias.
+
+    Returns
+    -------
+    mod : IRModule
+        The created IRModule.
+    """
+    f = tvm.tir.PrimFunc(args, stmt).with_attr(
+        "global_symbol", tvm.runtime.String(name))
+    f = f.with_attr("tir.is_entry_func", True)
+    if noalias:
+        f = f.with_attr("tir.no_alias", True)
+    mod = tvm.IRModule({name: f})
+    return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)
+
+
 tvm._ffi._init_api("testing", __name__)
index bd8e33f..b5d9fb1 100644 (file)
@@ -29,7 +29,7 @@ from .expr import IterVar, Any
 
 from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
 from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
+from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
 
 from .function import PrimFunc
 
index 84eeaac..448d0e6 100644 (file)
@@ -55,3 +55,14 @@ def expr_deep_equal(lhs, rhs):
     tvm.ir.structural_equal
     """
     return _ffi_api.expr_deep_equal(lhs, rhs)
+
+
+def verify_memory(mod):
+    """Verify if module contains illegal host side direct memory access.
+
+    Parameters
+    ----------
+    mod: tvm.IRModule
+        The module to be verified.
+    """
+    _ffi_api.verify_memory(mod)
index 37946f6..0ed1762 100644 (file)
@@ -18,6 +18,7 @@
 
 import tvm._ffi
 import tvm.runtime
+from tvm.runtime import Object
 from tvm.ir import BaseFunc
 from .buffer import Buffer
 from .expr import Var
@@ -54,6 +55,7 @@ class PrimFunc(BaseFunc):
         param_list = []
         buffer_map = {} if buffer_map is None else buffer_map
         for x in params:
+            x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
             if isinstance(x, Buffer):
                 var = Var(x.name, dtype="handle")
                 param_list.append(var)
index 0badad3..4531cdf 100644 (file)
@@ -385,14 +385,6 @@ class Prefetch(Stmt):
             _ffi_api.Prefetch, func, value_index, dtype, bounds)
 
 
-@tvm._ffi.register_object
-class LoweredFunc(Object):
-    """Represent a LoweredFunc in TVM."""
-    MixedFunc = 0
-    HostFunc = 1
-    DeviceFunc = 2
-
-
 def stmt_seq(*args):
     """Make sequence of statements
 
index c823c1a..64c31a5 100644 (file)
@@ -60,6 +60,36 @@ def Filter(fcond):
     return _fpass.prim_func_pass(_transform, opt_level=0)
 
 
+def LowerCustomDatatypes():
+    """Lower custom datatypes.
+
+    See tvm::datatypes::Registry for more information on adding custom datatypes.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerCustomDatatypes()
+
+
+def MakePackedAPI(num_unpacked_params=0):
+    """Transform the PrimFuncs in the module to a packed func API.
+
+    Parameters
+    ----------
+    num_unpacked_params : int
+        Number of parameters that we hope to directly pass via normal arguments
+        following the PackedFunc input signature.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    return _ffi_api.MakePackedAPI(num_unpacked_params)
+
+
 def BindDeviceType():
     """Bind the device type of the function to be
        the device_type specified in the target attribute.
index 6491491..9784def 100644 (file)
@@ -27,7 +27,6 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/target/codegen.h>
-#include <tvm/tir/lowered_func.h>
 #include <tvm/te/schedule.h>
 #include <map>
 #include <string>
index d54d6f8..ae1d539 100644 (file)
 #include <tvm/te/operation.h>
 
 #include <tvm/tir/transform.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/target/codegen.h>
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/registry.h>
 
 #include <algorithm>
@@ -39,7 +41,6 @@ namespace tvm {
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
 using runtime::PackedFunc;
-using tir::LoweredFunc;
 
 bool LLVMEnabled() {
   const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
@@ -166,17 +167,6 @@ tir::Stmt BuildStmt(te::Schedule sch,
   return stmt;
 }
 
-Array<LoweredFunc> lower(te::Schedule sch,
-                         const Array<te::Tensor>& args,
-                         const std::string& name,
-                         const std::unordered_map<te::Tensor, tir::Buffer>& binds,
-                         const BuildConfig& config) {
-  Array<ObjectRef> out_arg_list;
-  auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
-  return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
-}
-
-
 transform::Pass BindTarget(Target target) {
   auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
     return WithAttr(std::move(f), tvm::attr::kTarget, target);
@@ -198,18 +188,46 @@ transform::Pass FilterBy(FCond fcond) {
 }
 
 
+IRModule lower(te::Schedule sch,
+               const Array<te::Tensor>& args,
+               const std::string& name,
+               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+               const BuildConfig& config) {
+  Array<ObjectRef> out_arg_list;
+  auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
+
+  Array<tir::Var> params;
+  Map<tir::Var, tir::Buffer> buffer_map;
+
+  for (auto var : out_arg_list) {
+    if (auto* n = var.as<tir::VarNode>()) {
+      params.push_back(GetRef<tir::Var>(n));
+    } else {
+      tir::Buffer buffer = Downcast<tir::Buffer>(var);
+      tir::Var bptr(buffer->name, DataType::Handle());
+      params.push_back(bptr);
+      buffer_map.Set(bptr, buffer);
+    }
+  }
+
+  auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  if (config->restricted_func) {
+    f = WithAttr(std::move(f), "tir.no_alias", Integer(1));
+  }
+  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+  return tir::transform::MakePackedAPI(0)(mod);
+}
+
+
 std::pair<IRModule, IRModule>
-split_dev_host_funcs(const Array<LoweredFunc>& funcs,
+split_dev_host_funcs(IRModule mod_mixed,
                      const Target& target,
                      const Target& target_host,
                      const BuildConfig& config) {
-  for (const auto& x : funcs) {
-    CHECK(tir::VerifyMemory(x, target->device_type))
-        << "Direct host side access to device memory is detected in "
-        << x->func_name() << ". Did you forget to bind?";
-  }
-
-  IRModule mod_mixed = codegen::ToIRModule(funcs);
+  mod_mixed = BindTarget(target)(std::move(mod_mixed));
+  tir::VerifyMemory(mod_mixed);
 
   Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
   if (config->detect_global_barrier) {
@@ -274,10 +292,9 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs,
 
 
 // Build for heterogeneous execution.
-runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
+runtime::Module build(const Map<Target, IRModule>& inputs,
                       const Target& target_host,
                       const BuildConfig& config) {
-  Array<LoweredFunc> fhost_all;
   std::vector<runtime::Module> device_modules;
 
   Target target_host_val = target_host;
@@ -319,10 +336,10 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
 }
 
 // Build for heterogeneous execution when target is a string.
-runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
+runtime::Module build(const Map<std::string, IRModule>& inputs,
                       const Target& target_host,
                       const BuildConfig& config) {
-  Map<Target, Array<LoweredFunc>> updated_input;
+  Map<Target, IRModule> updated_input;
   for (const auto& it : inputs) {
     auto target = Target::Create(it.first);
     if (target->device_name == "vta") {
@@ -334,11 +351,11 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
 }
 
 // Build for homogeneous execution.
-runtime::Module build(const Array<LoweredFunc>& funcs,
+runtime::Module build(const IRModule& funcs,
                       const Target& target,
                       const Target& target_host,
                       const BuildConfig& config) {
-  Map<Target, Array<LoweredFunc>> inputs = {{target, funcs}};
+  Map<Target, IRModule> inputs = {{target, funcs}};
   return build(inputs, target_host, config);
 }
 
index 4073271..eaf78bc 100644 (file)
@@ -38,7 +38,6 @@ namespace tvm {
 namespace relay {
 namespace backend {
 
-using tir::LoweredFunc;
 
 using TargetsMap = Map<tvm::Integer, tvm::Target>;
 using namespace tvm::relay::transform;
@@ -78,16 +77,16 @@ struct GraphCodegen {
   }
 
   Array<tvm::runtime::Module> GetExternalModules() {
-    return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
+    return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
   }
 
-  Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
-    return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
+  Map<std::string, IRModule> GetIRModule() {
+    return CallFunc<Map<std::string, IRModule>>("get_irmodule", nullptr);
   }
 
   std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
     std::unordered_map<std::string, tvm::runtime::NDArray> ret;
-    auto names = CallFunc<Array<tvm::PrimExpr> >("list_params_name", nullptr);
+    auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
     for (auto expr : names) {
       auto key = expr.as<tir::StringImmNode>()->value;
       ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
@@ -152,9 +151,9 @@ class RelayBuildModule : public runtime::ModuleNode {
           this->SetParam(kv.first, kv.second->data);
         }
       });
-    } else if (name == "get_lowered_funcs") {
+    } else if (name == "get_irmodule") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-          *rv = this->graph_codegen_->GetLoweredFunc();
+          *rv = this->graph_codegen_->GetIRModule();
       });
     } else if (name == "get_external_modules") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -452,7 +451,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     ret_.graph_json = graph_codegen_->GetJSON();
     ret_.params = graph_codegen_->GetParams();
 
-    auto lowered_funcs = graph_codegen_->GetLoweredFunc();
+    auto lowered_funcs = graph_codegen_->GetIRModule();
 
     // When there is no lowered_funcs due to reasons such as optimization.
     if (lowered_funcs.size() == 0) {
index 9bd6a4e..4a3a04d 100644 (file)
@@ -27,7 +27,6 @@
 
 #include <tvm/node/structural_equal.h>
 #include <tvm/node/structural_hash.h>
-#include <tvm/tir/lowered_func.h>
 #include <tvm/runtime/module.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr.h>
@@ -82,7 +81,8 @@ struct CachedFuncNode : public Object {
   /*! \brief The schedule to the function */
   te::Schedule schedule;
   /*! \brief The lowered functions to support the function. */
-  tvm::Array<tir::LoweredFunc> funcs;
+  IRModule funcs = IRModule::Empty();
+
   /*! \brief Parameter usage states in the shape function. */
   tvm::Array<Integer> shape_func_param_states;
 
index 0587cd2..c7f1be8 100644 (file)
@@ -55,7 +55,7 @@ using TargetsMap = std::unordered_map<int, Target>;
 /*! \brief Lowered outputs */
 struct LoweredOutput {
   std::string graph_json;
-  Map<std::string, Array<tir::LoweredFunc> > lowered_funcs;
+  Map<std::string, IRModule> lowered_funcs;
   Array<tvm::runtime::Module> external_mods;
   std::unordered_map<std::string, tvm::runtime::NDArray> params;
 };
@@ -214,19 +214,14 @@ class GraphRuntimeCodegen
     LoweredOutput ret;
     ret.graph_json = os.str();
     ret.params = params_;
+
     for (auto& kv : lowered_funcs_) {
       if (ret.lowered_funcs.count(kv.first) == 0) {
-        ret.lowered_funcs.Set(kv.first, Array<tir::LoweredFunc>());
-      }
-      auto& vec = ret.lowered_funcs[kv.first];
-      Array<tir::LoweredFunc> tmp;
-      for (auto f : kv.second) {
-        tmp.push_back(f);
-      }
-      for (auto f : vec) {
-        tmp.push_back(f);
+        ret.lowered_funcs.Set(kv.first, IRModule::Empty());
       }
-      ret.lowered_funcs.Set(kv.first, tmp);
+      auto& mod = ret.lowered_funcs[kv.first];
+      mod->Update(kv.second);
+      ret.lowered_funcs.Set(kv.first, mod);
     }
     ret.external_mods = compile_engine_->LowerExternalFunctions();
     return ret;
@@ -457,12 +452,9 @@ class GraphRuntimeCodegen
     CCacheKey key = (*pf0)(func, target);
     CachedFunc lowered_func = (*pf1)(compile_engine_, key);
     if (!lowered_funcs_.count(target->str())) {
-      lowered_funcs_[target->str()] = {};
+      lowered_funcs_[target->str()] = IRModule::Empty();
     }
-    for (auto f : lowered_func->funcs) {
-      lowered_funcs_[target->str()].insert(f);
-    }
-
+    lowered_funcs_[target->str()]->Update(lowered_func->funcs);
     return GraphAddCallNode(op,
                            _GetUniqueName(lowered_func->func_name),
                            lowered_func->func_name);
@@ -602,8 +594,7 @@ class GraphRuntimeCodegen
   /*! \brief plan memory of device result */
   Map<Expr, Array<IntegerArray>> storage_device_map_;
   /*! \brief lowered funcs */
-  std::unordered_map<std::string, std::unordered_set<tir::LoweredFunc, ObjectHash, ObjectEqual>>
-      lowered_funcs_;
+  std::unordered_map<std::string, IRModule> lowered_funcs_;
   /*! \brief name map */
   std::unordered_map<std::string, size_t> name_map_;
   /*! \brief compile engine */
@@ -655,7 +646,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
         CHECK_GT(this->output_.params.count(key), 0);
         *rv = this->output_.params[key];
       });
-    } else if (name == "get_lowered_funcs") {
+    } else if (name == "get_irmodule") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         *rv = this->output_.lowered_funcs;
       });
index 4d15c76..78ebb0f 100644 (file)
@@ -226,6 +226,7 @@ std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
   return raw_shape;
 }
 
+
 class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
  public:
   VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
@@ -407,12 +408,15 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     CCacheKey key(func, target_host_);
     auto cfunc = engine_->LowerShapeFunc(key);
     int op_index = -1;
-    if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
+    // pick the only function inside the context
+    CHECK_EQ(cfunc->funcs->functions.size(), 1);
+    auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
+    if (context_->seen_funcs.count(pfunc) == 0) {
       op_index = context_->cached_funcs.size();
       context_->cached_funcs.push_back(cfunc);
-      context_->seen_funcs[cfunc->funcs[0]] = op_index;
+      context_->seen_funcs[pfunc] = op_index;
     } else {
-      op_index = context_->seen_funcs[cfunc->funcs[0]];
+      op_index = context_->seen_funcs[pfunc];
     }
 
     // Prepare input and output registers
@@ -494,13 +498,14 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       context_->cached_funcs.push_back(cfunc);
     } else {
       // TODO(jroesch): support lowered funcs for multiple targets
-      CHECK_EQ(cfunc->funcs.size(), 1);
-      if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+      CHECK_EQ(cfunc->funcs->functions.size(), 1);
+      auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
+      if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) {
         op_index = context_->cached_funcs.size();
         context_->cached_funcs.push_back(cfunc);
-        context_->seen_funcs[cfunc->funcs[0]] = op_index;
+        context_->seen_funcs[pfunc] = op_index;
       } else {
-        op_index = context_->seen_funcs[cfunc->funcs[0]];
+        op_index = context_->seen_funcs[pfunc];
       }
     }
 
@@ -862,11 +867,7 @@ void VMCompiler::Lower(IRModule mod,
   // update primitive function map
   size_t primitive_index = 0;
   for (const auto& cfunc : context_.cached_funcs) {
-    if (cfunc->target->str() == "ext_dev") {
-      exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
-    } else {
-      exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
-    }
+    exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
   }
 }
 
@@ -961,8 +962,6 @@ void VMCompiler::PopulateGlobalMap() {
 }
 
 void VMCompiler::Codegen() {
-  using tir::LoweredFunc;
-
   if (!context_.module.defined()) {
     LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
     return;
@@ -971,15 +970,21 @@ void VMCompiler::Codegen() {
   if (cached_funcs.size() == 0) {
     return;
   }
-  std::unordered_map<std::string, Array<LoweredFunc>> funcs;
+  std::unordered_map<std::string, IRModule> funcs;
+
   for (auto& cfunc : cached_funcs) {
     std::string target_str = cfunc->target->str();
+    // NOTE: because module, is mutable, we need to make an
+    // explicit copy of the IRModule.
+    IRModule mod = cfunc->funcs;
+    mod.CopyOnWrite();
+
     if (target_str == "ext_dev") {
       continue;
     } else if (funcs.count(target_str) == 0) {
-      funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
+      funcs.emplace(target_str, mod);
     } else {
-      funcs[target_str].push_back(cfunc->funcs[0]);
+      funcs[target_str]->Update(mod);
     }
   }
 
index f18e2c0..c1040f1 100644 (file)
@@ -76,7 +76,7 @@ struct VMCompilerContext {
   // List of cached functions
   std::vector<CachedFunc> cached_funcs;
   // The functions that have been lowered.
-  std::unordered_map<tir::LoweredFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
+  std::unordered_map<tir::PrimFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
 };
 
 
index a3728e9..d0ff169 100644 (file)
@@ -22,7 +22,6 @@
  * \brief API for Automatic Differentiation for the Relay IR.
  */
 #include <tvm/ir/type_functor.h>
-#include <tvm/tir/lowered_func.h>
 #include <tvm/te/operation.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/analysis.h>
index 47ec8f0..fc45cef 100644 (file)
 #include <tvm/tir/function.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
-#include <tvm/tir/lowered_func.h>
 #include <unordered_map>
 #include <string>
 #include "../runtime/meta_data.h"
 
 namespace tvm {
 namespace codegen {
-// Extract function information from device function.
-inline std::unordered_map<std::string, runtime::FunctionInfo>
-ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) {
-  std::unordered_map<std::string, runtime::FunctionInfo> fmap;
-  for (tir::LoweredFunc f : funcs) {
-    runtime::FunctionInfo info;
-    for (size_t i = 0; i < f->args.size(); ++i) {
-      info.arg_types.push_back(f->args[i].dtype());
-    }
-    for (size_t i = 0; i < f->thread_axis.size(); ++i) {
-      info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
-    }
-    fmap[f->name] = info;
-  }
-  return fmap;
-}
 
 inline std::unordered_map<std::string, runtime::FunctionInfo>
 ExtractFuncInfo(const IRModule& mod) {
index 703328f..0eceea8 100644 (file)
 namespace tvm {
 namespace codegen {
 
-// convert legacy LoweredFunc to PrimFunc.
-tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
-  // remap args to attach type annotations.
-  Array<tir::Var> args;
-  Map<tir::Var, PrimExpr> remap_vars;
-
-  for (auto var : from->args) {
-    auto it = from->handle_data_type.find(var);
-    if (it != from->handle_data_type.end()) {
-      tir::Var new_var(var->name_hint,
-                       PointerType(PrimType((*it).second->dtype)));
-      args.push_back(new_var);
-      remap_vars.Set(var, new_var);
-    } else {
-      args.push_back(var);
-    }
-  }
-  tir::PrimFunc func(args, Substitute(from->body, remap_vars));
-
-  func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name));
-  func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis);
-  if (from->func_type == tir::LoweredFuncType::kDeviceFunc) {
-    func = WithAttr(std::move(func),
-                    attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch));
-  }
-  if (from->is_restricted) {
-    func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1));
-  }
-  return func;
-}
-
-IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
-  Map<GlobalVar, BaseFunc> functions;
-  for (size_t i = 0; i < funcs.size(); ++i) {
-    auto f = funcs[i];
-    tir::PrimFunc pf = ToPrimFunc(f);
-    if (i == 0) {
-      pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1));
-    }
-    functions.Set(GlobalVar(f->name), pf);
-  }
-  return IRModule(functions);
-}
-
 runtime::Module Build(IRModule mod, const Target& target) {
   if (BuildConfig::Current()->disable_assert) {
     mod = tir::transform::SkipAssert()(mod);
@@ -284,9 +240,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
 TVM_REGISTER_GLOBAL("target.Build")
 .set_body_typed(Build);
 
-TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
-.set_body_typed(ToIRModule);
-
 // Export two auxiliary function to the runtime namespace.
 TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
 .set_body_typed(PackImportsToC);
index 31465cd..450ebbc 100644 (file)
@@ -448,7 +448,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
   auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
   debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
 #endif
-  // TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance?
+  // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
   debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
   debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
       llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
index 3f508a5..9ea77ac 100644 (file)
@@ -67,20 +67,23 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     } else if (name == "_get_target_triple") {
       std::string target_triple = tm_->getTargetTriple().str();
       return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) {
-        * rv = target_triple;
+        *rv = target_triple;
       });
     }
     if (ee_ == nullptr) LazyInitJIT();
 
-    // This LLVMModule is empty and no function can be retrieved.
-    if (entry_func_.empty()) return nullptr;
-
     std::lock_guard<std::mutex> lock(mutex_);
-    const std::string& fname = (name == runtime::symbol::tvm_module_main ?
-                                entry_func_ : name);
 
-    TVMBackendPackedCFunc faddr =
-        reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(fname));
+    TVMBackendPackedCFunc faddr;
+    if (name == runtime::symbol::tvm_module_main) {
+      const char* entry_name = reinterpret_cast<const char*>(
+          GetGlobalAddr(runtime::symbol::tvm_module_main));
+      CHECK(entry_name != nullptr)
+          << "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
+      faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(entry_name));
+    } else {
+      faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(name));
+    }
     if (faddr == nullptr) return PackedFunc();
     return WrapPackedFunc(faddr, sptr_to_self);
   }
@@ -205,6 +208,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
 
     std::vector<PrimFunc> funcs;
+    std::string entry_func;
     for (auto kv :  mod->functions) {
       CHECK(kv.second->IsInstance<PrimFuncNode>())
           << "Can only lower IR Module with PrimFuncs";
@@ -212,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
       if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
         auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
         CHECK(global_symbol.defined());
-        entry_func_ = global_symbol;
+        entry_func = global_symbol;
       }
       funcs.push_back(f);
     }
@@ -225,8 +229,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
       cg->AddFunction(f);
     }
 
-    if (entry_func_.length() != 0) {
-      cg->AddMainFunction(entry_func_);
+    if (entry_func.length() != 0) {
+      cg->AddMainFunction(entry_func);
     }
 
     module_ = cg->Finish();
@@ -321,13 +325,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     CHECK(ee_ != nullptr)
         << "Failed to initialize jit engine for " << mptr_->getTargetTriple();
     ee_->runStaticConstructorsDestructors(false);
-    // setup context address.
-    // we will skip context setup if this LLVMModule is empty.
-    if (GetGlobalAddr(runtime::symbol::tvm_module_main) == 0)
-      return;
 
-    entry_func_ =
-        reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
     if (void** ctx_addr = reinterpret_cast<void**>(
             GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
       *ctx_addr = this;
@@ -356,8 +354,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {
 
   // The target configuration string
   std::string target_;
-  // Name of entry function.
-  std::string entry_func_;
   // JIT lock
   std::mutex mutex_;
   // execution engine
index c1894a3..30ad890 100644 (file)
@@ -29,7 +29,6 @@
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/target/codegen.h>
-#include <tvm/tir/lowered_func.h>
 #include <tvm/runtime/container.h>
 #include <string>
 #include <vector>
index a5ccd54..edcee20 100644 (file)
@@ -27,7 +27,6 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/lowered_func.h>
 
 #include <vector>
 #include <memory>
index 041c7a7..fd370d2 100644 (file)
@@ -26,7 +26,6 @@
 
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/lowered_func.h>
 #include <tvm/target/codegen.h>
 #include <string>
 #include <vector>
similarity index 86%
rename from src/tir/pass/verify_memory.cc
rename to src/tir/analysis/verify_memory.cc
index 5e805f8..d6a521f 100644 (file)
  * \brief Pass to check if memory accesses are legal.
  */
 #include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
 
 
 namespace tvm {
@@ -44,7 +46,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
  public:
   /// Special member functions
   //@{
-  explicit MemoryAccessVerifier(LoweredFunc f, int device_type)
+  explicit MemoryAccessVerifier(PrimFunc f, int device_type)
       : func_(f), dev_type_(device_type) {}
   virtual ~MemoryAccessVerifier() = default;
   MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
@@ -116,7 +118,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
       CHECK(V) << "Invalid Variable\n";
 
       // Variable is from function args. Return true.
-      if (V == func_->args[0].get()) return true;
+      if (V == func_->params[0].get()) return true;
 
       // The value is expected to come from a tvm_struct_get Call.
       // Get the first argument of tvm_struct_get, and continue.
@@ -179,18 +181,33 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   const ProducerConsumerNode *pc_{nullptr};
   bool failure_{false};  ///< If the verification fails (i.e. has illegal access)
   //@}
-  LoweredFunc func_{nullptr};  ///< Function to be verified.
+  tir::PrimFunc func_{nullptr};  ///< Function to be verified.
   int dev_type_{kDLCPU};       ///< Device type
   std::unordered_map<const VarNode *, PrimExpr> defs_;  ///< Variable definitions
 };
 }  // namespace
 
 /// Interface of VerifyMemory pass
-bool VerifyMemory(LoweredFunc func, int device_type) {
-  MemoryAccessVerifier v(func, device_type);
-  v.Run();
-  return !v.Failed();
+void VerifyMemory(const IRModule& mod) {
+  for (auto kv : mod->functions) {
+    if (auto* n = kv.second.as<PrimFuncNode>()) {
+      PrimFunc func = GetRef<PrimFunc>(n);
+      auto target = func->GetAttr<Target>(tvm::attr::kTarget);
+      CHECK(target.defined())
+          << "LowerWarpMemory: Require the target attribute";
+      MemoryAccessVerifier v(func, target->device_type);
+      v.Run();
+      if (v.Failed()) {
+        LOG(FATAL)
+            << "ValueError: Direct host side access to device memory is detected."
+            << " Did you forget to bind?\n"
+            << func;
+      }
+    }
+  }
 }
 
+TVM_REGISTER_GLOBAL("tir.analysis.verify_memory")
+.set_body_typed(VerifyMemory);
 }  // namespace tir
 }  // namespace tvm
index eec7c10..6bbf645 100644 (file)
@@ -48,7 +48,7 @@ Buffer decl_buffer(Array<PrimExpr> shape,
                    DataType dtype,
                    std::string name) {
   return BufferNode::make(
-      Var(name, DataType::Handle()),
+      Var(name, PointerType(PrimType(dtype))),
       dtype,
       shape,
       Array<PrimExpr>(),
diff --git a/src/tir/ir/lowered_func.cc b/src/tir/ir/lowered_func.cc
deleted file mode 100644 (file)
index 8790f2b..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file lowered_func.cc
- */
-#include <tvm/tir/lowered_func.h>
-
-namespace tvm {
-namespace tir {
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
-    auto* op = static_cast<const LoweredFuncNode*>(node.get());
-    p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
-});
-
-TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
-
-
-}  // namespace tir
-}  // namespace tvm
index 83db1a9..3083b68 100644 (file)
@@ -105,13 +105,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
       });
   });
 
-TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-  LoweredFunc f = args[0];
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = LowerStorageAccessInfo(f->body);
-  *ret = LoweredFunc(n);
-});
 
 // make from two arguments
 #define REGISTER_PASS(PassName)                                   \
@@ -128,7 +121,6 @@ REGISTER_PASS(VectorizeLoop);
 REGISTER_PASS(SkipVectorize);
 REGISTER_PASS(UnrollLoop);
 REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(MakeAPI);
 REGISTER_PASS(StorageRewrite);
 REGISTER_PASS(CoProcSync);
 REGISTER_PASS(LowerStorageAccessInfo);
@@ -138,9 +130,6 @@ REGISTER_PASS(InjectDoubleBuffer);
 REGISTER_PASS(LoopPartition);
 REGISTER_PASS(RemoveNoOp);
 REGISTER_PASS(LiftAttrScope);
-REGISTER_PASS(RemapThreadAxis);
-REGISTER_PASS(LowerCustomDatatypes);
-REGISTER_PASS(VerifyMemory);
 REGISTER_PASS(VerifyGPUCode);
 REGISTER_PASS(DecorateDeviceScope);
 REGISTER_PASS(InstrumentBoundCheckers);
index b4e6061..f3604b6 100644 (file)
@@ -994,29 +994,6 @@ class VectorAllocRewriter : public StmtExprMutator {
 };
 
 
-LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  VectorAllocRewriter rewriter;
-  n->body = rewriter(n->body);
-  for (Var arg : f->args) {
-    if (arg.dtype().is_handle()) {
-      const auto& tvec = rewriter.acc_map_[arg.get()];
-      if (tvec.size() == 1) {
-        PrimExpr dtype = make_const(tvec[0], 0);
-        n->handle_data_type.Set(arg, dtype);
-      } else {
-        // always set data type to be non vectorized so
-        // load/store can still work via scalarization
-        if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
-          PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0);
-          n->handle_data_type.Set(arg, dtype);
-        }
-      }
-    }
-  }
-  return LoweredFunc(n);
-}
-
 PrimFunc PointerValueTypeRewrite(PrimFunc f) {
   auto* n = f.CopyOnWrite();
   VectorAllocRewriter rewriter;
similarity index 88%
rename from src/tir/pass/lower_custom_datatypes.cc
rename to src/tir/transforms/lower_custom_datatypes.cc
index b24fdf1..6026f8c 100644 (file)
@@ -22,7 +22,9 @@
  */
 
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
 #include "../../target/datatype/registry.h"
 
 namespace tvm {
@@ -129,11 +131,26 @@ class CustomDatatypesLowerer : public StmtExprMutator {
   std::string target_;
 };
 
-LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = CustomDatatypesLowerer(target)(n->body);
-  return LoweredFunc(n);
+
+namespace transform {
+
+Pass LowerCustomDatatypes() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+    CHECK(target.defined())
+        << "LowerCustomDatatypes: Require the target attribute";
+
+    n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
 }
 
+TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes")
+.set_body_typed(LowerCustomDatatypes);
+
+}  // namespace transform
+
 }  // namespace tir
 }  // namespace tvm
similarity index 61%
rename from src/tir/pass/make_api.cc
rename to src/tir/transforms/make_packed_api.cc
index 861cd43..c49b044 100644 (file)
  */
 
 /*!
- * \file make_api.cc Build API function.
+ * \file make_packed_api.cc Lower PrimFunc to use the packed function API.
  */
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/analysis.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/buffer.h>
 #include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
+
 #include <vector>
 #include <utility>
 #include <unordered_set>
 
-#include "ir_util.h"
-#include "arg_binder.h"
+#include "../pass/ir_util.h"
+#include "../pass/arg_binder.h"
 
 namespace tvm {
 namespace tir {
@@ -40,14 +44,18 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
   return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
 }
 
-LoweredFunc MakeAPI(Stmt body,
-                    std::string name,
-                    Array<ObjectRef> api_args,
-                    int num_unpacked_args,
-                    bool is_restricted) {
+PrimFunc MakePackedAPI(PrimFunc&& func,
+                       int num_unpacked_args) {
+  auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  CHECK(global_symbol.defined())
+      << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
+  std::string name_hint = global_symbol;
+
+  auto* func_ptr = func.CopyOnWrite();
   const Stmt nop = EvaluateNode::make(0);
-  int num_args = static_cast<int>(api_args.size());
+  int num_args = static_cast<int>(func_ptr->params.size());
   CHECK_LE(num_unpacked_args, num_args);
+
   int num_packed_args = num_args - num_unpacked_args;
   // Data field definitions
   // The packed fields
@@ -69,9 +77,10 @@ LoweredFunc MakeAPI(Stmt body,
   // local function definitions
   // load i-th argument as type t
   auto f_arg_value = [&](DataType t, int i) {
-    Array<PrimExpr> call_args{v_packed_args,
-                          IntImm(DataType::Int(32), i),
-                          IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
+    Array<PrimExpr> call_args{
+      v_packed_args,
+      IntImm(DataType::Int(32), i),
+      IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
     // load 64 bit version
     DataType api_type = APIType(t);
     PrimExpr res = CallNode::make(
@@ -83,13 +92,7 @@ LoweredFunc MakeAPI(Stmt body,
     }
     return res;
   };
-  // get declaration of argument i
-  auto f_arg_decl = [&](int i) {
-    std::ostringstream os;
-    os << "arg" << i;
-    const VarNode* v = api_args[i].as<VarNode>();
-    return Var(os.str(), v ? v->dtype: DataType::Handle());
-  };
+
   // ---------------------------
   // start of logics
   // add signiture for packed arguments.
@@ -99,16 +102,25 @@ LoweredFunc MakeAPI(Stmt body,
     args.push_back(v_num_packed_args);
     std::ostringstream os;
 
-    os << name << ": num_args should be " << num_packed_args;
+    os << name_hint << ": num_args should be " << num_packed_args;
     seq_init.emplace_back(
         MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
   }
 
-  // Save the input variables and buffers that will be bound later.
-  std::vector<std::pair<Var, Var> > var_defs;
-  std::vector<std::pair<Buffer, Var> > buf_defs;
-  for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
-    Var v_arg = f_arg_decl(i);
+  // Need to re-declare vars, in case some arguments also appears in the buffer.
+  std::vector<std::pair<Var, Var> > var_def;
+  std::vector<std::pair<Var, Buffer> > buffer_def;
+
+  for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
+    Var param = func_ptr->params[i];
+    Var v_arg = Var("arg" + std::to_string(i), param->dtype);
+
+    auto it = func_ptr->buffer_map.find(param);
+    if (it != func_ptr->buffer_map.end()) {
+      buffer_def.emplace_back(v_arg, (*it).second);
+    } else {
+      var_def.emplace_back(v_arg, param);
+    }
     if (i < num_packed_args) {
       // Value loads
       seq_init.emplace_back(LetStmtNode::make(
@@ -123,35 +135,26 @@ LoweredFunc MakeAPI(Stmt body,
       DataType t = v_arg.dtype();
       if (t.is_handle()) {
         std::ostringstream msg;
-        msg << name << ": Expect arg[" << i << "] to be pointer";
+        msg << name_hint << ": Expect arg[" << i << "] to be pointer";
         seq_check.emplace_back(
             AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
-                             tcode == kTVMNDArrayHandle ||
-                             tcode == kTVMDLTensorHandle ||
-                             tcode == kTVMNullptr, msg.str(), nop));
+                                 tcode == kTVMNDArrayHandle ||
+                                 tcode == kTVMDLTensorHandle ||
+                                 tcode == kTVMNullptr, msg.str(), nop));
       } else if (t.is_int() || t.is_uint()) {
         std::ostringstream msg;
-        msg << name << ": Expect arg[" << i << "] to be int";
+        msg << name_hint << ": Expect arg[" << i << "] to be int";
         seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
       } else {
         CHECK(t.is_float());
         std::ostringstream msg;
-        msg << name << ": Expect arg[" << i << "] to be float";
+        msg << name_hint << ": Expect arg[" << i << "] to be float";
         seq_check.emplace_back(
             AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
       }
     } else {
       args.push_back(v_arg);
     }
-    // add checks for functions.
-    if (api_args[i].as<VarNode>()) {
-      var_defs.emplace_back(std::make_pair(Downcast<Var>(api_args[i]), v_arg));
-    } else {
-      // Buffer checks
-      CHECK(api_args[i].as<BufferNode>())
-          << "api_args can only be Buffer or Var";
-      buf_defs.emplace_back(std::make_pair(Downcast<Buffer>(api_args[i]), v_arg));
-    }
   }
 
   // allow return value if the function is packed.
@@ -170,24 +173,22 @@ LoweredFunc MakeAPI(Stmt body,
   // either 0 or the original stride will be correctly used. Checks here have
   // to use the args that may have no let bining yet. Therefore, hoisting let
   // binding for args before buffer declaration is needed.
-  for (const auto& arg : var_defs) {
-    binder.Bind(arg.first, arg.second, arg.second->name_hint, true);
+  for (const auto& kv : var_def) {
+    binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
+  }
+
+  for (const auto& kv : buffer_def) {
+    binder.BindDLTensor(kv.second, device_type, device_id,
+                        kv.first, kv.first->name_hint);
   }
 
-  for (const auto& buf_arg : buf_defs) {
-    binder.BindDLTensor(buf_arg.first, device_type, device_id,
-                        buf_arg.second, buf_arg.second->name_hint);
+  if (num_unpacked_args == 0) {
+    func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
   }
 
-  ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>();
-  n->name = name;
-  n->args = args;
-  n->handle_data_type = binder.def_handle_dtype();
-  n->is_packed_func = num_unpacked_args == 0;
-  n->is_restricted = is_restricted;
-  body = AttrStmtNode::make(
+  auto body = AttrStmtNode::make(
       make_zero(DataType::Int(32)), attr::compute_scope,
-      StringImmNode::make(name + "_compute_"), body);
+      StringImmNode::make(name_hint + "_compute_"), func_ptr->body);
   // Set device context
   if (vmap.count(device_id.get())) {
     PrimExpr node = StringImmNode::make("default");
@@ -203,21 +204,59 @@ LoweredFunc MakeAPI(Stmt body,
              device_type, device_id}, CallNode::Intrinsic)));
     body = SeqStmt({set_device, body});
   }
-  n->body = MergeNest(
+  func_ptr->body = MergeNest(
       {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
-  LoweredFunc f(n);
-  Array<Var> undefined = UndefinedVars(f->body, f->args);
+  func_ptr->params = args;
+
+  Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
   if (undefined.size() != 0) {
     std::ostringstream os;
     for (Var v : undefined) {
       os << " \'" << v->name_hint << "\' ";
     }
-    os << " does not appear in api_args";
+    os << " is not bound to any variables";
     LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
   }
-  return f;
+
+
+  func_ptr->buffer_map = Map<Var, Buffer>();
+  func_ptr->checked_type_ = func_ptr->func_type_annotation();
+  func_ptr->ret_type = PrimType(DataType::Int(32));
+
+  // return the function.
+  return std::move(func);
 }
 
+namespace transform {
+
+Pass MakePackedAPI(int num_unpacked_args) {
+  auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
+    IRModuleNode* mptr = m.CopyOnWrite();
+    std::vector<std::pair<GlobalVar, PrimFunc> > updates;
+
+    for (const auto& kv : mptr->functions) {
+      if (auto* n = kv.second.as<PrimFuncNode>()) {
+        PrimFunc func = GetRef<PrimFunc>(n);
+        if (func->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value
+            == static_cast<int>(CallingConv::kDefault)) {
+          auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
+          updates.push_back({kv.first, updated_func});
+        }
+      }
+    }
+
+    for (const auto& pair : updates) {
+      mptr->AddUnchecked(pair.first, pair.second);
+    }
+    return m;
+  };
+
+  return tvm::transform::CreateModulePass(
+      pass_func, 0, "tir.MakePackedAPI", {});
+}
 
+TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI")
+.set_body_typed(MakePackedAPI);
+}  // namespace transform
 }  // namespace tir
 }  // namespace tvm
similarity index 73%
rename from src/tir/pass/remap_thread_axis.cc
rename to src/tir/transforms/remap_thread_axis.cc
index 4fa5dd3..f695b3c 100644 (file)
@@ -22,7 +22,8 @@
  */
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/runtime/registry.h>
 #include <unordered_map>
 
 
@@ -74,8 +75,8 @@ class ThreadAxisRewriter : private StmtExprMutator {
   std::unordered_map<const VarNode*, Var> vmap_;
 };
 
-LoweredFunc
-RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
+
+PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
   std::unordered_map<std::string, IterVar> tmap;
   for (const auto& kv : thread_map) {
     const StringImmNode* str = kv.first.as<StringImmNode>();
@@ -83,18 +84,33 @@ RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
     tmap[str->value] = kv.second;
   }
 
-  CHECK_EQ(f->func_type, kDeviceFunc);
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
+  auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
+  auto* n = f.CopyOnWrite();
+
   // replace the thread axis
-  for (size_t i = 0; i < n->thread_axis.size(); ++i) {
-    auto it = tmap.find(n->thread_axis[i]->thread_tag);
+  for (size_t i = 0; i < thread_axis.size(); ++i) {
+    auto it = tmap.find(thread_axis[i]->thread_tag);
     if (it != tmap.end()) {
-      n->thread_axis.Set(i, it->second);
+      thread_axis.Set(i, it->second);
     }
   }
-  n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
-  return LoweredFunc(n);
+  n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body));
+  return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis);
 }
 
+
+namespace transform {
+
+Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
+  auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
+    return RemapThreadAxis(std::move(f), thread_map);
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis")
+.set_body_typed(RemapThreadAxis);
+
+}  // namespace transform
 }  // namespace tir
 }  // namespace tvm
index 838ad82..ae32bdc 100644 (file)
@@ -264,7 +264,6 @@ class HostDeviceSplitter : public StmtMutator {
   std::string name_prefix_;
   // Number of device functions.
   int device_func_counter_{0};
-  std::vector<LoweredFunc> device_funcs_;
   std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
 };
 
index c2c808f..9333a34 100644 (file)
@@ -117,8 +117,8 @@ TEST(BuildModule, Heterogeneous) {
   std::unordered_map<Tensor, Buffer> binds;
   auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config);
   auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config);
-  Map<tvm::Target, Array<LoweredFunc>> inputs = {{target_cuda, lowered_s1},
-                                                 {target_llvm, lowered_s2}};
+  Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1},
+                                       {target_llvm, lowered_s2}};
   auto module = build(inputs, Target(), config);
 
   // Assertion for build.
index 4f2b6aa..27f3788 100644 (file)
@@ -18,29 +18,6 @@ import tvm
 from tvm import te
 import numpy as np
 
-def lower(s, args, name="mydot"):
-    binds = {}
-    arg_list = []
-
-    for x in args:
-        assert isinstance(x, te.tensor.Tensor)
-        buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
-        binds[x] = buf
-        arg_list.append(buf)
-    s = s.normalize()
-    bounds = tvm.te.schedule.InferBound(s)
-    stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-    stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 16)
-    stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
-    stmt = tvm.tir.ir_pass.Simplify(stmt)
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
-    fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
-    return fapi
-
-
-def mybuild(fapi, target="llvm"):
-    return
-
 
 def test_dot():
     nn = 12
index 13de67e..d9088b6 100644 (file)
@@ -38,8 +38,9 @@ def test_dltensor_compatible():
     with ib.for_range(0, n - 1, "i") as i:
         A[i + 1] = A[i] + 1
     stmt = ib.get()
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
-    mod = tvm.testing.LoweredFuncsToIRModule([fapi])
+
+
+    mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
     mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
     f = tvm.target.codegen.build_module(mod, "stackvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
index 8ca61c1..343b867 100644 (file)
@@ -156,7 +156,7 @@ def test_simplex_data_transferring():
                                              elemwise_sub],
                               name="elemwise_sub")
 
-        target_flist = {target_device: [lower_add], target_host: [lower_sub]}
+        target_flist = {target_device: lower_add, target_host: lower_sub}
         mhost = tvm.build(target_flist, target_host=target_host)
         ctx = [host_ctx, device_ctx]
         mod = graph_runtime.create(graph, mhost, ctx)
@@ -354,8 +354,9 @@ def test_duplex_data_transferring():
                                              elemwise_sub],
                               name="elemwise_sub")
 
-        target_flist = {target_device: [lower_add0, lower_add1], target_host:
-                        [lower_sub]}
+        lower_add0.update(lower_add1)
+        target_flist = {target_device: lower_add0, target_host:
+                        lower_sub}
         mhost = tvm.build(target_flist, target_host=target_host)
         ctx = [host_ctx, device_ctx]
         params = {}
index 37ccb5e..f6abebd 100644 (file)
@@ -57,8 +57,8 @@ def test_dso_module_load():
             tvm.tir.Store(Ab.data,
                            tvm.tir.Load(dtype, Ab.data, i) + 1,
                            i + 1))
-        fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
-        m = tvm.driver.build(fapi, target="llvm")
+        m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
+        m = tvm.driver.build(m, target="llvm")
         for name in names:
             m.save(name)
 
index 45554c5..34135c6 100644 (file)
@@ -22,6 +22,7 @@ import numpy as np
 import ctypes
 import math
 
+
 def test_llvm_intrin():
     ib = tvm.tir.ir_builder.create()
     n = tvm.runtime.convert(4)
@@ -34,7 +35,8 @@ def test_llvm_intrin():
         tvm.tir.Call(
             "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
     body = ib.get()
-    func = tvm.tir.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
+
+    func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True)
     fcode = tvm.build(func, None, "llvm")
 
 
@@ -85,7 +87,7 @@ def test_llvm_lookup_intrin():
     x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A)
     ib.emit(x)
     body = ib.get()
-    func = tvm.tir.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
+    func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True)
     fcode = tvm.build(func, None, "llvm")
 
 
@@ -307,8 +309,9 @@ def test_multiple_func():
         f2 = tvm.lower(s, [A, B, C], name="fadd1")
         f1 = tvm.lower(s, [A, B, C], name="fadd2")
         m = tvm.build([f1, f2], "llvm")
-        fadd1 = m['fadd1']
         fadd2 = m['fadd2']
+        fadd1 = m['fadd1']
+
         ctx = tvm.cpu(0)
         # launch the kernel.
         n = nn
@@ -665,6 +668,7 @@ def test_llvm_shuffle():
         tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
 
 if __name__ == "__main__":
+    test_multiple_func()
     test_llvm_large_uintimm()
     test_llvm_import()
     test_alignment()
@@ -676,7 +680,6 @@ if __name__ == "__main__":
     test_llvm_vadd_pipeline()
     test_llvm_add_pipeline()
     test_llvm_intrin()
-    test_multiple_func()
     test_llvm_flip_pipeline()
     test_llvm_madd_pipeline()
     test_llvm_temp_space()
index a9fa35f..bd4d0d8 100644 (file)
@@ -19,6 +19,18 @@ from tvm import te
 import ctypes
 import numpy as np
 
+
+def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
+    """Legacy adapter to create a API"""
+    f = tvm.tir.PrimFunc(args, stmt).with_attr(
+        "global_symbol", tvm.runtime.String(name))
+    f = f.with_attr("tir.is_entry_func", True)
+    if noalias:
+        f = f.with_attr("tir.no_alias", True)
+    mod = tvm.IRModule.from_expr(f)
+    return tvm.tir.transform.MakePackedAPI()(mod)
+
+
 def test_static_callback():
     dtype = 'int64'
     n = te.size_var('n')
@@ -32,7 +44,7 @@ def test_static_callback():
     with ib.for_range(0, n, "i", for_type="parallel") as i:
         A[i] = A[i] + 1
     stmt = ib.get()
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+    fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
     f = tvm.driver.build(fapi, target="llvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     f(a)
@@ -55,7 +67,7 @@ def test_static_init():
         return sh
 
     stmt = ib.get()
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+    fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
     f = tvm.driver.build(fapi, target="llvm")
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     f(a)
index 26464ce..ee0d89b 100644 (file)
@@ -26,6 +26,18 @@ def run_jit(fapi, check):
         s = f.get_source()
         check(f)
 
+
+def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
+    """Legacy adapter to create a API"""
+    f = tvm.tir.PrimFunc(args, stmt).with_attr(
+        "global_symbol", tvm.runtime.String(name))
+    f = f.with_attr("tir.is_entry_func", True)
+    if noalias:
+        f = f.with_attr("tir.no_alias", True)
+    mod = tvm.IRModule.from_expr(f)
+    return tvm.tir.transform.MakePackedAPI()(mod)
+
+
 def test_stack_vm_basic():
     a = tvm.nd.array(np.zeros(10, dtype='float32'))
     @tvm.register_func
@@ -36,7 +48,7 @@ def test_stack_vm_basic():
     n = te.size_var('n')
     Ab = tvm.tir.decl_buffer((n, ), "float32")
     stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
+    fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
     run_jit(fapi, lambda f: f(a))
 
 
@@ -57,7 +69,7 @@ def test_stack_vm_loop():
         ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))
 
     stmt = ib.get()
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+    fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     def check(f):
         f(a)
@@ -79,7 +91,7 @@ def test_stack_vm_cond():
             A[i + 1] = A[i] + 2
 
     stmt = ib.get()
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
+    fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True)
     def check(f):
         a = tvm.nd.array(np.zeros(10, dtype=dtype))
         f(a)
@@ -98,7 +110,7 @@ def test_vm_parallel():
     with ib.for_range(0, n, "i", for_type="parallel") as i:
         A[i] = A[i] + 1
     stmt = ib.get()
-    fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+    fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
     def check(f):
         a = tvm.nd.array(np.zeros(10, dtype=dtype))
         f(a)
index 32f6e18..f6723e2 100644 (file)
@@ -19,7 +19,6 @@ import tvm
 from tvm import te
 from ctypes import *
 import topi
-import tvm.tir.ir_pass as ir_pass
 import numpy as np
 
 tgt = "llvm"
@@ -51,10 +50,12 @@ def lower_datatypes_and_build(schedule, args):
     Once datatype lowering is integrated directly into TVM's lower/build
     process, we won't need to do this manually.
     TODO(gus) integrate datatype lowering into build process; change this test"""
-    flist = tvm.lower(schedule, args)
-    flist = [flist]
-    flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist]
-    return tvm.build(flist[0], target=tgt)
+    mod = tvm.lower(schedule, args)
+    target = tvm.target.create(tgt)
+    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
+    mod = tvm.tir.transform.LowerCustomDatatypes()(mod)
+    return tvm.build(mod, target=tgt)
+
 
 def test_bfloat_add_and_cast_1():
     X = te.placeholder((3, ), name="X")
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+import pytest
 from tvm import te
 
 # The following DLDeviceType/TVMDeviceExtType values
 # are originally defined in dlpack.h and c_runtime_api.h.
-gpu_devices = [2, 4, 7, 8, 10, 11]
-other_devices = [1, 3, 9, 12]
+gpu_devices = ["cuda", "opencl", "metal", "vulkan"]
+other_devices = ["llvm", "ext_dev"]
 
 
 def lower(sch, args):
@@ -39,8 +40,11 @@ def lower(sch, args):
     stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
     stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
-    func = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", arg_list, 0, True)
-    return func
+
+    f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
+        "global_symbol", tvm.runtime.String("test"))
+    mod = tvm.IRModule({"test": f})
+    return tvm.tir.transform.MakePackedAPI()(mod)
 
 
 # All computations are bound.
@@ -57,10 +61,13 @@ def test_verify_memory_all_bind():
   s[B].bind(bx, te.thread_axis("blockIdx.x"))
   s[B].bind(tx, te.thread_axis("threadIdx.x"))
 
-  func = lower(s, [A, B])
+  mod = lower(s, [A, B])
 
   for dev_type in gpu_devices + other_devices:
-    assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+      binded_mod = tvm.tir.transform.Apply(
+          lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+      tvm.tir.analysis.verify_memory(binded_mod)
+
 
 
 # Computations are not bound.
@@ -74,12 +81,18 @@ def test_verify_memory_not_bind():
   # B is not bound to threads.
   s = te.create_schedule(B.op)
 
-  func = lower(s, [A, B])
+  mod = lower(s, [A, B])
 
   for dev_type in gpu_devices:
-    assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+      binded_mod = tvm.tir.transform.Apply(
+          lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+      with pytest.raises(ValueError):
+          tvm.tir.analysis.verify_memory(binded_mod)
+
   for dev_type in other_devices:
-    assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+      binded_mod = tvm.tir.transform.Apply(
+          lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+      tvm.tir.analysis.verify_memory(binded_mod)
 
 
 # Computations are partially bound.
@@ -98,16 +111,22 @@ def test_verify_memory_partially_bind():
   s[C].bind(bx, te.thread_axis("blockIdx.x"))
   s[C].bind(tx, te.thread_axis("threadIdx.x"))
 
-  func = lower(s, [A, B, C, D])
+  mod = lower(s, [A, B, C, D])
 
   for dev_type in gpu_devices:
-    assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+      binded_mod = tvm.tir.transform.Apply(
+          lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+      with pytest.raises(ValueError):
+          tvm.tir.analysis.verify_memory(binded_mod)
+
   for dev_type in other_devices:
-    assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+      binded_mod = tvm.tir.transform.Apply(
+          lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+      tvm.tir.analysis.verify_memory(binded_mod)
+
 
 
 if __name__ == "__main__":
   test_verify_memory_all_bind()
   test_verify_memory_not_bind()
   test_verify_memory_partially_bind()
-
index b339097..d6c89b2 100644 (file)
@@ -118,7 +118,6 @@ def test_in_bounds_vectorize_llvm():
     s[B].vectorize(xi)
     # build and invoke the kernel.
     lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False)
-    print (lowered_func.body)
     f = tvm.build(s, [A, C], "llvm")
     ctx = tvm.cpu(0)
     # launch the kernel.
@@ -137,7 +136,6 @@ def test_in_bounds_loop_partition_basic_llvm():
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -156,7 +154,6 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
     s = te.create_schedule(T.op)
     xo, xi = s[T].split(T.op.axis[0], factor=4)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -205,12 +202,11 @@ def test_in_bounds_const_loop_partition_ir():
     # after instrumentation
     assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
     assert_bound_instrumentation(stmt, check_branch_stmt, 2)
-    print (stmt)
+
     branch_collector = list()
     collect_visit(stmt, collect_branch_stmt)
     assert(len(branch_collector) ==  2)
-    print (branch_collector[0].condition)
-    print (branch_collector[1].condition)
+
 
 def test_in_bounds_const_loop_partition_llvm():
     with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
@@ -222,7 +218,6 @@ def test_in_bounds_const_loop_partition_llvm():
         s = te.create_schedule(T.op)
         xo, xi = s[T].split(T.op.axis[0], factor=4)
         lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-        print (lowered_func.body)
         ctx = tvm.cpu(0)
 
         f = tvm.build(s, [A, B, T], "llvm")
@@ -242,7 +237,6 @@ def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b):
         s = te.create_schedule(T.op)
         xo, xi = s[T].split(T.op.axis[0], factor=4)
         lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-        print (lowered_func.body)
         ctx = tvm.cpu(0)
 
         f = tvm.build(s, [A, B, T], "llvm")
@@ -276,7 +270,6 @@ def test_in_bounds_conv_llvm(loop_tiling=False):
     if loop_tiling:
         oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
     lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
-    print (lowered_func.body)
     ctx = tvm.cpu (0)
 
     f = tvm.build(s, [data, kernel, conv], "llvm")
@@ -320,7 +313,6 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False
     if loop_tiling:
         oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
     lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
-    print (lowered_func.body)
     ctx = tvm.cpu (0)
 
     f = tvm.build(s, [data, kernel, conv], "llvm")
@@ -341,7 +333,6 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm():
     T = te.compute((m, ), lambda i: A[i]*B[i])
     s = te.create_schedule(T.op)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -361,7 +352,6 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape
     T = te.compute((m, ), lambda i: A[i]*B[i])
     s = te.create_schedule(T.op)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -380,7 +370,6 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm():
     T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
     s = te.create_schedule(T.op)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -400,7 +389,6 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape
     T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
     s = te.create_schedule(T.op)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -419,7 +407,7 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm():
     T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
     s = te.create_schedule(T.op)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
+
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -439,7 +427,7 @@ def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape
     T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
     s = te.create_schedule(T.op)
     lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
-    print (lowered_func.body)
+
     ctx = tvm.cpu(0)
 
     f = tvm.build(s, [A, B, T], "llvm")
@@ -460,7 +448,7 @@ def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm():
     D = te.compute((), lambda : C + 1)
     s = te.create_schedule(D.op)
     stmt = tvm.lower (s, [A, scale, D], simple_mode=True)
-    print (stmt)
+
     # build and invoke the kernel.
     f = tvm.build(s, [A, scale, D], "llvm")
     ctx = tvm.cpu(0)
index 94e29c6..95a1054 100644 (file)
@@ -40,8 +40,7 @@ def test_double_buffer():
     stmt = tvm.tir.ir_pass.Simplify(stmt)
     assert isinstance(stmt.body.body, tvm.tir.Allocate)
     assert stmt.body.body.extents[0].value == 2
-    f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
-    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
     f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
 
     count = [0]
index 7e383dd..0818c0e 100644 (file)
@@ -381,7 +381,7 @@ def test_multilevel_splitting_with_indivisble_factors():
 
     ## But this does the right thing.
     with tvm.target.build_config(partition_const_loop=True):
-        lowered_body = tvm.lower(s, [A, B]).body
+        lowered_body = tvm.lower(s, [A, B], name="x")["x"].body
         def visit_stmt(op):
             return(isinstance(op, tvm.tir.Max))
         num_max = collect_visit(lowered_body, visit_stmt)
@@ -407,7 +407,7 @@ def test_double_splitting_with_indivisible_factors():
 
     # Find the beginning of the Halide IR corresponding to kernel code
     # and make sure it doesn't have an if statements left
-    top_produce = find_top_produce(f.body)
+    top_produce = find_top_produce(f["fadd1"].body)
     assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
 
     # check functional correctness of generated code
index dbfcd20..da9253f 100644 (file)
@@ -92,9 +92,7 @@ def test_flatten_double_buffer():
     stmt = tvm.tir.ir_pass.Simplify(stmt)
     assert isinstance(stmt.body.body, tvm.tir.Allocate)
     assert stmt.body.body.extents[0].value == 2
-    f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
-    f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
-    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
     f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
 
     count = [0]
index 8140ddb..6f2bc65 100644 (file)
@@ -36,12 +36,7 @@ def test_for():
             ib.emit(tvm.tir.call_extern
                     ("int32", "fadd", device_context(0), A))
     body = ib.get()
-    f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
-
-    # temp adapter to convert loweredFunc to IRModule
-    # to test passes in the new style.x
-    mod = tvm.testing.LoweredFuncsToIRModule([f])
-
+    mod = tvm.testing.MakeAPILegacy(body, "func", [dev_type, n], 2, True)
     mod = tvm.tir.transform.CombineContextCall()(mod)
 
     assert mod["func"].body.value.dtype == "handle"
index 167899a..cf6ef72 100644 (file)
@@ -35,10 +35,8 @@ def test_lower_warp_mem():
 
     cuda_target = tvm.target.create("cuda")
     assert cuda_target.thread_warp_size == 32
-    f = tvm.lower(s, [A, B], name="f")
+    mod = tvm.lower(s, [A, B], name="f")
 
-
-    mod = tvm.testing.LoweredFuncsToIRModule([f])
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
     fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
     mod = tvm.IRModule.from_expr(fdevice)
@@ -35,11 +35,11 @@ def test_makeapi():
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
 
     num_unpacked_args = 2
-    f = tvm.tir.ir_pass.MakeAPI(
-        stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True)
-    assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
-    assert(len(f.args) == 7)
-    output_ssa = False
+    f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr(
+        "tir.no_alias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
+    mod = tvm.IRModule.from_expr(f)
+    f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
+    assert(len(f.params) == 7)
 
 
 if __name__ == "__main__":
index 6c9e7f9..64b454f 100644 (file)
@@ -37,10 +37,9 @@ def test_thread_storage_sync():
     Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
     A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
     stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
-    f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
-    cuda_target = tvm.target.create("cuda")
 
-    mod = tvm.testing.LoweredFuncsToIRModule([f])
+    cuda_target = tvm.target.create("cuda")
+    mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True)
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
     fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
     mod = tvm.IRModule.from_expr(fdevice)
index 298b24f..25ca279 100644 (file)
@@ -36,7 +36,7 @@ Before reading this tutorial, we assume readers have already known these topics
 - Visitor design pattern. Otherwise, check the
   `Python AST module <https://docs.python.org/3/library/ast.html>`_ to see how an AST
   visitor is implemented.
-- How a HalideIR/Schedule is lowered to either a LoweredFunc class or a LLVM module. Otherwise,
+- How a Schedule is lowered to either an IRModule class or a LLVM module. Otherwise,
   take a look at ``python/tvm/build_module.py`` to get some basics.
 
 """