* [REFACTOR] Polish ffi convention.
- Remove the src/api, keep registration local to the c++ function.
- Remove the api_internal as it is no longer needed.
* Update the codebase walk through
src/tir/*.cc
src/driver/*.cc
src/printer/*.cc
- src/api/*.cc
+ src/support/*.cc
)
file(GLOB CODEGEN_SRCS
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
-Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
+Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/te/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/te/tensor.h`` and ``src/te/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
::
The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
-``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:
-
-::
-
- TVM_REGISTER_GLOBAL("_ComputeOp")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ComputeOpNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4]);
- });
-
We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy.
+You can also checkout `FFI Navigator <https://github.com/tqchen/ffi-navigator>`_ which allows you to navigate between python and c++ FFI calls.
-A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
+A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/te/tensor.py``, ``include/tvm/te/operation.h``, and ``src/tvm/te/operation`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
-We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/schedule.py``.
+We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/te/schedule.py``.
::
``Stage`` corresponds to one ``Operation``. In the vector add example above, there are two placeholder ops and one compute op, so the schedule ``s`` contains three stages. Each ``Stage`` holds information about a loop nest structure, types of each loop (``Parallel``, ``Vectorized``, ``Unrolled``), and where to execute its computation in the loop nest of the next ``Stage``, if any.
-``Schedule`` and ``Stage`` are defined in ``tvm/python/schedule.py``, ``include/tvm/schedule.h``, and ``src/schedule/schedule_ops.cc``.
+``Schedule`` and ``Stage`` are defined in ``tvm/python/te/schedule.py``, ``include/tvm/te/schedule.h``, and ``src/te/schedule/schedule_ops.cc``.
To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above.
target = "cuda"
fadd = tvm.build(s, [A, B, C], target)
-``tvm.build()``, defined in ``python/tvm/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a ``tvm.Module`` object, defined in ``python/tvm/module.py``. A ``Module`` object contains a compiled function which can be invoked with function call syntax.
+``tvm.build()``, defined in ``python/tvm/driver/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a :py:class:`tvm.runtime.Module` object. A :py:class:`tvm.runtime.Module` object contains a compiled function which can be invoked with function call syntax.
The process of ``tvm.build()`` can be divided into two steps:
stmt = schedule.ScheduleOps(sch, bounds)
...
-Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
+Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/te/schedule/bound.cc``, ``src/te/schedule/graph.cc`` and ``src/te/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
.. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html
-``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``.
+``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/te/schedule/schedule_ops.cc``.
-Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
+Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/tir/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
::
After lowering is done, ``build()`` function generates target machine code from the lowered function. This code can contain SSE or AVX instructions if you target x86, or PTX instructions for CUDA target. In addition to target specific machine code, TVM also generates host side code that is responsible for memory management, kernel launch etc.
-Code generation is done by ``build_module()`` function, defined in ``python/tvm/codegen.py``. On the C++ side, code generation is implemented in ``src/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/codegen/codegen.cc``:
+Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
::
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Namespace of internal API
-
-The functions in this namespace are automatically exported from C++ side via PackedFunc
-that is registered by "TVM_REGISTER_*" macro. This way makes calling Python functions from C++
-side very easily.
-
-Each string starts with "_" in the "TVM_REGISTER_*" macro is an internal API. You can find
-all the functions in "api_lang.cc", "api_base.cc", "api_arith.cc" and "api_ir.cc" under "src/api".
-"""
"""FFI registry to register function and objects."""
import sys
import ctypes
-from .. import _api_internal
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
module = sys.modules[module_name]
for name in list_global_func_names():
- if prefix == "api":
- fname = name
- if name.startswith("_"):
- target_module = sys.modules["tvm._api_internal"]
- else:
- target_module = module
- else:
- if not name.startswith(prefix):
- continue
- fname = name[len(prefix)+1:]
- target_module = module
+ if not name.startswith(prefix):
+ continue
+
+ fname = name[len(prefix)+1:]
+ target_module = module
if fname.find(".") != -1:
continue
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Implementation of API functions related to arith
- * \file api_arith.cc
- */
-#include <tvm/arith/bound.h>
-#include <tvm/arith/int_set.h>
-#include <tvm/arith/pattern.h>
-#include <tvm/arith/analyzer.h>
-
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/te/tensor.h>
-
-namespace tvm {
-namespace arith {
-
-TVM_REGISTER_GLOBAL("arith.intset_single_point")
-.set_body_typed(IntSet::single_point);
-
-TVM_REGISTER_GLOBAL("arith.intset_vector")
-.set_body_typed(IntSet::vector);
-
-TVM_REGISTER_GLOBAL("arith.intset_interval")
-.set_body_typed(IntSet::interval);
-
-
-TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
-.set_body_typed(DetectLinearEquation);
-
-TVM_REGISTER_GLOBAL("arith.DetectClipBound")
-.set_body_typed(DetectClipBound);
-
-TVM_REGISTER_GLOBAL("arith.DeduceBound")
-.set_body_typed([](
- PrimExpr v, PrimExpr cond,
- const Map<Var, IntSet> hint_map,
- const Map<Var, IntSet> relax_map
-) {
- return DeduceBound(v, cond, hint_map, relax_map);
-});
-
-
-TVM_REGISTER_GLOBAL("arith.DomainTouched")
-.set_body_typed(DomainTouched);
-
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
-.set_body_method(&IntSet::min);
-
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
-.set_body_method(&IntSet::max);
-
-TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
-.set_body_method(&IntSet::is_nothing);
-
-TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
-.set_body_method(&IntSet::is_everything);
-
-ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
- return ConstIntBound(min_value, max_value);
-}
-
-TVM_REGISTER_GLOBAL("arith.ConstIntBound")
-.set_body_typed(MakeConstIntBound);
-
-ModularSet MakeModularSet(int64_t coeff, int64_t base) {
- return ModularSet(coeff, base);
-}
-
-TVM_REGISTER_GLOBAL("arith.ModularSet")
-.set_body_typed(MakeModularSet);
-
-TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- using runtime::PackedFunc;
- using runtime::TypedPackedFunc;
- auto self = std::make_shared<Analyzer>();
- auto f = [self](std::string name) -> PackedFunc {
- if (name == "const_int_bound") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->const_int_bound(args[0]);
- });
- } else if (name == "modular_set") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->modular_set(args[0]);
- });
- } else if (name == "const_int_bound_update") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- self->const_int_bound.Update(args[0], args[1], args[2]);
- });
- } else if (name == "Simplify") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->Simplify(args[0]);
- });
- } else if (name == "rewrite_simplify") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->rewrite_simplify(args[0]);
- });
- } else if (name == "canonical_simplify") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->canonical_simplify(args[0]);
- });
- } else if (name == "int_set") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- *ret = self->int_set(args[0], args[1]);
- });
- } else if (name == "bind") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- if (args[1].IsObjectRef<Range>()) {
- self->Bind(args[0], args[1].operator Range());
- } else {
- self->Bind(args[0], args[1].operator PrimExpr());
- }
- });
- } else if (name == "enter_constraint_context") {
- return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
- // can't use make_shared due to noexcept(false) decl in destructor,
- // see https://stackoverflow.com/a/43907314
- auto ctx = std::shared_ptr<With<ConstraintContext> >(
- new With<ConstraintContext>(self.get(), args[0]));
- auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
- ctx.reset();
- };
- *ret = PackedFunc(fexit);
- });
- }
- return PackedFunc();
- };
- *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
-});
-
-} // namespace arith
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Implementation of API functions related to IR build
- * \file api_ir.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/tir/op.h>
-
-namespace tvm {
-namespace tir {
-
-TVM_REGISTER_GLOBAL("tir.Var")
-.set_body_typed([](std::string s, DataType t) {
- return Var(s, t);
- });
-
-TVM_REGISTER_GLOBAL("tir.SizeVar")
-.set_body_typed([](std::string s, DataType t) {
- return SizeVar(s, t);
- });
-
-TVM_REGISTER_GLOBAL("tir.abs")
-.set_body_typed(tvm::abs);
-
-TVM_REGISTER_GLOBAL("tir.isnan")
-.set_body_typed(tvm::isnan);
-
-TVM_REGISTER_GLOBAL("tir.floor")
-.set_body_typed(tvm::floor);
-
-TVM_REGISTER_GLOBAL("tir.ceil")
-.set_body_typed(tvm::ceil);
-
-TVM_REGISTER_GLOBAL("tir.round")
-.set_body_typed(tvm::round);
-
-TVM_REGISTER_GLOBAL("tir.nearbyint")
-.set_body_typed(tvm::nearbyint);
-
-TVM_REGISTER_GLOBAL("tir.trunc")
-.set_body_typed(tvm::trunc);
-
-TVM_REGISTER_GLOBAL("tir._cast")
-.set_body_typed(tvm::cast);
-
-TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
-.set_body_typed(Range::make_by_min_extent);
-
-
-TVM_REGISTER_GLOBAL("tir.SeqStmt")
-.set_body_typed([](Array<Stmt> seq) {
- return SeqStmt(std::move(seq));
-});
-
-TVM_REGISTER_GLOBAL("tir.For")
-.set_body_typed([](
- Var loop_var, PrimExpr min, PrimExpr extent,
- int for_type, int device_api, Stmt body) {
- return ForNode::make(loop_var,
- min,
- extent,
- static_cast<ForType>(for_type),
- static_cast<DeviceAPI>(device_api),
- body);
-});
-
-TVM_REGISTER_GLOBAL("tir.Load")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- DataType t = args[0];
- if (args.size() == 3) {
- *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
- } else {
- *ret = LoadNode::make(t, args[1], args[2], args[3]);
- }
- });
-
-TVM_REGISTER_GLOBAL("tir.Store")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- PrimExpr value = args[1];
- if (args.size() == 3) {
- *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
- } else {
- *ret = StoreNode::make(args[0], value, args[2], args[3]);
- }
- });
-
-TVM_REGISTER_GLOBAL("tir.Realize")
-.set_body_typed(RealizeNode::make);
-
-TVM_REGISTER_GLOBAL("tir.Call")
-.set_body_typed([](
- DataType type, std::string name,
- Array<PrimExpr> args, int call_type,
- FunctionRef func, int value_index
-) {
- return CallNode::make(type,
- name,
- args,
- static_cast<CallNode::CallType>(call_type),
- func,
- value_index);
-});
-
-TVM_REGISTER_GLOBAL("tir.CommReducer")
-.set_body_typed(CommReducerNode::make);
-
-// make from two arguments
-#define REGISTER_MAKE(NodeName) \
- TVM_REGISTER_GLOBAL("tir."#NodeName) \
- .set_body_typed(NodeName ## Node::make); \
-
-
-REGISTER_MAKE(Reduce);
-REGISTER_MAKE(AttrStmt);
-
-REGISTER_MAKE(StringImm);
-
-REGISTER_MAKE(Add);
-REGISTER_MAKE(Sub);
-REGISTER_MAKE(Mul);
-REGISTER_MAKE(Div);
-REGISTER_MAKE(Mod);
-REGISTER_MAKE(FloorDiv);
-REGISTER_MAKE(FloorMod);
-REGISTER_MAKE(Min);
-REGISTER_MAKE(Max);
-REGISTER_MAKE(EQ);
-REGISTER_MAKE(NE);
-REGISTER_MAKE(LT);
-REGISTER_MAKE(LE);
-REGISTER_MAKE(GT);
-REGISTER_MAKE(GE);
-REGISTER_MAKE(And);
-REGISTER_MAKE(Or);
-
-REGISTER_MAKE(Not);
-REGISTER_MAKE(Select);
-REGISTER_MAKE(Ramp);
-REGISTER_MAKE(Cast);
-REGISTER_MAKE(Broadcast);
-REGISTER_MAKE(Shuffle);
-REGISTER_MAKE(Let);
-REGISTER_MAKE(LetStmt);
-REGISTER_MAKE(AssertStmt);
-REGISTER_MAKE(ProducerConsumer);
-REGISTER_MAKE(Provide);
-REGISTER_MAKE(Prefetch);
-REGISTER_MAKE(Free);
-REGISTER_MAKE(IfThenElse);
-REGISTER_MAKE(Evaluate);
-
-// overloaded, needs special handling
-// has default args
-TVM_REGISTER_GLOBAL("tir.Allocate")
- .set_body_typed([](
- Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
- ){
- return AllocateNode::make(buffer_var, type, extents, condition, body);
- });
-
-// operator overloading, smarter than make
-#define REGISTER_MAKE_BINARY_OP(Node, Func) \
- TVM_REGISTER_GLOBAL("tir."#Node) \
- .set_body_typed([](PrimExpr a, PrimExpr b) { \
- return (Func(a, b)); \
- })
-
-#define REGISTER_MAKE_BIT_OP(Node, Func) \
- TVM_REGISTER_GLOBAL("tir."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- bool lhs_is_int = args[0].type_code() == kDLInt; \
- bool rhs_is_int = args[1].type_code() == kDLInt; \
- if (lhs_is_int) { \
- *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
- } else if (rhs_is_int) { \
- *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
- } else { \
- *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
- } \
- })
-
-
-REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
-REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
-REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
-REGISTER_MAKE_BINARY_OP(_OpDiv, div);
-REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
-REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
-REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
-REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
-REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
-REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpPow, pow);
-REGISTER_MAKE_BINARY_OP(_OpMin, min);
-REGISTER_MAKE_BINARY_OP(_OpMax, max);
-REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
-REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
-REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
-REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
-REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
-REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
-REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
-REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
-REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
-REGISTER_MAKE_BIT_OP(right_shift, operator>>);
-TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
-.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
- return if_then_else(cond, true_value, false_value);
-});
-
-} // namespace tir
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Implementation of API functions related to Higher DSL build.
- * \file api_lang.cc
- */
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
-#include <tvm/te/operation.h>
-#include <tvm/tir/buffer.h>
-#include <tvm/te/schedule.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/driver/driver_api.h>
-#include <tvm/tir/data_layout.h>
-
-namespace tvm {
-
-TVM_REGISTER_GLOBAL("tir.min_value")
-.set_body_typed(min_value);
-
-TVM_REGISTER_GLOBAL("tir.max_value")
-.set_body_typed(max_value);
-
-TVM_REGISTER_GLOBAL("ir.Range")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = Range(args[0], args[1]);
- });
-
-namespace tir {
-TVM_REGISTER_GLOBAL("tir.IterVar")
-.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
- return IterVarNode::make(
- dom, var,
- static_cast<IterVarType>(iter_type),
- thread_tag);
-});
-}
-
-namespace te {
-TVM_REGISTER_GLOBAL("te.Tensor")
-.set_body_typed(TensorNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorIntrin")
-.set_body_typed(TensorIntrinNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
-.set_body_typed(TensorIntrinCallNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorEqual")
-.set_body_method(&Tensor::operator==);
-
-TVM_REGISTER_GLOBAL("te.TensorHash")
-.set_body_typed([](Tensor tensor) -> int64_t {
- return static_cast<int64_t>(std::hash<Tensor>()(tensor));
- });
-
-TVM_REGISTER_GLOBAL("te.Placeholder")
-.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
- return placeholder(shape, dtype, name);
-});
-
-TVM_REGISTER_GLOBAL("te.ComputeOp")
-.set_body_typed(ComputeOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.ScanOp")
-.set_body_typed(ScanOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorComputeOp")
-.set_body_typed(TensorComputeOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.ExternOp")
-.set_body_typed(ExternOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.HybridOp")
-.set_body_typed(HybridOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.OpGetOutput")
-.set_body_typed([](Operation op, int64_t output) {
- return op.output(static_cast<size_t>(output));
-});
-
-TVM_REGISTER_GLOBAL("te.OpNumOutputs")
-.set_body_method<Operation>(&OperationNode::num_outputs);
-
-TVM_REGISTER_GLOBAL("te.OpInputTensors")
-.set_body_method<Operation>(&OperationNode::InputTensors);
-
-TVM_REGISTER_GLOBAL("te.CreateSchedule")
-.set_body_typed(create_schedule);
-
-TVM_REGISTER_GLOBAL("te.StageSetScope")
-.set_body_method(&Stage::set_scope);
-
-TVM_REGISTER_GLOBAL("te.StageBind")
-.set_body_method(&Stage::bind);
-
-TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
- IterVar outer, inner;
- stage.split(parent, factor, &outer, &inner);
- return Array<IterVar>({outer, inner});
-});
-
-TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
- IterVar outer, inner;
- stage.split_by_nparts(parent, nparts, &outer, &inner);
- return Array<IterVar>({outer, inner});
-});
-
-TVM_REGISTER_GLOBAL("te.StageFuse")
-.set_body_typed([](Stage stage, Array<IterVar> axes) {
- IterVar fused;
- stage.fuse(axes, &fused);
- return fused;
- });
-
-TVM_REGISTER_GLOBAL("te.StageComputeAt")
-.set_body_method(&Stage::compute_at);
-
-TVM_REGISTER_GLOBAL("te.StageComputeInline")
-.set_body_method(&Stage::compute_inline);
-
-TVM_REGISTER_GLOBAL("te.StageComputeRoot")
-.set_body_method(&Stage::compute_root);
-
-TVM_REGISTER_GLOBAL("te.StageReorder")
-.set_body_method(&Stage::reorder);
-
-TVM_REGISTER_GLOBAL("te.StageTile")
-.set_body_typed([](
- Stage stage,
- IterVar x_parent, IterVar y_parent,
- PrimExpr x_factor, PrimExpr y_factor
-) {
- IterVar x_outer, y_outer, x_inner, y_inner;
- stage.tile(x_parent, y_parent,
- x_factor, y_factor,
- &x_outer, &y_outer,
- &x_inner, &y_inner);
- return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
- });
-
-TVM_REGISTER_GLOBAL("te.StageEnvThreads")
-.set_body_method(&Stage::env_threads);
-
-TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
-.set_body_method(&Stage::set_store_predicate);
-
-TVM_REGISTER_GLOBAL("te.StageUnroll")
-.set_body_method(&Stage::unroll);
-
-TVM_REGISTER_GLOBAL("te.StageVectorize")
-.set_body_method(&Stage::vectorize);
-
-TVM_REGISTER_GLOBAL("te.StageTensorize")
-.set_body_method(&Stage::tensorize);
-
-TVM_REGISTER_GLOBAL("te.StageParallel")
-.set_body_method(&Stage::parallel);
-
-TVM_REGISTER_GLOBAL("te.StagePragma")
-.set_body_method(&Stage::pragma);
-
-TVM_REGISTER_GLOBAL("te.StagePrefetch")
-.set_body_method(&Stage::prefetch);
-
-TVM_REGISTER_GLOBAL("te.StageStorageAlign")
-.set_body_method(&Stage::storage_align);
-
-TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
-.set_body_method(&Stage::double_buffer);
-
-TVM_REGISTER_GLOBAL("te.StageOpenGL")
-.set_body_method(&Stage::opengl);
-
-TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
-.set_body_method(&Schedule::normalize);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
-.set_body_method(&Schedule::create_group);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
-.set_body_method(&Schedule::cache_read);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- if (args[1].IsObjectRef<Tensor>()) {
- *ret = args[0].operator Schedule()
- .cache_write(args[1].operator Tensor(), args[2]);
- } else {
- *ret = args[0].operator Schedule()
- .cache_write(args[1].operator Array<Tensor>(), args[2]);
- }
- });
-
-TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
-.set_body_method(&Schedule::rfactor);
-} // namespace te
-
-TVM_REGISTER_GLOBAL("te.CommReducerCombine")
-.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
-
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Implementation of API functions related to schedule pass.
- * \file api_schedule.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
-#include <tvm/te/schedule.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/runtime/registry.h>
-
-#include "../te/schedule/graph.h"
-
-namespace tvm {
-namespace te {
-
-TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
-.set_body_typed(AutoInlineElemWise);
-
-
-TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
-.set_body_typed(AutoInlineInjective);
-
-TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- if (args.size() == 2)
- *ret = ScheduleOps(args[0], args[1], false);
- else
- *ret = ScheduleOps(args[0], args[1], args[2]);
-});
-
-#define REGISTER_SCHEDULE_PASS(PassName) \
- TVM_REGISTER_GLOBAL("schedule."#PassName) \
- .set_body_typed(PassName); \
-
-
-REGISTER_SCHEDULE_PASS(InferBound);
-REGISTER_SCHEDULE_PASS(CreateReadGraph);
-REGISTER_SCHEDULE_PASS(PostDFSOrder);
-REGISTER_SCHEDULE_PASS(CreateAttachPath);
-REGISTER_SCHEDULE_PASS(ScanGetBody);
-REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
-
-} // namespace te
-} // namespace tvm
/*!
* \file tvm/arith/analyzer.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
return res;
}
+TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ using runtime::PackedFunc;
+ using runtime::TypedPackedFunc;
+ auto self = std::make_shared<Analyzer>();
+ auto f = [self](std::string name) -> PackedFunc {
+ if (name == "const_int_bound") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ *ret = self->const_int_bound(args[0]);
+ });
+ } else if (name == "modular_set") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ *ret = self->modular_set(args[0]);
+ });
+ } else if (name == "const_int_bound_update") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ self->const_int_bound.Update(args[0], args[1], args[2]);
+ });
+ } else if (name == "Simplify") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ *ret = self->Simplify(args[0]);
+ });
+ } else if (name == "rewrite_simplify") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ *ret = self->rewrite_simplify(args[0]);
+ });
+ } else if (name == "canonical_simplify") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ *ret = self->canonical_simplify(args[0]);
+ });
+ } else if (name == "int_set") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ *ret = self->int_set(args[0], args[1]);
+ });
+ } else if (name == "bind") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ if (args[1].IsObjectRef<Range>()) {
+ self->Bind(args[0], args[1].operator Range());
+ } else {
+ self->Bind(args[0], args[1].operator PrimExpr());
+ }
+ });
+ } else if (name == "enter_constraint_context") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ // can't use make_shared due to noexcept(false) decl in destructor,
+ // see https://stackoverflow.com/a/43907314
+ auto ctx = std::shared_ptr<With<ConstraintContext> >(
+ new With<ConstraintContext>(self.get(), args[0]));
+ auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
+ ctx.reset();
+ };
+ *ret = PackedFunc(fexit);
+ });
+ }
+ return PackedFunc();
+ };
+ *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
+});
+
} // namespace arith
} // namespace tvm
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/arith/analyzer.h>
-#include <tvm/runtime/registry.h>
#include <unordered_set>
#include <unordered_map>
return DeduceBound(v, e, hmap, rmap);
}
+
+TVM_REGISTER_GLOBAL("arith.DeduceBound")
+.set_body_typed([](
+ PrimExpr v, PrimExpr cond,
+ const Map<Var, IntSet> hint_map,
+ const Map<Var, IntSet> relax_map
+) {
+ return DeduceBound(v, cond, hint_map, relax_map);
+});
+
+
} // namespace arith
} // namespace tvm
/*!
* \file tvm/arith/const_int_bound.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr_functor.h>
#include <algorithm>
data_ = std::move(node);
}
+ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
+ return ConstIntBound(min_value, max_value);
+}
+
+TVM_REGISTER_GLOBAL("arith.ConstIntBound")
+.set_body_typed(MakeConstIntBound);
+
inline void PrintBoundValue(std::ostream& os, int64_t val) {
if (val == ConstIntBound::kPosInf) {
os << "pos_inf";
* \file detect_linear_equation.cc
* \brief Utility to detect patterns in the expression.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr_functor.h>
return ret;
}
+TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
+.set_body_typed(DetectLinearEquation);
+TVM_REGISTER_GLOBAL("arith.DetectClipBound")
+.set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
+ return DetectClipBound(e, vars);
+});
} // namespace arith
} // namespace tvm
return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
}
+TVM_REGISTER_GLOBAL("arith.DomainTouched")
+.set_body_typed(DomainTouched);
+
} // namespace arith
} // namespace tvm
<< "[" << op->min_value << ", "
<< op->max_value << ']';
});
+
+
+TVM_REGISTER_GLOBAL("arith.intset_single_point")
+.set_body_typed(IntSet::single_point);
+
+TVM_REGISTER_GLOBAL("arith.intset_vector")
+.set_body_typed(IntSet::vector);
+
+TVM_REGISTER_GLOBAL("arith.intset_interval")
+.set_body_typed(IntSet::interval);
+
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
+.set_body_method(&IntSet::min);
+
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
+.set_body_method(&IntSet::max);
+
+TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
+.set_body_method(&IntSet::is_nothing);
+
+TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
+.set_body_method(&IntSet::is_everything);
+
} // namespace arith
} // namespace tvm
* \file modular_set.cc
* \brief Modular set analysis
*/
+#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/expr_functor.h>
<< op->base << ')';
});
+ModularSet MakeModularSet(int64_t coeff, int64_t base) {
+ return ModularSet(coeff, base);
+}
+
+TVM_REGISTER_GLOBAL("arith.ModularSet")
+.set_body_typed(MakeModularSet);
// internal entry for const int bound
struct ModularSetAnalyzer::Entry {
return Range(make_object<RangeNode>(min, extent));
}
+TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
+.set_body_typed(Range::make_by_min_extent);
+
+TVM_REGISTER_GLOBAL("ir.Range")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = Range(args[0], args[1]);
+ });
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
*/
/*!
- * Code mainly used for test purposes.
- * \file api_test.cc
+ * FFI registration code used for frontend testing purposes.
+ * \file ffi_testing.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/ir/attrs.h>
-#include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h>
namespace tvm {
* \brief Compute Op.
* \file compute_op.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
return Operation(n);
}
+TVM_REGISTER_GLOBAL("te.ComputeOp")
+.set_body_typed(ComputeOpNode::make);
+
+
// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
Array<Tensor> ret;
* \brief External computation rule.
* \file extern_op.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
return Operation(n);
}
+TVM_REGISTER_GLOBAL("te.ExternOp")
+.set_body_typed(ExternOpNode::make);
+
+
Array<Tensor> ExternOpNode::InputTensors() const {
return inputs;
}
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
return res;
}
+TVM_REGISTER_GLOBAL("te.HybridOp")
+.set_body_typed(HybridOpNode::make);
+
+
Array<Tensor> HybridOpNode::InputTensors() const {
// Because input tensors could be potentially inlined into hybrid scripts,
// we need to check if all input tensors are used in the body.
* \brief Placeholder op.
* \file placeholder_op.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
namespace tvm {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
+TVM_REGISTER_GLOBAL("te.Placeholder")
+.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
+ return placeholder(shape, dtype, name);
+});
+
Array<Tensor> PlaceholderOpNode::InputTensors() const {
return {};
}
* \brief Scan Operator.
* \file scan_op.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
return Operation(n);
}
+TVM_REGISTER_GLOBAL("te.ScanOp")
+.set_body_typed(ScanOpNode::make);
+
+
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
* \brief Tensor Compute Op.
* \file tensor_compute_op.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
return Operation(n);
}
+TVM_REGISTER_GLOBAL("te.TensorComputeOp")
+.set_body_typed(TensorComputeOpNode::make);
+
+
Array<Tensor> TensorComputeOpNode::InputTensors() const {
return inputs;
}
/*!
* \file auto_inline_elem_wise.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr_functor.h>
}
}
+TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
+.set_body_typed(AutoInlineElemWise);
+
+
+TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
+.set_body_typed(AutoInlineInjective);
+
} // namespace te
} // namespace tvm
* \file bound.cc
* \brief The bound inference logic.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/te/operation.h>
#include <tvm/tir/ir_pass.h>
return Map<IterVar, Range>(ret.begin(), ret.end());
}
+TVM_REGISTER_GLOBAL("schedule.InferBound")
+.set_body_typed(InferBound);
+
} // namespace te
} // namespace tvm
* \file graph.cc
* \brief Utilities to get information about schedule graph.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/te/operation.h>
return ret;
}
+
+TVM_REGISTER_GLOBAL("schedule.CreateReadGraph")
+.set_body_typed(CreateReadGraph);
+
+TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
+.set_body_typed([](const Array<Operation>& roots,
+ const ReadGraph& g) {
+ return PostDFSOrder(roots, g);
+});
+
+TVM_REGISTER_GLOBAL("schedule.CreateAttachPath")
+.set_body_typed(CreateAttachPath);
+
+TVM_REGISTER_GLOBAL("schedule.ScanGetBody")
+.set_body_typed(ScanGetBody);
+
+TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis")
+.set_body_typed(ScanFixPointAnalysis);
+
} // namespace te
} // namespace tvm
/*!
* \file schedule_lang.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <unordered_set>
auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")";
});
+
+
+TVM_REGISTER_GLOBAL("te.CreateSchedule")
+.set_body_typed(create_schedule);
+
+TVM_REGISTER_GLOBAL("te.StageSetScope")
+.set_body_method(&Stage::set_scope);
+
+TVM_REGISTER_GLOBAL("te.StageBind")
+.set_body_method(&Stage::bind);
+
+TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
+ IterVar outer, inner;
+ stage.split(parent, factor, &outer, &inner);
+ return Array<IterVar>({outer, inner});
+});
+
+TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
+ IterVar outer, inner;
+ stage.split_by_nparts(parent, nparts, &outer, &inner);
+ return Array<IterVar>({outer, inner});
+});
+
+TVM_REGISTER_GLOBAL("te.StageFuse")
+.set_body_typed([](Stage stage, Array<IterVar> axes) {
+ IterVar fused;
+ stage.fuse(axes, &fused);
+ return fused;
+ });
+
+TVM_REGISTER_GLOBAL("te.StageComputeAt")
+.set_body_method(&Stage::compute_at);
+
+TVM_REGISTER_GLOBAL("te.StageComputeInline")
+.set_body_method(&Stage::compute_inline);
+
+TVM_REGISTER_GLOBAL("te.StageComputeRoot")
+.set_body_method(&Stage::compute_root);
+
+TVM_REGISTER_GLOBAL("te.StageReorder")
+.set_body_method(&Stage::reorder);
+
+TVM_REGISTER_GLOBAL("te.StageTile")
+.set_body_typed([](
+ Stage stage,
+ IterVar x_parent, IterVar y_parent,
+ PrimExpr x_factor, PrimExpr y_factor
+) {
+ IterVar x_outer, y_outer, x_inner, y_inner;
+ stage.tile(x_parent, y_parent,
+ x_factor, y_factor,
+ &x_outer, &y_outer,
+ &x_inner, &y_inner);
+ return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
+ });
+
+TVM_REGISTER_GLOBAL("te.StageEnvThreads")
+.set_body_method(&Stage::env_threads);
+
+TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
+.set_body_method(&Stage::set_store_predicate);
+
+TVM_REGISTER_GLOBAL("te.StageUnroll")
+.set_body_method(&Stage::unroll);
+
+TVM_REGISTER_GLOBAL("te.StageVectorize")
+.set_body_method(&Stage::vectorize);
+
+TVM_REGISTER_GLOBAL("te.StageTensorize")
+.set_body_method(&Stage::tensorize);
+
+TVM_REGISTER_GLOBAL("te.StageParallel")
+.set_body_method(&Stage::parallel);
+
+TVM_REGISTER_GLOBAL("te.StagePragma")
+.set_body_method(&Stage::pragma);
+
+TVM_REGISTER_GLOBAL("te.StagePrefetch")
+.set_body_method(&Stage::prefetch);
+
+TVM_REGISTER_GLOBAL("te.StageStorageAlign")
+.set_body_method(&Stage::storage_align);
+
+TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
+.set_body_method(&Stage::double_buffer);
+
+TVM_REGISTER_GLOBAL("te.StageOpenGL")
+.set_body_method(&Stage::opengl);
+
+TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
+.set_body_method(&Schedule::normalize);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
+.set_body_method(&Schedule::create_group);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
+.set_body_method(&Schedule::cache_read);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ if (args[1].IsObjectRef<Tensor>()) {
+ *ret = args[0].operator Schedule()
+ .cache_write(args[1].operator Tensor(), args[2]);
+ } else {
+ *ret = args[0].operator Schedule()
+ .cache_write(args[1].operator Array<Tensor>(), args[2]);
+ }
+ });
+
+TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
+.set_body_method(&Schedule::rfactor);
} // namespace te
} // namespace tvm
/*!
* \file schedule_ops.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
return post_proc(std::move(body));
}
+TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ if (args.size() == 2)
+ *ret = ScheduleOps(args[0], args[1], false);
+ else
+ *ret = ScheduleOps(args[0], args[1], args[2]);
+});
+
} // namespace te
} // namespace tvm
/*!
* \file tensor.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/te/tensor.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor_intrin.h>
TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
+TVM_REGISTER_GLOBAL("te.Tensor")
+.set_body_typed(TensorNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorIntrin")
+.set_body_typed(TensorIntrinNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
+.set_body_typed(TensorIntrinCallNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorEqual")
+.set_body_method(&Tensor::operator==);
+
+TVM_REGISTER_GLOBAL("te.TensorHash")
+.set_body_typed([](Tensor tensor) -> int64_t {
+ return static_cast<int64_t>(std::hash<Tensor>()(tensor));
+ });
+
+TVM_REGISTER_GLOBAL("te.OpGetOutput")
+.set_body_typed([](Operation op, int64_t output) {
+ return op.output(static_cast<size_t>(output));
+});
+
+TVM_REGISTER_GLOBAL("te.OpNumOutputs")
+.set_body_method<Operation>(&OperationNode::num_outputs);
+
+TVM_REGISTER_GLOBAL("te.OpInputTensors")
+.set_body_method<Operation>(&OperationNode::InputTensors);
+
} // namespace te
} // namespace tvm
/*!
* \file expr.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
: VarNode(t, std::move(name_hint)) {}
+
+TVM_REGISTER_GLOBAL("tir.Var")
+.set_body_typed([](std::string s, DataType t) {
+ return Var(s, t);
+ });
+
+TVM_REGISTER_GLOBAL("tir.SizeVar")
+.set_body_typed([](std::string s, DataType t) {
+ return SizeVar(s, t);
+ });
+
IterVar IterVarNode::make(Range dom,
Var var,
IterVarType t,
return IterVar(n);
}
+TVM_REGISTER_GLOBAL("tir.IterVar")
+.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
+ return IterVarNode::make(
+ dom, var,
+ static_cast<IterVarType>(iter_type),
+ thread_tag);
+});
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
return PrimExpr(node);
}
+TVM_REGISTER_GLOBAL("tir.StringImm")
+.set_body_typed(StringImmNode::make);
+
PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
});
}
+TVM_REGISTER_GLOBAL("tir.CommReducer")
+.set_body_typed(CommReducerNode::make);
+
+TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
+.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
+
+
PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
Array<IterVar> axis, PrimExpr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
return PrimExpr(n);
}
+
+TVM_REGISTER_GLOBAL("tir.Reduce")
+.set_body_typed(ReduceNode::make);
+
+
PrimExpr AnyNode::make() {
auto n = make_object<AnyNode>();
return PrimExpr(n);
TVM_REGISTER_NODE_TYPE(ReduceNode);
TVM_REGISTER_NODE_TYPE(AnyNode);
+
+TVM_REGISTER_GLOBAL("tir.Add")
+.set_body_typed(AddNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Sub")
+.set_body_typed(SubNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Mul")
+.set_body_typed(MulNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Div")
+.set_body_typed(DivNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Mod")
+.set_body_typed(ModNode::make);
+
+TVM_REGISTER_GLOBAL("tir.FloorDiv")
+.set_body_typed(FloorDivNode::make);
+
+TVM_REGISTER_GLOBAL("tir.FloorMod")
+.set_body_typed(FloorModNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Min")
+.set_body_typed(MinNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Max")
+.set_body_typed(MaxNode::make);
+
+TVM_REGISTER_GLOBAL("tir.EQ")
+.set_body_typed(EQNode::make);
+
+TVM_REGISTER_GLOBAL("tir.NE")
+.set_body_typed(NENode::make);
+
+TVM_REGISTER_GLOBAL("tir.LT")
+.set_body_typed(LTNode::make);
+
+TVM_REGISTER_GLOBAL("tir.LE")
+.set_body_typed(LENode::make);
+
+TVM_REGISTER_GLOBAL("tir.GT")
+.set_body_typed(GTNode::make);
+
+TVM_REGISTER_GLOBAL("tir.GE")
+.set_body_typed(GENode::make);
+
+TVM_REGISTER_GLOBAL("tir.And")
+.set_body_typed(AndNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Or")
+.set_body_typed(OrNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Not")
+.set_body_typed(NotNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Select")
+.set_body_typed(SelectNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Ramp")
+.set_body_typed(RampNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Cast")
+.set_body_typed(CastNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Broadcast")
+.set_body_typed(BroadcastNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Shuffle")
+.set_body_typed(ShuffleNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Let")
+.set_body_typed(LetNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Load")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ DataType t = args[0];
+ if (args.size() == 3) {
+ *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
+ } else {
+ *ret = LoadNode::make(t, args[1], args[2], args[3]);
+ }
+ });
+
+
+
+TVM_REGISTER_GLOBAL("tir.Call")
+.set_body_typed([](
+ DataType type, std::string name,
+ Array<PrimExpr> args, int call_type,
+ FunctionRef func, int value_index
+) {
+ return CallNode::make(type,
+ name,
+ args,
+ static_cast<CallNode::CallType>(call_type),
+ func,
+ value_index);
+});
+
} // namespace tir
} // namespace tvm
TVM_REGISTER_GLOBAL("node.String")
.set_body_typed(tir::StringImmNode::make);
+TVM_REGISTER_GLOBAL("tir.min_value")
+.set_body_typed(min_value);
+
+TVM_REGISTER_GLOBAL("tir.max_value")
+.set_body_typed(max_value);
+
+TVM_REGISTER_GLOBAL("tir.abs")
+.set_body_typed(tvm::abs);
+
+TVM_REGISTER_GLOBAL("tir.isnan")
+.set_body_typed(tvm::isnan);
+
+TVM_REGISTER_GLOBAL("tir.floor")
+.set_body_typed(tvm::floor);
+
+TVM_REGISTER_GLOBAL("tir.ceil")
+.set_body_typed(tvm::ceil);
+
+TVM_REGISTER_GLOBAL("tir.round")
+.set_body_typed(tvm::round);
+
+TVM_REGISTER_GLOBAL("tir.nearbyint")
+.set_body_typed(tvm::nearbyint);
+
+TVM_REGISTER_GLOBAL("tir.trunc")
+.set_body_typed(tvm::trunc);
+
+TVM_REGISTER_GLOBAL("tir._cast")
+.set_body_typed(tvm::cast);
+
+
+
+// operator overloading, smarter than make
+#define REGISTER_MAKE_BINARY_OP(Node, Func) \
+ TVM_REGISTER_GLOBAL("tir."#Node) \
+ .set_body_typed([](PrimExpr a, PrimExpr b) { \
+ return (Func(a, b)); \
+ })
+
+#define REGISTER_MAKE_BIT_OP(Node, Func) \
+ TVM_REGISTER_GLOBAL("tir."#Node) \
+ .set_body([](TVMArgs args, TVMRetValue *ret) { \
+ bool lhs_is_int = args[0].type_code() == kDLInt; \
+ bool rhs_is_int = args[1].type_code() == kDLInt; \
+ if (lhs_is_int) { \
+ *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
+ } else if (rhs_is_int) { \
+ *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
+ } else { \
+ *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
+ } \
+ })
+
+
+REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
+REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
+REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
+REGISTER_MAKE_BINARY_OP(_OpDiv, div);
+REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
+REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
+REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
+REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
+REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
+REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpPow, pow);
+REGISTER_MAKE_BINARY_OP(_OpMin, min);
+REGISTER_MAKE_BINARY_OP(_OpMax, max);
+REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
+REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
+REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
+REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
+REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
+REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
+REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
+REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
+REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
+REGISTER_MAKE_BIT_OP(right_shift, operator>>);
+
+TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
+.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
+ return if_then_else(cond, true_value, false_value);
+});
} // namespace tvm
/*!
* \file tvm/tir/stmt.cc
*/
-
+#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
#include "../pass/ir_util.h"
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.LetStmt")
+.set_body_typed(LetStmtNode::make);
+
Stmt AttrStmtNode::make(ObjectRef node,
std::string attr_key,
PrimExpr value,
return Stmt(n);
}
+TVM_REGISTER_GLOBAL("tir.AttrStmt")
+.set_body_typed(AttrStmtNode::make);
+
+
Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) ||
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.AssertStmt")
+.set_body_typed(AssertStmtNode::make);
+
+
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
+.set_body_typed(ProducerConsumerNode::make);
+
+
Stmt ForNode::make(Var loop_var,
PrimExpr min,
PrimExpr extent,
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.For")
+.set_body_typed([](
+ Var loop_var, PrimExpr min, PrimExpr extent,
+ int for_type, int device_api, Stmt body) {
+ return ForNode::make(loop_var,
+ min,
+ extent,
+ static_cast<ForType>(for_type),
+ static_cast<DeviceAPI>(device_api),
+ body);
+});
+
+
Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
CHECK(value.defined());
CHECK(index.defined());
return Stmt(node);
}
+
+TVM_REGISTER_GLOBAL("tir.Store")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+ PrimExpr value = args[1];
+ if (args.size() == 3) {
+ *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
+ } else {
+ *ret = StoreNode::make(args[0], value, args[2], args[3]);
+ }
+ });
+
+
Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
CHECK(value_index >=0 && value_index < func->num_outputs())
<< "value index output function return value bound";
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.Provide")
+.set_body_typed(ProvideNode::make);
+
+
Stmt AllocateNode::make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
return Stmt(node);
}
+// overloaded, needs special handling
+// has default args
+TVM_REGISTER_GLOBAL("tir.Allocate")
+.set_body_typed([](
+ Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
+ ){
+ return AllocateNode::make(buffer_var, type, extents, condition, body);
+});
+
int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.Free")
+.set_body_typed(FreeNode::make);
+
+
Stmt RealizeNode::make(FunctionRef func,
- int value_index,
- DataType dtype,
- Region bounds,
- PrimExpr condition,
- Stmt body) {
+ int value_index,
+ DataType dtype,
+ Region bounds,
+ PrimExpr condition,
+ Stmt body) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
return Stmt(node);
}
+
+TVM_REGISTER_GLOBAL("tir.Realize")
+.set_body_typed(RealizeNode::make);
+
+
Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.Prefetch")
+.set_body_typed(PrefetchNode::make);
+
+
SeqStmt::SeqStmt(Array<Stmt> seq) {
auto node = make_object<SeqStmtNode>();
node->seq = std::move(seq);
data_ = std::move(node);
}
+TVM_REGISTER_GLOBAL("tir.SeqStmt")
+.set_body_typed([](Array<Stmt> seq) {
+ return SeqStmt(std::move(seq));
+});
+
Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
CHECK(condition.defined());
CHECK(then_case.defined());
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.IfThenElse")
+.set_body_typed(IfThenElseNode::make);
+
+
Stmt EvaluateNode::make(PrimExpr value) {
CHECK(value.defined());
return Stmt(node);
}
+TVM_REGISTER_GLOBAL("tir.Evaluate")
+.set_body_typed(EvaluateNode::make);
+
// Printers
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
/*!
* Exposure of pass functions.
- * \file api_pass.cc
+ * \file ffi_api.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
// make from two arguments
#define REGISTER_PASS(PassName) \
- TVM_REGISTER_GLOBAL("ir_pass."#PassName) \
- .set_body_typed(PassName); \
+ TVM_REGISTER_GLOBAL("ir_pass."#PassName) \
+ .set_body_typed(PassName); \
REGISTER_PASS(ConvertSSA);
except tvm.error.OpNotImplemented as e:
msg = str(e)
assert isinstance(e, NotImplementedError)
- assert msg.find("api_test.cc") != -1
+ assert msg.find("ffi_testing.cc") != -1
fchk_eq = tvm.testing.test_check_eq_callback(
"InternalError: myop")
assert False
except tvm.error.InternalError as e:
msg = str(e)
- assert msg.find("api_test.cc") != -1
+ assert msg.find("ffi_testing.cc") != -1
try:
tvm.testing.ErrorTest(0, 1)
assert False
except ValueError as e:
msg = str(e)
- assert msg.find("api_test.cc") != -1
+ assert msg.find("ffi_testing.cc") != -1
def test_deep_callback():