* [REFACTOR][TIR] Migrate all low-level passes to the Pass Manager.
This PR migrates the tvm.lower to return IRModule of PrimFuncs
instead of the LoweredFuncs.
* Remove LoweredFunc.
"tvm::IterVarAttr",
"tvm::IterVarRelation",
"tvm::Layout",
- "tir::LoweredFunc",
"tvm::Map",
"tvm::Map",
"tvm::MemoryInfo",
Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
-::
-
- runtime::Module Build(const Array<LoweredFunc>& funcs,
- const std::string& target) {
- std::string build_f_name = "codegen.build_" + target;
- const PackedFunc* bf = runtime::Registry::Get(build_f_name);
- runtime::Module m = (*bf)(funcs, target);
- return m;
- }
The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/support/with.h>
+#include <tvm/ir/module.h>
#include <tvm/te/schedule_pass.h>
-#include <tvm/tir/lowered_func.h>
#include <string>
#include <vector>
namespace tvm {
/*!
-* \brief Build a LoweredFunc given a schedule, args and binds
+* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param config The build configuration.
-* \return The lowered function.
+* \return The result module.
*/
-TVM_DLL Array<tir::LoweredFunc> lower(
+TVM_DLL IRModule lower(
te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const BuildConfig& config);
/*!
-* \brief Build a device and host module for a specific target from an array of lowered functions.
+* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The built module.
*/
-TVM_DLL runtime::Module build(const Array<tir::LoweredFunc>& funcs,
+TVM_DLL runtime::Module build(const IRModule& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from a map
- * contains target to a list of lowered functions pairs. This function is used
+ * contains target to IRModule. This function is used
* for heterogeneous build.
- * \param input The map contains target to a list of lowered functions pairs.
+ * \param input The map contains target to an IRModule.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
-TVM_DLL runtime::Module build(const Map<Target, Array<tir::LoweredFunc>>& input,
+TVM_DLL runtime::Module build(const Map<Target, IRModule>& input,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from a map
- * contains target to a list of lowered functions pairs. This function is used
+ * contains target to IRModule. This function is used
* for heterogeneous build.
- * \param input The map contains target string to a list of lowered functions
- * pairs.
+ * \param input The map contains target string to an IRModule.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
-TVM_DLL runtime::Module build(const Map<std::string, Array<tir::LoweredFunc>>& input,
+TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input,
const Target& target_host,
const BuildConfig& config);
} // namespace tvm
CHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr);
}
+
+ /*!
+ * \brief Construct an empty module.
+ *
+ * \returns The constructed module
+ */
+ static IRModule Empty() {
+ return IRModule(Map<GlobalVar, BaseFunc>());
+ }
/*!
* \brief Construct a module from a standalone expression.
*
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>
#include <string>
using runtime::TVMRetValue;
/*!
- * \brief Temporary backward compatible function to convert a list
- * of LoweredFunc to a IRModule of PrimfFuncs
- * \param funcs The input lowered function.
- * \return The IRModule.
- *
- * \note This function is only used for code refactor and will be
- * removed once the refactor completes.
- */
-IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);
-
-/*!
* \brief Build a module from array of lowered function.
* \param mod The Module to be built
* \param target The target to be built.
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_
+#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
+
namespace tvm {
namespace tir {
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
+/*!
+ * \brief Verify if memory accesses are legal for a specific target device type.
+ *
+ * In the case that tgt is cuda, if not all workload is bound with
+ * threads, CPU code is generated that tries to access GPU memory,
+ * which is illegal. This pass performs verification for this case.
+ *
+ * \param mod The module to be verified.
+ * \return Success of memory verification.
+ */
+void VerifyMemory(const IRModule& mod);
+
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/function.h>
-#include <tvm/tir/lowered_func.h>
#include <unordered_map>
#include <unordered_set>
Stmt NarrowDataType(Stmt stmt, int target_bits);
/*!
- * \brief Make an user callable API LoweredFunc.
- *
- * The main task of this function is to create code to :
- * - Map the values in the api_args to Var that is required by body.
- * - Insert assertions to check type/value of the passed arguments.
- *
- * \param body The body of the function.
- * \param name The name of the function.
- * \param api_args Arguments to the function, can be either Var, or Buffer
- * \param num_unpacked_args Number of arguments that
- * are processed in plain form instead of packed form.
- * \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
- * It is recommended to set to true for optimized code if such invariant holds.
- *
- * \return a LoweredFunc with the specified signiture.
- *
- * \note
- * The function signature have two cases
- *
- * let num_packed_args = len(api_args) - num_unpacked_args;
- *
- * if num_packed_args is zero:
- * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
- *
- * if num_packed_args is not zero:
- * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
- * api_arg_k, api_arg_k+1, ... api_arg_n,
- * TVMValue* out_ret_val, int* out_ret_tcode)
- *
- * where n == len(api_args), k == num_packed_args
- *
- * There is no thread_axis in generated function.
- */
-LoweredFunc MakeAPI(Stmt body,
- std::string name,
- Array<ObjectRef> api_args,
- int num_unpacked_args,
- bool is_restricted);
-
-/*!
- * \brief Remap the thread axis
- *
- * This can be used to get equivalent program which uses
- * threadIdx.y in place of threadIdx.x by passing
- * {"threadIdx.x": thread_axis("threadIdx.y")}
- *
- *
- * \param f The device function to be lowered.
- * \param axis_map The map from StringImm -> ItrVar
- * \return Transformed function.
- */
-LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
-
-/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
PrimFunc PointerValueTypeRewrite(PrimFunc f);
/*!
- * \brief Lower custom datatypes.
- *
- * See tvm::datatypes::Registry for more information on adding custom datatypes.
- *
- * \param f The device function to be lowered.
- * \param target The target device.
- * \return Transformed function.
- */
-LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
-
-/*!
- * \brief Verify if memory accesses are legal for a specific target device type.
- *
- * In the case that tgt is cuda, if not all workload is bound with
- * threads, CPU code is generated that tries to access GPU memory,
- * which is illegal. This pass performs verification for this case.
- *
- * \param func The function to be verified.
- * \param device_type The target device type.
- * \return Success of memory verification.
- */
-bool VerifyMemory(LoweredFunc func, int device_type);
-
-
-/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/tir/lowered_func.h
- * \brief Information about a lowered TVM function.
- * This data structure is final step toward codegen.
- */
-#ifndef TVM_TIR_LOWERED_FUNC_H_
-#define TVM_TIR_LOWERED_FUNC_H_
-
-#include <tvm/node/container.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
-#include <string>
-
-namespace tvm {
-namespace tir {
-
-// Internal node container of lowered function.
-class LoweredFuncNode;
-
-/*!
- * \brief LoweredFunc represents function after lowering.
- * This is the final IR representation before codegen.
- */
-class LoweredFunc : public FunctionRef {
- public:
- LoweredFunc() {}
- explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {}
- /*!
- * \brief access the internal node container
- * \return the pointer to the internal node container
- */
- inline const LoweredFuncNode* operator->() const;
- /*! \brief specify container node */
- using ContainerType = LoweredFuncNode;
-};
-
-/*! \brief specific type of lowered function */
-enum LoweredFuncType : int {
- /*! \brief Function that can mix device and host calls */
- kMixedFunc = 0,
- /*! \brief Only contains host code */
- kHostFunc = 1,
- /*! \brief Only contains device code */
- kDeviceFunc = 2
-};
-
-/*! \brief Node container of LoweredFunc */
-class LoweredFuncNode : public tir::FunctionBaseNode {
- public:
- /*! \brief The name of the function */
- std::string name;
- /*!
- * \brief The arguments of the function
- * This function can only take pod type(int, float) and void* as arguments.
- */
- Array<Var> args;
- /*!
- * \brief The IterVar axis of threads
- * Each axis need host function to specify a size.
- * \note Calling convention into LoweredFunc
- *
- * Assume we have a LoweredFunc f, a call into f
- * Call(f, arg1, arg2, ..., arg_n,
- * size_axis_1, size_axis_2, ... size_axis_m)
- *
- * Here n = len(args), m = len(thread_axis)
- *
- * The CodeGen should take this and translate this call
- * to corresponding API specific kernel launchs or function calls.
- */
- Array<IterVar> thread_axis;
- /*!
- * \brief The hint data type of Var handles defined in LetStmt
- * Can be used as hint when generating type signiture.
- * The creation rule is given by
- * handle_data_type[var_handle] = make_const(the_type, 0);
- *
- * \note Expr is used instead Type, because Type cannot be hold by Map.
- * constant Expr of given type is used.
- */
- Map<Var, PrimExpr> handle_data_type;
- /*! \brief The type of the function */
- LoweredFuncType func_type{kMixedFunc};
- /*! \brief Whether this function is packed function */
- bool is_packed_func{true};
- /*!
- * \brief Whether function ensures that argument pointers do not alias.
- * This corresponds to restrict keyword in C.
- */
- bool is_restricted{true};
- /*! \brief The body statment of the function */
- Stmt body;
- /*! \return name of the operation */
- const std::string& func_name() const final {
- return name;
- }
- // there is no return value, but return 1
- // to enable Call into this function.
- int num_outputs() const final {
- return 1;
- }
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("args", &args);
- v->Visit("thread_axis", &thread_axis);
- v->Visit("handle_data_type", &handle_data_type);
- v->Visit("func_type", &func_type);
- v->Visit("is_packed_func", &is_packed_func);
- v->Visit("is_restricted", &is_restricted);
- v->Visit("body", &body);
- }
-
- static constexpr const char* _type_key = "LoweredFunc";
- TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object);
-};
-
-// Implementations of inline functions
-inline const LoweredFuncNode* LoweredFunc::operator->() const {
- return static_cast<const LoweredFuncNode*>(get());
-}
-} // namespace tir
-} // namespace tvm
-
-namespace std {
-template <>
-struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash {
-};
-}
-
-#endif // TVM_TIR_LOWERED_FUNC_H_
const tvm::Array<tvm::PrimExpr>& required);
/*!
+ * \brief Transform the high-level PrimFunc to a low-level version
+ * that can be used as an API function.
+ *
+ *
+ * The main task of this function is to create code to :
+ * - Map the values in the api_args to Var that is required by body.
+ * - Insert assertions to check type/value of the passed arguments.
+ *
+ * \param num_unpacked_args Number of arguments that
+ * are processed in plain form instead of packed form.
+ *
+ * \note
+ * The function signature have two cases
+ *
+ * let num_packed_args = len(api_args) - num_unpacked_args;
+ *
+ * if num_packed_args is zero:
+ * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
+ *
+ * if num_packed_args is not zero:
+ * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
+ * api_arg_k, api_arg_k+1, ... api_arg_n,
+ * TVMValue* out_ret_val, int* out_ret_tcode)
+ *
+ * where n == len(api_args), k == num_packed_args
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
+
+
+/*!
+ * \brief Remap the thread axis
+ *
+ * This can be used to get equivalent program which uses
+ * threadIdx.y in place of threadIdx.x by passing
+ * {"threadIdx.x": thread_axis("threadIdx.y")}
+ *
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
+
+
+/*!
+ * \brief Lower custom datatypes.
+ *
+ * See tvm::datatypes::Registry for more information on adding custom datatypes.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LowerCustomDatatypes();
+
+
+/*!
* \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute.
*
# pylint: disable=invalid-name
"""The build utils in python.
-
-This module provides the functions to transform schedule to
-LoweredFunc and compiled Module.
"""
import warnings
from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
-from tvm.tir.stmt import LoweredFunc
from tvm.te import tensor
from tvm.te import schedule
from tvm import target as _target
Returns
-------
- f : LoweredFunc or Stmt
- The result function, if with_api_wrapper=False
+ m : IRModule or Stmt
+ The result IRModule, if simple_mode=False
Then the Stmt before make api is returned.
"""
cfg = BuildConfig.current()
if simple_mode:
return stmt
- return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
+ f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
+ "global_symbol", tvm.runtime.String(name))
+ if cfg.restricted_func:
+ f = f.with_attr("tir.no_alias", True)
+ mod = tvm.IRModule({name: f})
+ return tvm.tir.transform.MakePackedAPI()(mod)
-def _build_for_device(flist, target, target_host):
+def _build_for_device(input_mod, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
- flist : list of LoweredFunc
+ input_mod : IRModule
The schedule to be built.
target : str or :any:`tvm.target.Target`
Returns
-------
- fhost : list of LoweredFunc
- A list of lowered functions for the host.
+ fhost : IRModule
+ The host IRModule.
mdev : tvm.module
A module that contains device code.
target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type
- for func in flist:
- if not ir_pass.VerifyMemory(func, device_type):
- raise ValueError(
- "Direct host side access to device memory is detected in %s. "
- "Did you forget to bind?" % func.name)
+ mod_mixed = input_mod
+ mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
+ tvm.tir.analysis.verify_memory(mod_mixed)
- mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
- opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
+ opt_mixed = []
+ if len(mod_mixed.functions) == 1:
+ opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
if BuildConfig.current().detect_global_barrier:
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
Parameters
----------
- inputs : tvm.te.Schedule, LoweredFunc, or dict of target to LoweredFunc list
+ inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
The schedule to be built
args : list of Buffer or Tensor or Var, optional
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
- 1. it is a list of lowered functions:
+ 1. it is an IRModule.
.. code-block:: python
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
- f = tvm.lower(s, [A, B, C], name="test_add")
- m = tvm.build(f, target="llvm")
+ m = tvm.lower(s, [A, B, C], name="test_add")
+ rt_mod = tvm.build(m, target="llvm")
- 2. it is a dict of compilation target to list of lowered functions:
+ 2. it is a dict of compilation target to IRModule.
.. code-block:: python
s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
- f1 = tvm.lower(s1, [A, B, C], name="test_add1")
- f2 = tvm.lower(s2, [A, B, C], name="test_add2")
- m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
+ m1 = tvm.lower(s1, [A, B, C], name="test_add1")
+ m2 = tvm.lower(s2, [A, B, C], name="test_add2")
+ rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
Note
----
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
- flist = lower(inputs, args,
- name=name,
- binds=binds)
- if isinstance(flist, LoweredFunc):
- flist = [flist]
- elif isinstance(inputs, LoweredFunc):
- if args:
- raise ValueError("args must be done when build from LoweredFunc.")
- flist = [inputs]
+ input_mod = lower(inputs, args,
+ name=name,
+ binds=binds)
elif isinstance(inputs, (list, tuple, container.Array)):
- flist = inputs
+ merged_mod = tvm.IRModule({})
+ for x in inputs:
+ merged_mod.update(x)
+ input_mod = merged_mod
+ elif isinstance(inputs, tvm.IRModule):
+ input_mod = inputs
elif not isinstance(inputs, (dict, container.Map)):
- raise ValueError("inputs must be Schedule, LoweredFunc, list of "
- "LoweredFunc, or dict of target to list of "
- "LoweredFunc.")
+ raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")
if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target
target = target if target else "llvm"
- target_flist = {target: flist}
+ target_input_mod = {target: input_mod}
else:
- target_flist = inputs
+ target_input_mod = inputs
- for tar, flist in target_flist.items():
+ for tar, mod in target_input_mod.items():
if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
- fname_set = set()
- for x in flist:
- if not isinstance(x, LoweredFunc):
- raise ValueError("inputs must be Schedule, LoweredFunc, list "
- "of LoweredFunc, or dict of str to list of "
- "LoweredFunc.")
- if x.name in fname_set:
- raise ValueError("Duplicate function name %s" % x.name)
- fname_set.add(x.name)
+ if not isinstance(mod, tvm.IRModule):
+ raise ValueError("inputs must be Schedule, IRModule,"
+ "or dict of str to IRModule.")
if not target_host:
- for tar, _ in target_flist.items():
+ for tar, _ in target_input_mod.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
mod_host_all = tvm.IRModule({})
device_modules = []
- for tar, flist in target_flist.items():
- mod_host, mdev = _build_for_device(flist, tar, target_host)
+ for tar, input_mod in target_input_mod.items():
+ mod_host, mdev = _build_for_device(input_mod, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)
"""The interface of expr function exposed from C++."""
import tvm._ffi
import tvm.driver
-from tvm.ir import container as _container
@tvm._ffi.register_func("relay.backend.lower")
Returns
-------
- lowered_funcs : List[tvm.LoweredFunc]
+ mod : tvm.IRModule
The result of lowering.
"""
# pylint: disable=broad-except, import-outside-toplevel
msg += "-----------------------------\n"
msg += source_func.astext()
raise RuntimeError(msg)
- return f if isinstance(
- f, (_container.Array, tuple, list)) else [f]
+ return f
@tvm._ffi.register_func("relay.backend.build")
-def build(funcs, target, target_host=None):
+def build(mod, target, target_host=None):
"""Backend build function.
Parameters
----------
- funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
- A list of lowered functions or dictionary mapping from targets to
- lowered functions.
-
+ mod : tvm.IRModule or Dict[str, tvm.IRModule]
+ Input module
target : tvm.Target
The target to run the code on.
"""
if target_host == "":
target_host = None
- return tvm.driver.build(funcs, target=target, target_host=target_host)
+ return tvm.driver.build(mod, target=target, target_host=target_host)
@tvm._ffi.register_func("relay._tensor_value_repr")
self._get_graph_json = self._mod["get_graph_json"]
self._list_params_name = self._mod["list_params_name"]
self._get_param_by_name = self._mod["get_param_by_name"]
- self._get_lowered_funcs = self._mod["get_lowered_funcs"]
+ self._get_irmodule = self._mod["get_irmodule"]
self._setup(mod, target)
def _setup(self, mod, target):
-------
graph_json : str
The graph json that can be consumed by runtime.
- lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
+ mod : IRModule or Dict[str, IRModule]
The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self._codegen(func)
graph_json = self._get_graph_json()
- lowered_func = self._get_lowered_funcs()
+ lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
for name in param_names:
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .module import load_module, enabled, system_lib
+from .container import String
import tvm.ir
from tvm.runtime import Object
-from tvm.ir import container
from tvm.tir import Stmt
-from tvm.tir.stmt import LoweredFunc
from . import _ffi_api
def dump(*args, **kwargs):
"""dump function"""
retv = func(*args, **kwargs)
- if not isinstance(retv, (Stmt, LoweredFunc, container.Array)):
+ if not isinstance(retv, (Stmt,)):
return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f:
- out = retv.body if isinstance(retv, LoweredFunc) else retv
+ out = retv
f.write(str(out))
- if isinstance(retv, container.Array):
- for x in retv:
- out = x.body if isinstance(x, LoweredFunc) else x
- f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
return dump
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+# pylint: disable=invalid-name
""" TVM testing utilities """
import logging
import numpy as np
+import tvm
import tvm._ffi
x_name, grad.shape, dist, max_diff, avg_diff)
+def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
+ """Legacy adapter to build a Module from statement.
+
+ Used for migrating existing test cases only.
+
+ Parameters
+ ----------
+ stmt: Stmt
+ The input statement.
+
+ name: str
+ The name of the funciton.
+
+ args: list of Buffer or Vars
+ The function arguments
+
+ num_unpacked_args: int
+ Number of unpacked arguments.
+
+ nolias: bool
+ Whether allow noalias.
+
+ Returns
+ -------
+ mod : IRModule
+ The created IRModule.
+ """
+ f = tvm.tir.PrimFunc(args, stmt).with_attr(
+ "global_symbol", tvm.runtime.String(name))
+ f = f.with_attr("tir.is_entry_func", True)
+ if noalias:
+ f = f.with_attr("tir.no_alias", True)
+ mod = tvm.IRModule({name: f})
+ return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)
+
+
tvm._ffi._init_api("testing", __name__)
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
-from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
+from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .function import PrimFunc
tvm.ir.structural_equal
"""
return _ffi_api.expr_deep_equal(lhs, rhs)
+
+
+def verify_memory(mod):
+ """Verify if module contains illegal host side direct memory access.
+
+ Parameters
+ ----------
+ mod: tvm.IRModule
+ The module to be verified.
+ """
+ _ffi_api.verify_memory(mod)
import tvm._ffi
import tvm.runtime
+from tvm.runtime import Object
from tvm.ir import BaseFunc
from .buffer import Buffer
from .expr import Var
param_list = []
buffer_map = {} if buffer_map is None else buffer_map
for x in params:
+ x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
if isinstance(x, Buffer):
var = Var(x.name, dtype="handle")
param_list.append(var)
_ffi_api.Prefetch, func, value_index, dtype, bounds)
-@tvm._ffi.register_object
-class LoweredFunc(Object):
- """Represent a LoweredFunc in TVM."""
- MixedFunc = 0
- HostFunc = 1
- DeviceFunc = 2
-
-
def stmt_seq(*args):
"""Make sequence of statements
return _fpass.prim_func_pass(_transform, opt_level=0)
+def LowerCustomDatatypes():
+ """Lower custom datatypes.
+
+ See tvm::datatypes::Registry for more information on adding custom datatypes.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerCustomDatatypes()
+
+
+def MakePackedAPI(num_unpacked_params=0):
+ """Transform the PrimFuncs in the module to a packed func API.
+
+ Parameters
+ ----------
+ num_unpacked_params : int
+ Number of parameters that we hope to directly pass via normal arguments
+ following the PackedFunc input signature.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ return _ffi_api.MakePackedAPI(num_unpacked_params)
+
+
def BindDeviceType():
"""Bind the device type of the function to be
the device_type specified in the target attribute.
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
-#include <tvm/tir/lowered_func.h>
#include <tvm/te/schedule.h>
#include <map>
#include <string>
#include <tvm/te/operation.h>
#include <tvm/tir/transform.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/target/codegen.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <algorithm>
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
-using tir::LoweredFunc;
bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
return stmt;
}
-Array<LoweredFunc> lower(te::Schedule sch,
- const Array<te::Tensor>& args,
- const std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- const BuildConfig& config) {
- Array<ObjectRef> out_arg_list;
- auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
- return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
-}
-
-
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
}
+IRModule lower(te::Schedule sch,
+ const Array<te::Tensor>& args,
+ const std::string& name,
+ const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+ const BuildConfig& config) {
+ Array<ObjectRef> out_arg_list;
+ auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
+
+ Array<tir::Var> params;
+ Map<tir::Var, tir::Buffer> buffer_map;
+
+ for (auto var : out_arg_list) {
+ if (auto* n = var.as<tir::VarNode>()) {
+ params.push_back(GetRef<tir::Var>(n));
+ } else {
+ tir::Buffer buffer = Downcast<tir::Buffer>(var);
+ tir::Var bptr(buffer->name, DataType::Handle());
+ params.push_back(bptr);
+ buffer_map.Set(bptr, buffer);
+ }
+ }
+
+ auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map);
+ f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+ if (config->restricted_func) {
+ f = WithAttr(std::move(f), "tir.no_alias", Integer(1));
+ }
+ auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+ return tir::transform::MakePackedAPI(0)(mod);
+}
+
+
std::pair<IRModule, IRModule>
-split_dev_host_funcs(const Array<LoweredFunc>& funcs,
+split_dev_host_funcs(IRModule mod_mixed,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
- for (const auto& x : funcs) {
- CHECK(tir::VerifyMemory(x, target->device_type))
- << "Direct host side access to device memory is detected in "
- << x->func_name() << ". Did you forget to bind?";
- }
-
- IRModule mod_mixed = codegen::ToIRModule(funcs);
+ mod_mixed = BindTarget(target)(std::move(mod_mixed));
+ tir::VerifyMemory(mod_mixed);
Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
if (config->detect_global_barrier) {
// Build for heterogeneous execution.
-runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
+runtime::Module build(const Map<Target, IRModule>& inputs,
const Target& target_host,
const BuildConfig& config) {
- Array<LoweredFunc> fhost_all;
std::vector<runtime::Module> device_modules;
Target target_host_val = target_host;
}
// Build for heterogeneous execution when target is a string.
-runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
+runtime::Module build(const Map<std::string, IRModule>& inputs,
const Target& target_host,
const BuildConfig& config) {
- Map<Target, Array<LoweredFunc>> updated_input;
+ Map<Target, IRModule> updated_input;
for (const auto& it : inputs) {
auto target = Target::Create(it.first);
if (target->device_name == "vta") {
}
// Build for homogeneous execution.
-runtime::Module build(const Array<LoweredFunc>& funcs,
+runtime::Module build(const IRModule& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
- Map<Target, Array<LoweredFunc>> inputs = {{target, funcs}};
+ Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host, config);
}
namespace relay {
namespace backend {
-using tir::LoweredFunc;
using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;
}
Array<tvm::runtime::Module> GetExternalModules() {
- return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
+ return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}
- Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
- return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
+ Map<std::string, IRModule> GetIRModule() {
+ return CallFunc<Map<std::string, IRModule>>("get_irmodule", nullptr);
}
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
- auto names = CallFunc<Array<tvm::PrimExpr> >("list_params_name", nullptr);
+ auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<tir::StringImmNode>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
this->SetParam(kv.first, kv.second->data);
}
});
- } else if (name == "get_lowered_funcs") {
+ } else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- *rv = this->graph_codegen_->GetLoweredFunc();
+ *rv = this->graph_codegen_->GetIRModule();
});
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
- auto lowered_funcs = graph_codegen_->GetLoweredFunc();
+ auto lowered_funcs = graph_codegen_->GetIRModule();
// When there is no lowered_funcs due to reasons such as optimization.
if (lowered_funcs.size() == 0) {
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
-#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
/*! \brief The schedule to the function */
te::Schedule schedule;
/*! \brief The lowered functions to support the function. */
- tvm::Array<tir::LoweredFunc> funcs;
+ IRModule funcs = IRModule::Empty();
+
/*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states;
/*! \brief Lowered outputs */
struct LoweredOutput {
std::string graph_json;
- Map<std::string, Array<tir::LoweredFunc> > lowered_funcs;
+ Map<std::string, IRModule> lowered_funcs;
Array<tvm::runtime::Module> external_mods;
std::unordered_map<std::string, tvm::runtime::NDArray> params;
};
LoweredOutput ret;
ret.graph_json = os.str();
ret.params = params_;
+
for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) {
- ret.lowered_funcs.Set(kv.first, Array<tir::LoweredFunc>());
- }
- auto& vec = ret.lowered_funcs[kv.first];
- Array<tir::LoweredFunc> tmp;
- for (auto f : kv.second) {
- tmp.push_back(f);
- }
- for (auto f : vec) {
- tmp.push_back(f);
+ ret.lowered_funcs.Set(kv.first, IRModule::Empty());
}
- ret.lowered_funcs.Set(kv.first, tmp);
+ auto& mod = ret.lowered_funcs[kv.first];
+ mod->Update(kv.second);
+ ret.lowered_funcs.Set(kv.first, mod);
}
ret.external_mods = compile_engine_->LowerExternalFunctions();
return ret;
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
if (!lowered_funcs_.count(target->str())) {
- lowered_funcs_[target->str()] = {};
+ lowered_funcs_[target->str()] = IRModule::Empty();
}
- for (auto f : lowered_func->funcs) {
- lowered_funcs_[target->str()].insert(f);
- }
-
+ lowered_funcs_[target->str()]->Update(lowered_func->funcs);
return GraphAddCallNode(op,
_GetUniqueName(lowered_func->func_name),
lowered_func->func_name);
/*! \brief plan memory of device result */
Map<Expr, Array<IntegerArray>> storage_device_map_;
/*! \brief lowered funcs */
- std::unordered_map<std::string, std::unordered_set<tir::LoweredFunc, ObjectHash, ObjectEqual>>
- lowered_funcs_;
+ std::unordered_map<std::string, IRModule> lowered_funcs_;
/*! \brief name map */
std::unordered_map<std::string, size_t> name_map_;
/*! \brief compile engine */
CHECK_GT(this->output_.params.count(key), 0);
*rv = this->output_.params[key];
});
- } else if (name == "get_lowered_funcs") {
+ } else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.lowered_funcs;
});
return raw_shape;
}
+
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
CCacheKey key(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key);
int op_index = -1;
- if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
+ // pick the only function inside the context
+ CHECK_EQ(cfunc->funcs->functions.size(), 1);
+ auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
+ if (context_->seen_funcs.count(pfunc) == 0) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
- context_->seen_funcs[cfunc->funcs[0]] = op_index;
+ context_->seen_funcs[pfunc] = op_index;
} else {
- op_index = context_->seen_funcs[cfunc->funcs[0]];
+ op_index = context_->seen_funcs[pfunc];
}
// Prepare input and output registers
context_->cached_funcs.push_back(cfunc);
} else {
// TODO(jroesch): support lowered funcs for multiple targets
- CHECK_EQ(cfunc->funcs.size(), 1);
- if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+ CHECK_EQ(cfunc->funcs->functions.size(), 1);
+ auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
+ if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
- context_->seen_funcs[cfunc->funcs[0]] = op_index;
+ context_->seen_funcs[pfunc] = op_index;
} else {
- op_index = context_->seen_funcs[cfunc->funcs[0]];
+ op_index = context_->seen_funcs[pfunc];
}
}
// update primitive function map
size_t primitive_index = 0;
for (const auto& cfunc : context_.cached_funcs) {
- if (cfunc->target->str() == "ext_dev") {
- exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
- } else {
- exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
- }
+ exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
}
}
}
void VMCompiler::Codegen() {
- using tir::LoweredFunc;
-
if (!context_.module.defined()) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return;
if (cached_funcs.size() == 0) {
return;
}
- std::unordered_map<std::string, Array<LoweredFunc>> funcs;
+ std::unordered_map<std::string, IRModule> funcs;
+
for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str();
+ // NOTE: because module, is mutable, we need to make an
+ // explicit copy of the IRModule.
+ IRModule mod = cfunc->funcs;
+ mod.CopyOnWrite();
+
if (target_str == "ext_dev") {
continue;
} else if (funcs.count(target_str) == 0) {
- funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
+ funcs.emplace(target_str, mod);
} else {
- funcs[target_str].push_back(cfunc->funcs[0]);
+ funcs[target_str]->Update(mod);
}
}
// List of cached functions
std::vector<CachedFunc> cached_funcs;
// The functions that have been lowered.
- std::unordered_map<tir::LoweredFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
+ std::unordered_map<tir::PrimFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
};
* \brief API for Automatic Differentiation for the Relay IR.
*/
#include <tvm/ir/type_functor.h>
-#include <tvm/tir/lowered_func.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
-#include <tvm/tir/lowered_func.h>
#include <unordered_map>
#include <string>
#include "../runtime/meta_data.h"
namespace tvm {
namespace codegen {
-// Extract function information from device function.
-inline std::unordered_map<std::string, runtime::FunctionInfo>
-ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) {
- std::unordered_map<std::string, runtime::FunctionInfo> fmap;
- for (tir::LoweredFunc f : funcs) {
- runtime::FunctionInfo info;
- for (size_t i = 0; i < f->args.size(); ++i) {
- info.arg_types.push_back(f->args[i].dtype());
- }
- for (size_t i = 0; i < f->thread_axis.size(); ++i) {
- info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
- }
- fmap[f->name] = info;
- }
- return fmap;
-}
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule& mod) {
namespace tvm {
namespace codegen {
-// convert legacy LoweredFunc to PrimFunc.
-tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
- // remap args to attach type annotations.
- Array<tir::Var> args;
- Map<tir::Var, PrimExpr> remap_vars;
-
- for (auto var : from->args) {
- auto it = from->handle_data_type.find(var);
- if (it != from->handle_data_type.end()) {
- tir::Var new_var(var->name_hint,
- PointerType(PrimType((*it).second->dtype)));
- args.push_back(new_var);
- remap_vars.Set(var, new_var);
- } else {
- args.push_back(var);
- }
- }
- tir::PrimFunc func(args, Substitute(from->body, remap_vars));
-
- func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name));
- func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis);
- if (from->func_type == tir::LoweredFuncType::kDeviceFunc) {
- func = WithAttr(std::move(func),
- attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch));
- }
- if (from->is_restricted) {
- func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1));
- }
- return func;
-}
-
-IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
- Map<GlobalVar, BaseFunc> functions;
- for (size_t i = 0; i < funcs.size(); ++i) {
- auto f = funcs[i];
- tir::PrimFunc pf = ToPrimFunc(f);
- if (i == 0) {
- pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1));
- }
- functions.Set(GlobalVar(f->name), pf);
- }
- return IRModule(functions);
-}
-
runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) {
mod = tir::transform::SkipAssert()(mod);
TVM_REGISTER_GLOBAL("target.Build")
.set_body_typed(Build);
-TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
-.set_body_typed(ToIRModule);
-
// Export two auxiliary function to the runtime namespace.
TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
.set_body_typed(PackImportsToC);
auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
#endif
- // TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance?
+ // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
} else if (name == "_get_target_triple") {
std::string target_triple = tm_->getTargetTriple().str();
return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) {
- * rv = target_triple;
+ *rv = target_triple;
});
}
if (ee_ == nullptr) LazyInitJIT();
- // This LLVMModule is empty and no function can be retrieved.
- if (entry_func_.empty()) return nullptr;
-
std::lock_guard<std::mutex> lock(mutex_);
- const std::string& fname = (name == runtime::symbol::tvm_module_main ?
- entry_func_ : name);
- TVMBackendPackedCFunc faddr =
- reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(fname));
+ TVMBackendPackedCFunc faddr;
+ if (name == runtime::symbol::tvm_module_main) {
+ const char* entry_name = reinterpret_cast<const char*>(
+ GetGlobalAddr(runtime::symbol::tvm_module_main));
+ CHECK(entry_name != nullptr)
+ << "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
+ faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(entry_name));
+ } else {
+ faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(name));
+ }
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
std::vector<PrimFunc> funcs;
+ std::string entry_func;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
- entry_func_ = global_symbol;
+ entry_func = global_symbol;
}
funcs.push_back(f);
}
cg->AddFunction(f);
}
- if (entry_func_.length() != 0) {
- cg->AddMainFunction(entry_func_);
+ if (entry_func.length() != 0) {
+ cg->AddMainFunction(entry_func);
}
module_ = cg->Finish();
CHECK(ee_ != nullptr)
<< "Failed to initialize jit engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);
- // setup context address.
- // we will skip context setup if this LLVMModule is empty.
- if (GetGlobalAddr(runtime::symbol::tvm_module_main) == 0)
- return;
- entry_func_ =
- reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
if (void** ctx_addr = reinterpret_cast<void**>(
GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
// The target configuration string
std::string target_;
- // Name of entry function.
- std::string entry_func_;
// JIT lock
std::mutex mutex_;
// execution engine
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
-#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/container.h>
#include <string>
#include <vector>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/lowered_func.h>
#include <vector>
#include <memory>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/lowered_func.h>
#include <tvm/target/codegen.h>
#include <string>
#include <vector>
* \brief Pass to check if memory accesses are legal.
*/
#include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
namespace tvm {
public:
/// Special member functions
//@{
- explicit MemoryAccessVerifier(LoweredFunc f, int device_type)
+ explicit MemoryAccessVerifier(PrimFunc f, int device_type)
: func_(f), dev_type_(device_type) {}
virtual ~MemoryAccessVerifier() = default;
MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
CHECK(V) << "Invalid Variable\n";
// Variable is from function args. Return true.
- if (V == func_->args[0].get()) return true;
+ if (V == func_->params[0].get()) return true;
// The value is expected to come from a tvm_struct_get Call.
// Get the first argument of tvm_struct_get, and continue.
const ProducerConsumerNode *pc_{nullptr};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@}
- LoweredFunc func_{nullptr}; ///< Function to be verified.
+ tir::PrimFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
std::unordered_map<const VarNode *, PrimExpr> defs_; ///< Variable definitions
};
} // namespace
/// Interface of VerifyMemory pass
-bool VerifyMemory(LoweredFunc func, int device_type) {
- MemoryAccessVerifier v(func, device_type);
- v.Run();
- return !v.Failed();
+void VerifyMemory(const IRModule& mod) {
+ for (auto kv : mod->functions) {
+ if (auto* n = kv.second.as<PrimFuncNode>()) {
+ PrimFunc func = GetRef<PrimFunc>(n);
+ auto target = func->GetAttr<Target>(tvm::attr::kTarget);
+ CHECK(target.defined())
+ << "LowerWarpMemory: Require the target attribute";
+ MemoryAccessVerifier v(func, target->device_type);
+ v.Run();
+ if (v.Failed()) {
+ LOG(FATAL)
+ << "ValueError: Direct host side access to device memory is detected."
+ << " Did you forget to bind?\n"
+ << func;
+ }
+ }
+ }
}
+TVM_REGISTER_GLOBAL("tir.analysis.verify_memory")
+.set_body_typed(VerifyMemory);
} // namespace tir
} // namespace tvm
DataType dtype,
std::string name) {
return BufferNode::make(
- Var(name, DataType::Handle()),
+ Var(name, PointerType(PrimType(dtype))),
dtype,
shape,
Array<PrimExpr>(),
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file lowered_func.cc
- */
-#include <tvm/tir/lowered_func.h>
-
-namespace tvm {
-namespace tir {
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LoweredFuncNode*>(node.get());
- p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
-});
-
-TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
-
-
-} // namespace tir
-} // namespace tvm
});
});
-TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- LoweredFunc f = args[0];
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = LowerStorageAccessInfo(f->body);
- *ret = LoweredFunc(n);
-});
// make from two arguments
#define REGISTER_PASS(PassName) \
REGISTER_PASS(SkipVectorize);
REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(MakeAPI);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope);
-REGISTER_PASS(RemapThreadAxis);
-REGISTER_PASS(LowerCustomDatatypes);
-REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
};
-LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- VectorAllocRewriter rewriter;
- n->body = rewriter(n->body);
- for (Var arg : f->args) {
- if (arg.dtype().is_handle()) {
- const auto& tvec = rewriter.acc_map_[arg.get()];
- if (tvec.size() == 1) {
- PrimExpr dtype = make_const(tvec[0], 0);
- n->handle_data_type.Set(arg, dtype);
- } else {
- // always set data type to be non vectorized so
- // load/store can still work via scalarization
- if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
- PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0);
- n->handle_data_type.Set(arg, dtype);
- }
- }
- }
- }
- return LoweredFunc(n);
-}
-
PrimFunc PointerValueTypeRewrite(PrimFunc f) {
auto* n = f.CopyOnWrite();
VectorAllocRewriter rewriter;
*/
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
#include "../../target/datatype/registry.h"
namespace tvm {
std::string target_;
};
-LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = CustomDatatypesLowerer(target)(n->body);
- return LoweredFunc(n);
+
+namespace transform {
+
+Pass LowerCustomDatatypes() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+ CHECK(target.defined())
+ << "LowerCustomDatatypes: Require the target attribute";
+
+ n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
}
+TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes")
+.set_body_typed(LowerCustomDatatypes);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
*/
/*!
- * \file make_api.cc Build API function.
+ * \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/buffer.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
+
#include <vector>
#include <utility>
#include <unordered_set>
-#include "ir_util.h"
-#include "arg_binder.h"
+#include "../pass/ir_util.h"
+#include "../pass/arg_binder.h"
namespace tvm {
namespace tir {
return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
}
-LoweredFunc MakeAPI(Stmt body,
- std::string name,
- Array<ObjectRef> api_args,
- int num_unpacked_args,
- bool is_restricted) {
+PrimFunc MakePackedAPI(PrimFunc&& func,
+ int num_unpacked_args) {
+ auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ CHECK(global_symbol.defined())
+ << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
+ std::string name_hint = global_symbol;
+
+ auto* func_ptr = func.CopyOnWrite();
const Stmt nop = EvaluateNode::make(0);
- int num_args = static_cast<int>(api_args.size());
+ int num_args = static_cast<int>(func_ptr->params.size());
CHECK_LE(num_unpacked_args, num_args);
+
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
- Array<PrimExpr> call_args{v_packed_args,
- IntImm(DataType::Int(32), i),
- IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
+ Array<PrimExpr> call_args{
+ v_packed_args,
+ IntImm(DataType::Int(32), i),
+ IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
PrimExpr res = CallNode::make(
}
return res;
};
- // get declaration of argument i
- auto f_arg_decl = [&](int i) {
- std::ostringstream os;
- os << "arg" << i;
- const VarNode* v = api_args[i].as<VarNode>();
- return Var(os.str(), v ? v->dtype: DataType::Handle());
- };
+
// ---------------------------
// start of logics
// add signiture for packed arguments.
args.push_back(v_num_packed_args);
std::ostringstream os;
- os << name << ": num_args should be " << num_packed_args;
+ os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back(
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
- // Save the input variables and buffers that will be bound later.
- std::vector<std::pair<Var, Var> > var_defs;
- std::vector<std::pair<Buffer, Var> > buf_defs;
- for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
- Var v_arg = f_arg_decl(i);
+ // Need to re-declare vars, in case some arguments also appears in the buffer.
+ std::vector<std::pair<Var, Var> > var_def;
+ std::vector<std::pair<Var, Buffer> > buffer_def;
+
+ for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
+ Var param = func_ptr->params[i];
+ Var v_arg = Var("arg" + std::to_string(i), param->dtype);
+
+ auto it = func_ptr->buffer_map.find(param);
+ if (it != func_ptr->buffer_map.end()) {
+ buffer_def.emplace_back(v_arg, (*it).second);
+ } else {
+ var_def.emplace_back(v_arg, param);
+ }
if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmtNode::make(
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
- msg << name << ": Expect arg[" << i << "] to be pointer";
+ msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(
AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
- tcode == kTVMNDArrayHandle ||
- tcode == kTVMDLTensorHandle ||
- tcode == kTVMNullptr, msg.str(), nop));
+ tcode == kTVMNDArrayHandle ||
+ tcode == kTVMDLTensorHandle ||
+ tcode == kTVMNullptr, msg.str(), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
- msg << name << ": Expect arg[" << i << "] to be int";
+ msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
- msg << name << ": Expect arg[" << i << "] to be float";
+ msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(
AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
}
} else {
args.push_back(v_arg);
}
- // add checks for functions.
- if (api_args[i].as<VarNode>()) {
- var_defs.emplace_back(std::make_pair(Downcast<Var>(api_args[i]), v_arg));
- } else {
- // Buffer checks
- CHECK(api_args[i].as<BufferNode>())
- << "api_args can only be Buffer or Var";
- buf_defs.emplace_back(std::make_pair(Downcast<Buffer>(api_args[i]), v_arg));
- }
}
// allow return value if the function is packed.
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let bining yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
- for (const auto& arg : var_defs) {
- binder.Bind(arg.first, arg.second, arg.second->name_hint, true);
+ for (const auto& kv : var_def) {
+ binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
+ }
+
+ for (const auto& kv : buffer_def) {
+ binder.BindDLTensor(kv.second, device_type, device_id,
+ kv.first, kv.first->name_hint);
}
- for (const auto& buf_arg : buf_defs) {
- binder.BindDLTensor(buf_arg.first, device_type, device_id,
- buf_arg.second, buf_arg.second->name_hint);
+ if (num_unpacked_args == 0) {
+ func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
- ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>();
- n->name = name;
- n->args = args;
- n->handle_data_type = binder.def_handle_dtype();
- n->is_packed_func = num_unpacked_args == 0;
- n->is_restricted = is_restricted;
- body = AttrStmtNode::make(
+ auto body = AttrStmtNode::make(
make_zero(DataType::Int(32)), attr::compute_scope,
- StringImmNode::make(name + "_compute_"), body);
+ StringImmNode::make(name_hint + "_compute_"), func_ptr->body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImmNode::make("default");
device_type, device_id}, CallNode::Intrinsic)));
body = SeqStmt({set_device, body});
}
- n->body = MergeNest(
+ func_ptr->body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
- LoweredFunc f(n);
- Array<Var> undefined = UndefinedVars(f->body, f->args);
+ func_ptr->params = args;
+
+ Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
- os << " does not appear in api_args";
+ os << " is not bound to any variables";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
- return f;
+
+
+ func_ptr->buffer_map = Map<Var, Buffer>();
+ func_ptr->checked_type_ = func_ptr->func_type_annotation();
+ func_ptr->ret_type = PrimType(DataType::Int(32));
+
+ // return the function.
+ return std::move(func);
}
+namespace transform {
+
+Pass MakePackedAPI(int num_unpacked_args) {
+ auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
+ IRModuleNode* mptr = m.CopyOnWrite();
+ std::vector<std::pair<GlobalVar, PrimFunc> > updates;
+
+ for (const auto& kv : mptr->functions) {
+ if (auto* n = kv.second.as<PrimFuncNode>()) {
+ PrimFunc func = GetRef<PrimFunc>(n);
+ if (func->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value
+ == static_cast<int>(CallingConv::kDefault)) {
+ auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
+ updates.push_back({kv.first, updated_func});
+ }
+ }
+ }
+
+ for (const auto& pair : updates) {
+ mptr->AddUnchecked(pair.first, pair.second);
+ }
+ return m;
+ };
+
+ return tvm::transform::CreateModulePass(
+ pass_func, 0, "tir.MakePackedAPI", {});
+}
+TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI")
+.set_body_typed(MakePackedAPI);
+} // namespace transform
} // namespace tir
} // namespace tvm
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/runtime/registry.h>
#include <unordered_map>
std::unordered_map<const VarNode*, Var> vmap_;
};
-LoweredFunc
-RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
+
+PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
const StringImmNode* str = kv.first.as<StringImmNode>();
tmap[str->value] = kv.second;
}
- CHECK_EQ(f->func_type, kDeviceFunc);
- auto n = make_object<LoweredFuncNode>(*f.operator->());
+ auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
+ auto* n = f.CopyOnWrite();
+
// replace the thread axis
- for (size_t i = 0; i < n->thread_axis.size(); ++i) {
- auto it = tmap.find(n->thread_axis[i]->thread_tag);
+ for (size_t i = 0; i < thread_axis.size(); ++i) {
+ auto it = tmap.find(thread_axis[i]->thread_tag);
if (it != tmap.end()) {
- n->thread_axis.Set(i, it->second);
+ thread_axis.Set(i, it->second);
}
}
- n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
- return LoweredFunc(n);
+ n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body));
+ return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis);
}
+
+namespace transform {
+
+Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
+ auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
+ return RemapThreadAxis(std::move(f), thread_map);
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis")
+.set_body_typed(RemapThreadAxis);
+
+} // namespace transform
} // namespace tir
} // namespace tvm
std::string name_prefix_;
// Number of device functions.
int device_func_counter_{0};
- std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
std::unordered_map<Tensor, Buffer> binds;
auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config);
auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config);
- Map<tvm::Target, Array<LoweredFunc>> inputs = {{target_cuda, lowered_s1},
- {target_llvm, lowered_s2}};
+ Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1},
+ {target_llvm, lowered_s2}};
auto module = build(inputs, Target(), config);
// Assertion for build.
from tvm import te
import numpy as np
-def lower(s, args, name="mydot"):
- binds = {}
- arg_list = []
-
- for x in args:
- assert isinstance(x, te.tensor.Tensor)
- buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
- binds[x] = buf
- arg_list.append(buf)
- s = s.normalize()
- bounds = tvm.te.schedule.InferBound(s)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 16)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
- fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
- return fapi
-
-
-def mybuild(fapi, target="llvm"):
- return
-
def test_dot():
nn = 12
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
stmt = ib.get()
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
- mod = tvm.testing.LoweredFuncsToIRModule([fapi])
+
+
+ mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
f = tvm.target.codegen.build_module(mod, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
elemwise_sub],
name="elemwise_sub")
- target_flist = {target_device: [lower_add], target_host: [lower_sub]}
+ target_flist = {target_device: lower_add, target_host: lower_sub}
mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx]
mod = graph_runtime.create(graph, mhost, ctx)
elemwise_sub],
name="elemwise_sub")
- target_flist = {target_device: [lower_add0, lower_add1], target_host:
- [lower_sub]}
+ lower_add0.update(lower_add1)
+ target_flist = {target_device: lower_add0, target_host:
+ lower_sub}
mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx]
params = {}
tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1))
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
- m = tvm.driver.build(fapi, target="llvm")
+ m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
+ m = tvm.driver.build(m, target="llvm")
for name in names:
m.save(name)
import ctypes
import math
+
def test_llvm_intrin():
ib = tvm.tir.ir_builder.create()
n = tvm.runtime.convert(4)
tvm.tir.Call(
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get()
- func = tvm.tir.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
+
+ func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A)
ib.emit(x)
body = ib.get()
- func = tvm.tir.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
+ func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm")
f2 = tvm.lower(s, [A, B, C], name="fadd1")
f1 = tvm.lower(s, [A, B, C], name="fadd2")
m = tvm.build([f1, f2], "llvm")
- fadd1 = m['fadd1']
fadd2 = m['fadd2']
+ fadd1 = m['fadd1']
+
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
if __name__ == "__main__":
+ test_multiple_func()
test_llvm_large_uintimm()
test_llvm_import()
test_alignment()
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
- test_multiple_func()
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
import ctypes
import numpy as np
+
+def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
+ """Legacy adapter to create a API"""
+ f = tvm.tir.PrimFunc(args, stmt).with_attr(
+ "global_symbol", tvm.runtime.String(name))
+ f = f.with_attr("tir.is_entry_func", True)
+ if noalias:
+ f = f.with_attr("tir.no_alias", True)
+ mod = tvm.IRModule.from_expr(f)
+ return tvm.tir.transform.MakePackedAPI()(mod)
+
+
def test_static_callback():
dtype = 'int64'
n = te.size_var('n')
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+ fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
return sh
stmt = ib.get()
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+ fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
s = f.get_source()
check(f)
+
+def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
+ """Legacy adapter to create a API"""
+ f = tvm.tir.PrimFunc(args, stmt).with_attr(
+ "global_symbol", tvm.runtime.String(name))
+ f = f.with_attr("tir.is_entry_func", True)
+ if noalias:
+ f = f.with_attr("tir.no_alias", True)
+ mod = tvm.IRModule.from_expr(f)
+ return tvm.tir.transform.MakePackedAPI()(mod)
+
+
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
+ fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
run_jit(fapi, lambda f: f(a))
ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))
stmt = ib.get()
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+ fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
A[i + 1] = A[i] + 2
stmt = ib.get()
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
+ fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
- fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
+ fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
from tvm import te
from ctypes import *
import topi
-import tvm.tir.ir_pass as ir_pass
import numpy as np
tgt = "llvm"
Once datatype lowering is integrated directly into TVM's lower/build
process, we won't need to do this manually.
TODO(gus) integrate datatype lowering into build process; change this test"""
- flist = tvm.lower(schedule, args)
- flist = [flist]
- flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist]
- return tvm.build(flist[0], target=tgt)
+ mod = tvm.lower(schedule, args)
+ target = tvm.target.create(tgt)
+ mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
+ mod = tvm.tir.transform.LowerCustomDatatypes()(mod)
+ return tvm.build(mod, target=tgt)
+
def test_bfloat_add_and_cast_1():
X = te.placeholder((3, ), name="X")
# specific language governing permissions and limitations
# under the License.
import tvm
+import pytest
from tvm import te
# The following DLDeviceType/TVMDeviceExtType values
# are originally defined in dlpack.h and c_runtime_api.h.
-gpu_devices = [2, 4, 7, 8, 10, 11]
-other_devices = [1, 3, 9, 12]
+gpu_devices = ["cuda", "opencl", "metal", "vulkan"]
+other_devices = ["llvm", "ext_dev"]
def lower(sch, args):
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
- func = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", arg_list, 0, True)
- return func
+
+ f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
+ "global_symbol", tvm.runtime.String("test"))
+ mod = tvm.IRModule({"test": f})
+ return tvm.tir.transform.MakePackedAPI()(mod)
# All computations are bound.
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
- func = lower(s, [A, B])
+ mod = lower(s, [A, B])
for dev_type in gpu_devices + other_devices:
- assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+ binded_mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+ tvm.tir.analysis.verify_memory(binded_mod)
+
# Computations are not bound.
# B is not bound to threads.
s = te.create_schedule(B.op)
- func = lower(s, [A, B])
+ mod = lower(s, [A, B])
for dev_type in gpu_devices:
- assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+ binded_mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+ with pytest.raises(ValueError):
+ tvm.tir.analysis.verify_memory(binded_mod)
+
for dev_type in other_devices:
- assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+ binded_mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+ tvm.tir.analysis.verify_memory(binded_mod)
# Computations are partially bound.
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
- func = lower(s, [A, B, C, D])
+ mod = lower(s, [A, B, C, D])
for dev_type in gpu_devices:
- assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+ binded_mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+ with pytest.raises(ValueError):
+ tvm.tir.analysis.verify_memory(binded_mod)
+
for dev_type in other_devices:
- assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
+ binded_mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
+ tvm.tir.analysis.verify_memory(binded_mod)
+
if __name__ == "__main__":
test_verify_memory_all_bind()
test_verify_memory_not_bind()
test_verify_memory_partially_bind()
-
s[B].vectorize(xi)
# build and invoke the kernel.
lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False)
- print (lowered_func.body)
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
# after instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 2)
- print (stmt)
+
branch_collector = list()
collect_visit(stmt, collect_branch_stmt)
assert(len(branch_collector) == 2)
- print (branch_collector[0].condition)
- print (branch_collector[1].condition)
+
def test_in_bounds_const_loop_partition_llvm():
with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
- print (lowered_func.body)
ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm")
if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
- print (lowered_func.body)
ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm")
T = te.compute((m, ), lambda i: A[i]*B[i])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
T = te.compute((m, ), lambda i: A[i]*B[i])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
+
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
- print (lowered_func.body)
+
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
D = te.compute((), lambda : C + 1)
s = te.create_schedule(D.op)
stmt = tvm.lower (s, [A, scale, D], simple_mode=True)
- print (stmt)
+
# build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm")
ctx = tvm.cpu(0)
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
- f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
- mod = tvm.testing.LoweredFuncsToIRModule([f])
+ mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
## But this does the right thing.
with tvm.target.build_config(partition_const_loop=True):
- lowered_body = tvm.lower(s, [A, B]).body
+ lowered_body = tvm.lower(s, [A, B], name="x")["x"].body
def visit_stmt(op):
return(isinstance(op, tvm.tir.Max))
num_max = collect_visit(lowered_body, visit_stmt)
# Find the beginning of the Halide IR corresponding to kernel code
# and make sure it doesn't have an if statements left
- top_produce = find_top_produce(f.body)
+ top_produce = find_top_produce(f["fadd1"].body)
assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
# check functional correctness of generated code
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
- f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
- f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
- mod = tvm.testing.LoweredFuncsToIRModule([f])
+ mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
ib.emit(tvm.tir.call_extern
("int32", "fadd", device_context(0), A))
body = ib.get()
- f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
-
- # temp adapter to convert loweredFunc to IRModule
- # to test passes in the new style.x
- mod = tvm.testing.LoweredFuncsToIRModule([f])
-
+ mod = tvm.testing.MakeAPILegacy(body, "func", [dev_type, n], 2, True)
mod = tvm.tir.transform.CombineContextCall()(mod)
assert mod["func"].body.value.dtype == "handle"
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
- f = tvm.lower(s, [A, B], name="f")
+ mod = tvm.lower(s, [A, B], name="f")
-
- mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
num_unpacked_args = 2
- f = tvm.tir.ir_pass.MakeAPI(
- stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True)
- assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
- assert(len(f.args) == 7)
- output_ssa = False
+ f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr(
+ "tir.no_alias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
+ mod = tvm.IRModule.from_expr(f)
+ f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
+ assert(len(f.params) == 7)
if __name__ == "__main__":
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
- f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
- cuda_target = tvm.target.create("cuda")
- mod = tvm.testing.LoweredFuncsToIRModule([f])
+ cuda_target = tvm.target.create("cuda")
+ mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
- Visitor design pattern. Otherwise, check the
`Python AST module <https://docs.python.org/3/library/ast.html>`_ to see how an AST
visitor is implemented.
-- How a HalideIR/Schedule is lowered to either a LoweredFunc class or a LLVM module. Otherwise,
+- How a Schedule is lowered to either an IRModule class or a LLVM module. Otherwise,
take a look at ``python/tvm/build_module.py`` to get some basics.
"""