[REFACTOR] Polish ffi convention. (#4912)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 19 Feb 2020 21:14:42 +0000 (13:14 -0800)
committerGitHub <noreply@github.com>
Wed, 19 Feb 2020 21:14:42 +0000 (13:14 -0800)
* [REFACTOR] Polish ffi convention.

- Remove the src/api, keep registration local to the c++ function.
- Remove the api_internal as it is no longer needed.

* Update the codebase walk through

36 files changed:
CMakeLists.txt
docs/dev/codebase_walkthrough.rst
python/tvm/_api_internal.py [deleted file]
python/tvm/_ffi/registry.py
src/api/api_arith.cc [deleted file]
src/api/api_ir.cc [deleted file]
src/api/api_lang.cc [deleted file]
src/api/api_pass.cc [deleted file]
src/api/api_schedule.cc [deleted file]
src/api/api_test.cc [deleted file]
src/arith/analyzer.cc
src/arith/bound_deducer.cc
src/arith/const_int_bound.cc
src/arith/detect_linear_equation.cc
src/arith/domain_touched.cc
src/arith/int_set.cc
src/arith/modular_set.cc
src/ir/expr.cc
src/support/ffi_testing.cc [new file with mode: 0644]
src/te/operation/compute_op.cc
src/te/operation/extern_op.cc
src/te/operation/hybrid_op.cc
src/te/operation/placeholder_op.cc
src/te/operation/scan_op.cc
src/te/operation/tensor_compute_op.cc
src/te/schedule/auto_inline_elem_wise.cc
src/te/schedule/bound.cc
src/te/schedule/graph.cc
src/te/schedule/schedule_lang.cc
src/te/schedule/schedule_ops.cc
src/te/tensor.cc
src/tir/ir/expr.cc
src/tir/ir/op.cc
src/tir/ir/stmt.cc
src/tir/pass/ffi_api.cc [new file with mode: 0644]
tests/python/unittest/test_runtime_error.py

index 8540a661f99d2985fef993d69db9d83188c2f2db..9d25e4a9ba583fd8226c59dc61e7156244719659 100644 (file)
@@ -133,7 +133,7 @@ file(GLOB_RECURSE COMPILER_SRCS
     src/tir/*.cc
     src/driver/*.cc
     src/printer/*.cc
-    src/api/*.cc
+    src/support/*.cc
     )
 
 file(GLOB CODEGEN_SRCS
index 0732c26f0c5883ff8c8cd7be5979d1133004f176..8513ce5bd89d27dcd32f1dafaa6e7602861582b2 100644 (file)
@@ -55,7 +55,7 @@ We use a simple example that uses the low level TVM API directly. The example is
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
 
-Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
+Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/te/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/te/tensor.h`` and ``src/te/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
 
 ::
 
@@ -68,24 +68,12 @@ Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``pytho
 
 The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
 
-``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:
-
-::
-
-   TVM_REGISTER_GLOBAL("_ComputeOp")
-   .set_body([](TVMArgs args,  TVMRetValue* ret) {
-       *ret = ComputeOpNode::make(args[0],
-                                  args[1],
-                                  args[2],
-                                  args[3],
-                                  args[4]);
-     });
-
 We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy.
+You can also checkout `FFI Navigator <https://github.com/tqchen/ffi-navigator>`_ which allows you to navigate between python and c++ FFI calls.
 
-A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
+A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/te/tensor.py``, ``include/tvm/te/operation.h``, and ``src/tvm/te/operation`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
 
-We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/schedule.py``.
+We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/te/schedule.py``.
 
 ::
 
@@ -103,7 +91,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.
 
 ``Stage`` corresponds to one ``Operation``. In the vector add example above, there are two placeholder ops and one compute op, so the schedule ``s`` contains three stages. Each ``Stage`` holds information about a loop nest structure, types of each loop (``Parallel``, ``Vectorized``, ``Unrolled``), and where to execute its computation in the loop nest of the next ``Stage``, if any.
 
-``Schedule`` and ``Stage`` are defined in ``tvm/python/schedule.py``, ``include/tvm/schedule.h``, and ``src/schedule/schedule_ops.cc``.
+``Schedule`` and ``Stage`` are defined in ``tvm/python/te/schedule.py``, ``include/tvm/te/schedule.h``, and ``src/te/schedule/schedule_ops.cc``.
 
 To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above.
 
@@ -112,7 +100,7 @@ To keep it simple, we call ``tvm.build(...)`` on the default schedule created by
    target = "cuda"
    fadd = tvm.build(s, [A, B, C], target)
 
-``tvm.build()``, defined in ``python/tvm/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a ``tvm.Module`` object, defined in ``python/tvm/module.py``. A ``Module`` object contains a compiled function which can be invoked with function call syntax.
+``tvm.build()``, defined in ``python/tvm/driver/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a :py:class:`tvm.runtime.Module` object. A :py:class:`tvm.runtime.Module` object contains a compiled function which can be invoked with function call syntax.
 
 The process of ``tvm.build()`` can be divided into two steps:
 
@@ -133,14 +121,14 @@ Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_modu
       stmt = schedule.ScheduleOps(sch, bounds)
       ...
 
-Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
+Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/te/schedule/bound.cc``, ``src/te/schedule/graph.cc`` and ``src/te/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
 
 .. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html
 
 
-``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``.
+``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/te/schedule/schedule_ops.cc``.
 
-Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
+Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/tir/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
 
 ::
 
@@ -157,7 +145,7 @@ Next, we apply a number of lowering passes to ``stmt``. These passes are impleme
 
 After lowering is done, ``build()`` function generates target machine code from the lowered function. This code can contain SSE or AVX instructions if you target x86, or PTX instructions for CUDA target. In addition to target specific machine code, TVM also generates host side code that is responsible for memory management, kernel launch etc.
 
-Code generation is done by ``build_module()`` function, defined in ``python/tvm/codegen.py``. On the C++ side, code generation is implemented in ``src/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/codegen/codegen.cc``:
+Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
 
 ::
 
diff --git a/python/tvm/_api_internal.py b/python/tvm/_api_internal.py
deleted file mode 100644 (file)
index 5715237..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Namespace of internal API
-
-The functions in this namespace are automatically exported from C++ side via PackedFunc
-that is registered by "TVM_REGISTER_*" macro. This way makes calling Python functions from C++
-side very easily.
-
-Each string starts with "_" in the "TVM_REGISTER_*" macro is an internal API. You can find
-all the functions in "api_lang.cc", "api_base.cc", "api_arith.cc" and "api_ir.cc" under "src/api".
-"""
index be1578550a3b6c401fda0680de1b936d27b4c51b..e4b8b18b480580c8db8753c456f4cd5f59ee4255 100644 (file)
@@ -19,7 +19,6 @@
 """FFI registry to register function and objects."""
 import sys
 import ctypes
-from .. import _api_internal
 
 from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
 
@@ -288,17 +287,11 @@ def _init_api_prefix(module_name, prefix):
     module = sys.modules[module_name]
 
     for name in list_global_func_names():
-        if prefix == "api":
-            fname = name
-            if name.startswith("_"):
-                target_module = sys.modules["tvm._api_internal"]
-            else:
-                target_module = module
-        else:
-            if not name.startswith(prefix):
-                continue
-            fname = name[len(prefix)+1:]
-            target_module = module
+        if not name.startswith(prefix):
+            continue
+
+        fname = name[len(prefix)+1:]
+        target_module = module
 
         if fname.find(".") != -1:
             continue
diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc
deleted file mode 100644 (file)
index 3942f6e..0000000
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to arith
- * \file api_arith.cc
- */
-#include <tvm/arith/bound.h>
-#include <tvm/arith/int_set.h>
-#include <tvm/arith/pattern.h>
-#include <tvm/arith/analyzer.h>
-
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/te/tensor.h>
-
-namespace tvm {
-namespace arith {
-
-TVM_REGISTER_GLOBAL("arith.intset_single_point")
-.set_body_typed(IntSet::single_point);
-
-TVM_REGISTER_GLOBAL("arith.intset_vector")
-.set_body_typed(IntSet::vector);
-
-TVM_REGISTER_GLOBAL("arith.intset_interval")
-.set_body_typed(IntSet::interval);
-
-
-TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
-.set_body_typed(DetectLinearEquation);
-
-TVM_REGISTER_GLOBAL("arith.DetectClipBound")
-.set_body_typed(DetectClipBound);
-
-TVM_REGISTER_GLOBAL("arith.DeduceBound")
-.set_body_typed([](
-  PrimExpr v, PrimExpr cond,
-  const Map<Var, IntSet> hint_map,
-  const Map<Var, IntSet> relax_map
-) {
-  return DeduceBound(v, cond, hint_map, relax_map);
-});
-
-
-TVM_REGISTER_GLOBAL("arith.DomainTouched")
-.set_body_typed(DomainTouched);
-
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
-.set_body_method(&IntSet::min);
-
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
-.set_body_method(&IntSet::max);
-
-TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
-.set_body_method(&IntSet::is_nothing);
-
-TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
-.set_body_method(&IntSet::is_everything);
-
-ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
-  return ConstIntBound(min_value, max_value);
-}
-
-TVM_REGISTER_GLOBAL("arith.ConstIntBound")
-.set_body_typed(MakeConstIntBound);
-
-ModularSet MakeModularSet(int64_t coeff, int64_t base) {
-  return ModularSet(coeff, base);
-}
-
-TVM_REGISTER_GLOBAL("arith.ModularSet")
-.set_body_typed(MakeModularSet);
-
-TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    using runtime::PackedFunc;
-    using runtime::TypedPackedFunc;
-    auto self = std::make_shared<Analyzer>();
-    auto f = [self](std::string name) -> PackedFunc {
-      if (name == "const_int_bound") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->const_int_bound(args[0]);
-          });
-      } else if (name == "modular_set") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->modular_set(args[0]);
-        });
-      } else if (name == "const_int_bound_update") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            self->const_int_bound.Update(args[0], args[1], args[2]);
-        });
-      } else if (name == "Simplify") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->Simplify(args[0]);
-        });
-      } else if (name == "rewrite_simplify") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->rewrite_simplify(args[0]);
-        });
-      } else if (name == "canonical_simplify") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->canonical_simplify(args[0]);
-        });
-      } else if (name == "int_set") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->int_set(args[0], args[1]);
-        });
-      } else if (name == "bind") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            if (args[1].IsObjectRef<Range>()) {
-              self->Bind(args[0], args[1].operator Range());
-            } else {
-              self->Bind(args[0], args[1].operator PrimExpr());
-            }
-        });
-      } else if (name == "enter_constraint_context") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            // can't use make_shared due to noexcept(false) decl in destructor,
-            // see https://stackoverflow.com/a/43907314
-            auto ctx = std::shared_ptr<With<ConstraintContext> >(
-                new With<ConstraintContext>(self.get(), args[0]));
-            auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
-              ctx.reset();
-            };
-            *ret = PackedFunc(fexit);
-        });
-      }
-      return PackedFunc();
-    };
-    *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
-});
-
-}  // namespace arith
-}  // namespace tvm
diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc
deleted file mode 100644 (file)
index 1e71baf..0000000
+++ /dev/null
@@ -1,237 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to IR build
- * \file api_ir.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/tir/op.h>
-
-namespace tvm {
-namespace tir {
-
-TVM_REGISTER_GLOBAL("tir.Var")
-.set_body_typed([](std::string s, DataType t) {
-    return Var(s, t);
-  });
-
-TVM_REGISTER_GLOBAL("tir.SizeVar")
-.set_body_typed([](std::string s, DataType t) {
-    return SizeVar(s, t);
-  });
-
-TVM_REGISTER_GLOBAL("tir.abs")
-.set_body_typed(tvm::abs);
-
-TVM_REGISTER_GLOBAL("tir.isnan")
-.set_body_typed(tvm::isnan);
-
-TVM_REGISTER_GLOBAL("tir.floor")
-.set_body_typed(tvm::floor);
-
-TVM_REGISTER_GLOBAL("tir.ceil")
-.set_body_typed(tvm::ceil);
-
-TVM_REGISTER_GLOBAL("tir.round")
-.set_body_typed(tvm::round);
-
-TVM_REGISTER_GLOBAL("tir.nearbyint")
-.set_body_typed(tvm::nearbyint);
-
-TVM_REGISTER_GLOBAL("tir.trunc")
-.set_body_typed(tvm::trunc);
-
-TVM_REGISTER_GLOBAL("tir._cast")
-.set_body_typed(tvm::cast);
-
-TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
-.set_body_typed(Range::make_by_min_extent);
-
-
-TVM_REGISTER_GLOBAL("tir.SeqStmt")
-.set_body_typed([](Array<Stmt> seq) {
-  return SeqStmt(std::move(seq));
-});
-
-TVM_REGISTER_GLOBAL("tir.For")
-.set_body_typed([](
-  Var loop_var, PrimExpr min, PrimExpr extent,
-  int for_type, int device_api, Stmt body) {
-  return ForNode::make(loop_var,
-                   min,
-                   extent,
-                   static_cast<ForType>(for_type),
-                   static_cast<DeviceAPI>(device_api),
-                   body);
-});
-
-TVM_REGISTER_GLOBAL("tir.Load")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    DataType t = args[0];
-    if (args.size() == 3) {
-      *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
-    } else {
-      *ret = LoadNode::make(t, args[1], args[2], args[3]);
-    }
-  });
-
-TVM_REGISTER_GLOBAL("tir.Store")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    PrimExpr value = args[1];
-    if (args.size() == 3) {
-      *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
-    } else {
-      *ret = StoreNode::make(args[0], value, args[2], args[3]);
-    }
-  });
-
-TVM_REGISTER_GLOBAL("tir.Realize")
-.set_body_typed(RealizeNode::make);
-
-TVM_REGISTER_GLOBAL("tir.Call")
-.set_body_typed([](
-  DataType type, std::string name,
-  Array<PrimExpr> args, int call_type,
-  FunctionRef func, int value_index
-) {
-  return CallNode::make(type,
-                    name,
-                    args,
-                    static_cast<CallNode::CallType>(call_type),
-                    func,
-                    value_index);
-});
-
-TVM_REGISTER_GLOBAL("tir.CommReducer")
-.set_body_typed(CommReducerNode::make);
-
-// make from two arguments
-#define REGISTER_MAKE(NodeName)                                     \
-  TVM_REGISTER_GLOBAL("tir."#NodeName)                             \
-  .set_body_typed(NodeName ## Node::make);                          \
-
-
-REGISTER_MAKE(Reduce);
-REGISTER_MAKE(AttrStmt);
-
-REGISTER_MAKE(StringImm);
-
-REGISTER_MAKE(Add);
-REGISTER_MAKE(Sub);
-REGISTER_MAKE(Mul);
-REGISTER_MAKE(Div);
-REGISTER_MAKE(Mod);
-REGISTER_MAKE(FloorDiv);
-REGISTER_MAKE(FloorMod);
-REGISTER_MAKE(Min);
-REGISTER_MAKE(Max);
-REGISTER_MAKE(EQ);
-REGISTER_MAKE(NE);
-REGISTER_MAKE(LT);
-REGISTER_MAKE(LE);
-REGISTER_MAKE(GT);
-REGISTER_MAKE(GE);
-REGISTER_MAKE(And);
-REGISTER_MAKE(Or);
-
-REGISTER_MAKE(Not);
-REGISTER_MAKE(Select);
-REGISTER_MAKE(Ramp);
-REGISTER_MAKE(Cast);
-REGISTER_MAKE(Broadcast);
-REGISTER_MAKE(Shuffle);
-REGISTER_MAKE(Let);
-REGISTER_MAKE(LetStmt);
-REGISTER_MAKE(AssertStmt);
-REGISTER_MAKE(ProducerConsumer);
-REGISTER_MAKE(Provide);
-REGISTER_MAKE(Prefetch);
-REGISTER_MAKE(Free);
-REGISTER_MAKE(IfThenElse);
-REGISTER_MAKE(Evaluate);
-
-// overloaded, needs special handling
-// has default args
-TVM_REGISTER_GLOBAL("tir.Allocate")
-  .set_body_typed([](
-    Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
-  ){
-    return AllocateNode::make(buffer_var, type, extents, condition, body);
-  });
-
-// operator overloading, smarter than make
-#define REGISTER_MAKE_BINARY_OP(Node, Func)                    \
-  TVM_REGISTER_GLOBAL("tir."#Node)                              \
-  .set_body_typed([](PrimExpr a, PrimExpr b) {                  \
-    return (Func(a, b));                                        \
-  })
-
-#define REGISTER_MAKE_BIT_OP(Node, Func)                               \
-  TVM_REGISTER_GLOBAL("tir."#Node)                                      \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
-    bool lhs_is_int = args[0].type_code() == kDLInt;                    \
-    bool rhs_is_int = args[1].type_code() == kDLInt;                    \
-    if (lhs_is_int) {                                                   \
-      *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
-    } else if (rhs_is_int) {                                            \
-      *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
-    } else {                                                            \
-      *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
-    }                                                                   \
-  })
-
-
-REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
-REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
-REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
-REGISTER_MAKE_BINARY_OP(_OpDiv, div);
-REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
-REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
-REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
-REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
-REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
-REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpPow, pow);
-REGISTER_MAKE_BINARY_OP(_OpMin, min);
-REGISTER_MAKE_BINARY_OP(_OpMax, max);
-REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
-REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
-REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGT, operator>);  // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
-REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
-REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
-REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
-REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
-REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
-REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
-REGISTER_MAKE_BIT_OP(right_shift, operator>>);
-TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
-.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
-  return if_then_else(cond, true_value, false_value);
-});
-
-}  // namespace tir
-}  // namespace tvm
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
deleted file mode 100644 (file)
index 613b823..0000000
+++ /dev/null
@@ -1,223 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to Higher DSL build.
- * \file api_lang.cc
- */
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
-#include <tvm/te/operation.h>
-#include <tvm/tir/buffer.h>
-#include <tvm/te/schedule.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/driver/driver_api.h>
-#include <tvm/tir/data_layout.h>
-
-namespace tvm {
-
-TVM_REGISTER_GLOBAL("tir.min_value")
-.set_body_typed(min_value);
-
-TVM_REGISTER_GLOBAL("tir.max_value")
-.set_body_typed(max_value);
-
-TVM_REGISTER_GLOBAL("ir.Range")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-  *ret = Range(args[0], args[1]);
-  });
-
-namespace tir {
-TVM_REGISTER_GLOBAL("tir.IterVar")
-.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
-  return IterVarNode::make(
-      dom, var,
-      static_cast<IterVarType>(iter_type),
-      thread_tag);
-});
-}
-
-namespace te {
-TVM_REGISTER_GLOBAL("te.Tensor")
-.set_body_typed(TensorNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorIntrin")
-.set_body_typed(TensorIntrinNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
-.set_body_typed(TensorIntrinCallNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorEqual")
-.set_body_method(&Tensor::operator==);
-
-TVM_REGISTER_GLOBAL("te.TensorHash")
-.set_body_typed([](Tensor tensor) -> int64_t {
-    return static_cast<int64_t>(std::hash<Tensor>()(tensor));
-  });
-
-TVM_REGISTER_GLOBAL("te.Placeholder")
-.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
-  return placeholder(shape, dtype, name);
-});
-
-TVM_REGISTER_GLOBAL("te.ComputeOp")
-.set_body_typed(ComputeOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.ScanOp")
-.set_body_typed(ScanOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorComputeOp")
-.set_body_typed(TensorComputeOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.ExternOp")
-.set_body_typed(ExternOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.HybridOp")
-.set_body_typed(HybridOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.OpGetOutput")
-.set_body_typed([](Operation op, int64_t output) {
-  return op.output(static_cast<size_t>(output));
-});
-
-TVM_REGISTER_GLOBAL("te.OpNumOutputs")
-.set_body_method<Operation>(&OperationNode::num_outputs);
-
-TVM_REGISTER_GLOBAL("te.OpInputTensors")
-.set_body_method<Operation>(&OperationNode::InputTensors);
-
-TVM_REGISTER_GLOBAL("te.CreateSchedule")
-.set_body_typed(create_schedule);
-
-TVM_REGISTER_GLOBAL("te.StageSetScope")
-.set_body_method(&Stage::set_scope);
-
-TVM_REGISTER_GLOBAL("te.StageBind")
-.set_body_method(&Stage::bind);
-
-TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
-  IterVar outer, inner;
-  stage.split(parent, factor, &outer, &inner);
-  return Array<IterVar>({outer, inner});
-});
-
-TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
-  IterVar outer, inner;
-  stage.split_by_nparts(parent, nparts, &outer, &inner);
-  return Array<IterVar>({outer, inner});
-});
-
-TVM_REGISTER_GLOBAL("te.StageFuse")
-.set_body_typed([](Stage stage, Array<IterVar> axes) {
-    IterVar fused;
-    stage.fuse(axes, &fused);
-    return fused;
-  });
-
-TVM_REGISTER_GLOBAL("te.StageComputeAt")
-.set_body_method(&Stage::compute_at);
-
-TVM_REGISTER_GLOBAL("te.StageComputeInline")
-.set_body_method(&Stage::compute_inline);
-
-TVM_REGISTER_GLOBAL("te.StageComputeRoot")
-.set_body_method(&Stage::compute_root);
-
-TVM_REGISTER_GLOBAL("te.StageReorder")
-.set_body_method(&Stage::reorder);
-
-TVM_REGISTER_GLOBAL("te.StageTile")
-.set_body_typed([](
-  Stage stage,
-  IterVar x_parent, IterVar y_parent,
-  PrimExpr x_factor, PrimExpr y_factor
-) {
-    IterVar x_outer, y_outer, x_inner, y_inner;
-    stage.tile(x_parent, y_parent,
-               x_factor, y_factor,
-               &x_outer, &y_outer,
-               &x_inner, &y_inner);
-    return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
-  });
-
-TVM_REGISTER_GLOBAL("te.StageEnvThreads")
-.set_body_method(&Stage::env_threads);
-
-TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
-.set_body_method(&Stage::set_store_predicate);
-
-TVM_REGISTER_GLOBAL("te.StageUnroll")
-.set_body_method(&Stage::unroll);
-
-TVM_REGISTER_GLOBAL("te.StageVectorize")
-.set_body_method(&Stage::vectorize);
-
-TVM_REGISTER_GLOBAL("te.StageTensorize")
-.set_body_method(&Stage::tensorize);
-
-TVM_REGISTER_GLOBAL("te.StageParallel")
-.set_body_method(&Stage::parallel);
-
-TVM_REGISTER_GLOBAL("te.StagePragma")
-.set_body_method(&Stage::pragma);
-
-TVM_REGISTER_GLOBAL("te.StagePrefetch")
-.set_body_method(&Stage::prefetch);
-
-TVM_REGISTER_GLOBAL("te.StageStorageAlign")
-.set_body_method(&Stage::storage_align);
-
-TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
-.set_body_method(&Stage::double_buffer);
-
-TVM_REGISTER_GLOBAL("te.StageOpenGL")
-.set_body_method(&Stage::opengl);
-
-TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
-.set_body_method(&Schedule::normalize);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
-.set_body_method(&Schedule::create_group);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
-.set_body_method(&Schedule::cache_read);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    if (args[1].IsObjectRef<Tensor>()) {
-      *ret = args[0].operator Schedule()
-          .cache_write(args[1].operator Tensor(), args[2]);
-    } else {
-      *ret = args[0].operator Schedule()
-          .cache_write(args[1].operator Array<Tensor>(), args[2]);
-    }
-  });
-
-TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
-.set_body_method(&Schedule::rfactor);
-}  // namespace te
-
-TVM_REGISTER_GLOBAL("te.CommReducerCombine")
-.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
-
-}  // namespace tvm
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
deleted file mode 100644 (file)
index 75d5439..0000000
+++ /dev/null
@@ -1,181 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Exposure of pass functions.
- * \file api_pass.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/ir/attrs.h>
-#include <tvm/tir/ir_pass.h>
-#include <tvm/tir/expr_functor.h>
-#include <tvm/tir/stmt_functor.h>
-#include <tvm/runtime/registry.h>
-
-namespace tvm {
-namespace tir {
-
-TVM_REGISTER_GLOBAL("ir_pass.Simplify")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    if (args[0].IsObjectRef<Stmt>()) {
-      if (args.size() > 1) {
-        *ret = Simplify(args[0].operator Stmt(), args[1]);
-      } else {
-        *ret = Simplify(args[0].operator Stmt());
-      }
-    } else {
-      if (args.size() > 1) {
-        *ret = Simplify(args[0].operator PrimExpr(), args[1]);
-      } else {
-        *ret = Simplify(args[0].operator PrimExpr());
-      }
-    }
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    if (args[0].IsObjectRef<Stmt>()) {
-      if (args.size() > 1) {
-        *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
-      } else {
-        *ret = CanonicalSimplify(args[0].operator Stmt());
-      }
-    } else {
-      if (args.size() > 1) {
-        *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]);
-      } else {
-        *ret = CanonicalSimplify(args[0].operator PrimExpr());
-      }
-    }
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.Substitute")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    if (args[0].IsObjectRef<Stmt>()) {
-      *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, PrimExpr>());
-    } else {
-      *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map<Var, PrimExpr>());
-    }
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.Equal")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    if (args[0].IsObjectRef<Stmt>()) {
-      *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
-    } else {
-      *ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr());
-    }
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    if (args.size() <= 3) {
-      *ret = StorageFlatten(args[0], args[1], args[2]);
-    } else {
-      *ret = StorageFlatten(args[0], args[1], args[2], args[3]);
-    }
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
-.set_body_typed
-  ([](const Stmt& stmt,
-      const te::Schedule& schedule,
-      const Map<te::Tensor, Buffer>& extern_buffer) {
-      return RewriteForTensorCore(stmt, schedule, extern_buffer);
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
-.set_body_typed(
-  [](const ObjectRef& lhs, const ObjectRef& rhs) {
-    return AttrsEqual()(lhs, rhs);
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
-.set_body_typed([](const ObjectRef &node) -> int64_t {
-    return AttrsHash()(node);
-});
-
-
-TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    PackedFunc f = args[1];
-    tir::PostOrderVisit(args[0], [f](const ObjectRef& n) {
-        f(n);
-      });
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-  LoweredFunc f = args[0];
-  auto n = make_object<LoweredFuncNode>(*f.operator->());
-  n->body = LowerStorageAccessInfo(f->body);
-  *ret = LoweredFunc(n);
-});
-
-// make from two arguments
-#define REGISTER_PASS(PassName)                                   \
-  TVM_REGISTER_GLOBAL("ir_pass."#PassName)                           \
-  .set_body_typed(PassName);                                     \
-
-
-REGISTER_PASS(ConvertSSA);
-REGISTER_PASS(VerifySSA);
-REGISTER_PASS(RewriteUnsafeSelect);
-REGISTER_PASS(Inline);
-REGISTER_PASS(IRTransform);
-REGISTER_PASS(VectorizeLoop);
-REGISTER_PASS(SkipVectorize);
-REGISTER_PASS(UnrollLoop);
-REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(ThreadSync);
-REGISTER_PASS(MakeAPI);
-REGISTER_PASS(BindDeviceType);
-REGISTER_PASS(SplitHostDevice);
-REGISTER_PASS(StorageRewrite);
-REGISTER_PASS(CoProcSync);
-REGISTER_PASS(LowerStorageAccessInfo);
-REGISTER_PASS(LowerDeviceStorageAccessInfo)
-REGISTER_PASS(InjectVirtualThread);
-REGISTER_PASS(InjectPrefetch);
-REGISTER_PASS(InjectDoubleBuffer);
-REGISTER_PASS(LoopPartition);
-REGISTER_PASS(RemoveNoOp);
-REGISTER_PASS(LiftAttrScope);
-REGISTER_PASS(LowerThreadAllreduce);
-REGISTER_PASS(LowerWarpMemory);
-REGISTER_PASS(RemapThreadAxis);
-REGISTER_PASS(LowerIntrin);
-REGISTER_PASS(LowerCustomDatatypes);
-REGISTER_PASS(LowerTVMBuiltin);
-REGISTER_PASS(CombineContextCall);
-REGISTER_PASS(VerifyMemory);
-REGISTER_PASS(VerifyGPUCode);
-REGISTER_PASS(DecorateDeviceScope);
-REGISTER_PASS(InstrumentBoundCheckers);
-REGISTER_PASS(VerifyCompactBuffer);
-REGISTER_PASS(HoistIfThenElse);
-REGISTER_PASS(InferFragment)
-}  // namespace tir
-}  // namespace tvm
diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc
deleted file mode 100644 (file)
index a53c6e9..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to schedule pass.
- * \file api_schedule.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
-#include <tvm/te/schedule.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/runtime/registry.h>
-
-#include "../te/schedule/graph.h"
-
-namespace tvm {
-namespace te {
-
-TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
-.set_body_typed(AutoInlineElemWise);
-
-
-TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
-.set_body_typed(AutoInlineInjective);
-
-TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  if (args.size() == 2)
-    *ret = ScheduleOps(args[0], args[1], false);
-  else
-    *ret = ScheduleOps(args[0], args[1], args[2]);
-});
-
-#define REGISTER_SCHEDULE_PASS(PassName)                             \
-  TVM_REGISTER_GLOBAL("schedule."#PassName)                          \
-  .set_body_typed(PassName);                                         \
-
-
-REGISTER_SCHEDULE_PASS(InferBound);
-REGISTER_SCHEDULE_PASS(CreateReadGraph);
-REGISTER_SCHEDULE_PASS(PostDFSOrder);
-REGISTER_SCHEDULE_PASS(CreateAttachPath);
-REGISTER_SCHEDULE_PASS(ScanGetBody);
-REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
-
-}  // namespace te
-}  // namespace tvm
diff --git a/src/api/api_test.cc b/src/api/api_test.cc
deleted file mode 100644 (file)
index 2a1e605..0000000
+++ /dev/null
@@ -1,117 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
- /*!
- *  Code mainly used for test purposes.
- * \file api_test.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
-#include <tvm/ir/attrs.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/ir/env_func.h>
-
-namespace tvm {
-// Attrs used to python API
-struct TestAttrs : public AttrsNode<TestAttrs> {
-  int axis;
-  std::string name;
-  Array<PrimExpr> padding;
-  TypedEnvFunc<int(int)> func;
-
-  TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
-    TVM_ATTR_FIELD(axis)
-        .set_default(10)
-        .set_lower_bound(1)
-        .set_upper_bound(10)
-        .describe("axis field");
-    TVM_ATTR_FIELD(name)
-        .describe("name");
-    TVM_ATTR_FIELD(padding)
-        .describe("padding of input")
-        .set_default(Array<PrimExpr>({0, 0}));
-    TVM_ATTR_FIELD(func)
-        .describe("some random env function")
-        .set_default(TypedEnvFunc<int(int)>(nullptr));
-  }
-};
-
-TVM_REGISTER_NODE_TYPE(TestAttrs);
-
-TVM_REGISTER_GLOBAL("testing.nop")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-  });
-
-TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    PackedFunc pf = args[0];
-    *ret = runtime::TypedPackedFunc<void()>([pf](){
-        pf();
-      });
-  });
-
-TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    std::string msg = args[0];
-    *ret = runtime::TypedPackedFunc<void()>([msg](){
-        LOG(FATAL) << msg;
-      });
-  });
-
-TVM_REGISTER_GLOBAL("testing.test_check_eq_callback")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    std::string msg = args[0];
-    *ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
-        CHECK_EQ(x, y) << msg;
-      });
-  });
-
-TVM_REGISTER_GLOBAL("testing.context_test")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    DLContext ctx = args[0];
-    int dtype = args[1];
-    int did = args[2];
-    CHECK_EQ(static_cast<int>(ctx.device_type), dtype);
-    CHECK_EQ(static_cast<int>(ctx.device_id), did);
-    *ret = ctx;
-  });
-
-
-// in src/api_test.cc
-void ErrorTest(int x, int y) {
-  // raise ValueError
-  CHECK_EQ(x, y) << "ValueError: expect x and y to be equal.";
-  if (x == 1) {
-    // raise InternalError.
-    LOG(FATAL) << "InternalError: cannot reach here";
-  }
-}
-
-TVM_REGISTER_GLOBAL("testing.ErrorTest")
-.set_body_typed(ErrorTest);
-
-// internal function used for debug and testing purposes
-TVM_REGISTER_GLOBAL("testing.ndarray_use_count")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    runtime::NDArray nd = args[0];
-    // substract the current one
-    *ret = (nd.use_count() - 1);
-  });
-
-}  // namespace tvm
index b12e5f51f4fbbcb6715f6f2554c814f9aafc49aa..9df5aa2d246deeb76a433ac85e01984e0b4093ae 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file tvm/arith/analyzer.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/op.h>
@@ -109,5 +110,64 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
   return res;
 }
 
+TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    using runtime::PackedFunc;
+    using runtime::TypedPackedFunc;
+    auto self = std::make_shared<Analyzer>();
+    auto f = [self](std::string name) -> PackedFunc {
+      if (name == "const_int_bound") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->const_int_bound(args[0]);
+          });
+      } else if (name == "modular_set") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->modular_set(args[0]);
+        });
+      } else if (name == "const_int_bound_update") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            self->const_int_bound.Update(args[0], args[1], args[2]);
+        });
+      } else if (name == "Simplify") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->Simplify(args[0]);
+        });
+      } else if (name == "rewrite_simplify") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->rewrite_simplify(args[0]);
+        });
+      } else if (name == "canonical_simplify") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->canonical_simplify(args[0]);
+        });
+      } else if (name == "int_set") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->int_set(args[0], args[1]);
+        });
+      } else if (name == "bind") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            if (args[1].IsObjectRef<Range>()) {
+              self->Bind(args[0], args[1].operator Range());
+            } else {
+              self->Bind(args[0], args[1].operator PrimExpr());
+            }
+        });
+      } else if (name == "enter_constraint_context") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            // can't use make_shared due to noexcept(false) decl in destructor,
+            // see https://stackoverflow.com/a/43907314
+            auto ctx = std::shared_ptr<With<ConstraintContext> >(
+                new With<ConstraintContext>(self.get(), args[0]));
+            auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
+              ctx.reset();
+            };
+            *ret = PackedFunc(fexit);
+        });
+      }
+      return PackedFunc();
+    };
+    *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
+});
+
 }  // namespace arith
 }  // namespace tvm
index df8f40230e0469c431d9f55fca55ce67f8655bc8..26be5d51115f7c07ff1be8352ae0f8145d290cf2 100644 (file)
  * \file bound_deducer.cc
  * \brief Utility to deduce bound of expression
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/expr_functor.h>
 #include <tvm/arith/analyzer.h>
-#include <tvm/runtime/registry.h>
 
 #include <unordered_set>
 #include <unordered_map>
@@ -362,5 +362,16 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
   return DeduceBound(v, e, hmap, rmap);
 }
 
+
+TVM_REGISTER_GLOBAL("arith.DeduceBound")
+.set_body_typed([](
+  PrimExpr v, PrimExpr cond,
+  const Map<Var, IntSet> hint_map,
+  const Map<Var, IntSet> relax_map
+) {
+  return DeduceBound(v, cond, hint_map, relax_map);
+});
+
+
 }  // namespace arith
 }  // namespace tvm
index 7fb90a5e87c146d56b382ace93d23cb0b1f4b1b1..9ef5723e153ea5332449cf0ccb7be44e8f37bccc 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file tvm/arith/const_int_bound.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr_functor.h>
 #include <algorithm>
@@ -41,6 +42,13 @@ ConstIntBound::ConstIntBound(
   data_ = std::move(node);
 }
 
+ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
+  return ConstIntBound(min_value, max_value);
+}
+
+TVM_REGISTER_GLOBAL("arith.ConstIntBound")
+.set_body_typed(MakeConstIntBound);
+
 inline void PrintBoundValue(std::ostream& os, int64_t val) {
   if (val == ConstIntBound::kPosInf) {
     os << "pos_inf";
index 53adf35eb6ee926b5adf8ef08b71ca4735e25b91..cc9c745a24b8cff37fdd61ddfa5edb7e3d52fc5a 100644 (file)
@@ -21,6 +21,7 @@
  * \file detect_linear_equation.cc
  * \brief Utility to detect patterns in the expression.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/expr_functor.h>
@@ -268,6 +269,12 @@ Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
   return ret;
 }
 
+TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
+.set_body_typed(DetectLinearEquation);
 
+TVM_REGISTER_GLOBAL("arith.DetectClipBound")
+.set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
+  return DetectClipBound(e, vars);
+});
 }  // namespace arith
 }  // namespace tvm
index aa1ba4eb67be664d52ea811d2f8cbcb81aab845e..4eecabdb6d8cb2112df3646bb3c7a83922b44025 100644 (file)
@@ -119,5 +119,8 @@ Domain DomainTouched(Stmt stmt,
   return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
 }
 
+TVM_REGISTER_GLOBAL("arith.DomainTouched")
+.set_body_typed(DomainTouched);
+
 }  // namespace arith
 }  // namespace tvm
index adb38799fdf2c59e6955a3cb5efc06693ac0cd53..8c5afb1be8b5c678efdcdb32275ddf23cc05a2ca 100644 (file)
@@ -820,5 +820,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
               << "[" << op->min_value << ", "
               << op->max_value << ']';
   });
+
+
+TVM_REGISTER_GLOBAL("arith.intset_single_point")
+.set_body_typed(IntSet::single_point);
+
+TVM_REGISTER_GLOBAL("arith.intset_vector")
+.set_body_typed(IntSet::vector);
+
+TVM_REGISTER_GLOBAL("arith.intset_interval")
+.set_body_typed(IntSet::interval);
+
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
+.set_body_method(&IntSet::min);
+
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
+.set_body_method(&IntSet::max);
+
+TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
+.set_body_method(&IntSet::is_nothing);
+
+TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
+.set_body_method(&IntSet::is_everything);
+
 }  // namespace arith
 }  // namespace tvm
index c3031ca0edfc1a4f82000ff541603af13db9b880..40cd7f8793ee928b14352a757454f84185a8f911 100644 (file)
@@ -21,6 +21,7 @@
  * \file modular_set.cc
  * \brief Modular set analysis
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/expr_functor.h>
@@ -52,6 +53,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
               << op->base << ')';
   });
 
+ModularSet MakeModularSet(int64_t coeff, int64_t base) {
+  return ModularSet(coeff, base);
+}
+
+TVM_REGISTER_GLOBAL("arith.ModularSet")
+.set_body_typed(MakeModularSet);
 
 // internal entry for const int bound
 struct ModularSetAnalyzer::Entry {
index 4feabeb8e50548cfd5a5761962a3afbf64f38314..6244c7645accb56560773ad05f58f0cbbe95e97b 100644 (file)
@@ -134,6 +134,14 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
   return Range(make_object<RangeNode>(min, extent));
 }
 
+TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
+.set_body_typed(Range::make_by_min_extent);
+
+TVM_REGISTER_GLOBAL("ir.Range")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+  *ret = Range(args[0], args[1]);
+  });
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const RangeNode*>(node.get());
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
new file mode 100644 (file)
index 0000000..9053f62
--- /dev/null
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+ /*!
+ *  FFI registration code used for frontend testing purposes.
+ * \file ffi_testing.cc
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/te/tensor.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/env_func.h>
+
+namespace tvm {
+// Attrs used to python API
+struct TestAttrs : public AttrsNode<TestAttrs> {
+  int axis;
+  std::string name;
+  Array<PrimExpr> padding;
+  TypedEnvFunc<int(int)> func;
+
+  TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
+    TVM_ATTR_FIELD(axis)
+        .set_default(10)
+        .set_lower_bound(1)
+        .set_upper_bound(10)
+        .describe("axis field");
+    TVM_ATTR_FIELD(name)
+        .describe("name");
+    TVM_ATTR_FIELD(padding)
+        .describe("padding of input")
+        .set_default(Array<PrimExpr>({0, 0}));
+    TVM_ATTR_FIELD(func)
+        .describe("some random env function")
+        .set_default(TypedEnvFunc<int(int)>(nullptr));
+  }
+};
+
+TVM_REGISTER_NODE_TYPE(TestAttrs);
+
+TVM_REGISTER_GLOBAL("testing.nop")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+  });
+
+TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    PackedFunc pf = args[0];
+    *ret = runtime::TypedPackedFunc<void()>([pf](){
+        pf();
+      });
+  });
+
+TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    std::string msg = args[0];
+    *ret = runtime::TypedPackedFunc<void()>([msg](){
+        LOG(FATAL) << msg;
+      });
+  });
+
+TVM_REGISTER_GLOBAL("testing.test_check_eq_callback")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    std::string msg = args[0];
+    *ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
+        CHECK_EQ(x, y) << msg;
+      });
+  });
+
+TVM_REGISTER_GLOBAL("testing.context_test")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    DLContext ctx = args[0];
+    int dtype = args[1];
+    int did = args[2];
+    CHECK_EQ(static_cast<int>(ctx.device_type), dtype);
+    CHECK_EQ(static_cast<int>(ctx.device_id), did);
+    *ret = ctx;
+  });
+
+
+// in src/api_test.cc
+void ErrorTest(int x, int y) {
+  // raise ValueError
+  CHECK_EQ(x, y) << "ValueError: expect x and y to be equal.";
+  if (x == 1) {
+    // raise InternalError.
+    LOG(FATAL) << "InternalError: cannot reach here";
+  }
+}
+
+TVM_REGISTER_GLOBAL("testing.ErrorTest")
+.set_body_typed(ErrorTest);
+
+// internal function used for debug and testing purposes
+TVM_REGISTER_GLOBAL("testing.ndarray_use_count")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    runtime::NDArray nd = args[0];
+    // substract the current one
+    *ret = (nd.use_count() - 1);
+  });
+
+}  // namespace tvm
index 1886d976555b66d518957b46a512d559dfe5fbcf..6123c613d0bdf48315e52e0c4db82086fe26c9be 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Compute Op.
  * \file compute_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -156,6 +157,10 @@ Operation ComputeOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.ComputeOp")
+.set_body_typed(ComputeOpNode::make);
+
+
 // The schedule related logics
 Array<Tensor> ComputeOpNode::InputTensors() const {
   Array<Tensor> ret;
index c1e55046102bf22357c568affe52ae3a6850a259..62c8dfd30d49032b1e4219a73c024ca04ff4bcd9 100644 (file)
@@ -21,6 +21,7 @@
  * \brief External computation rule.
  * \file extern_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -86,6 +87,10 @@ Operation ExternOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.ExternOp")
+.set_body_typed(ExternOpNode::make);
+
+
 Array<Tensor> ExternOpNode::InputTensors() const {
   return inputs;
 }
index bb883ae4700421df82e17c3dda5fc8c742cbe1f2..70abf34523b982e016504ada6dc873224904ab00 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Hybrid computation rule.
  * \file hybrid_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -83,6 +84,10 @@ Operation HybridOpNode::make(std::string name,
   return res;
 }
 
+TVM_REGISTER_GLOBAL("te.HybridOp")
+.set_body_typed(HybridOpNode::make);
+
+
 Array<Tensor> HybridOpNode::InputTensors() const {
   // Because input tensors could be potentially inlined into hybrid scripts,
   // we need to check if all input tensors are used in the body.
index 866ef949cf494dd71e9205e06e3edc1776654dba..d48be4c53668164865394ff67ad7fad4e59fb533 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Placeholder op.
  * \file placeholder_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 
 namespace tvm {
@@ -67,6 +68,11 @@ Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
   return PlaceholderOpNode::make(name, shape, dtype).output(0);
 }
 
+TVM_REGISTER_GLOBAL("te.Placeholder")
+.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
+  return placeholder(shape, dtype, name);
+});
+
 Array<Tensor> PlaceholderOpNode::InputTensors() const {
   return {};
 }
index cacfd8c4a4f149473b7836878a2f7305fdd77a96..956a297f5b3c6a71183bb3d81fe80e4001d3e7f0 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Scan Operator.
  * \file scan_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
@@ -120,6 +121,10 @@ Operation ScanOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.ScanOp")
+.set_body_typed(ScanOpNode::make);
+
+
 Array<Tensor> scan(Array<Tensor> init,
                    Array<Tensor> update,
                    Array<Tensor> state_placeholder,
index 8ce621ccc55beb8ad97dfd98e6e73d58f8b17b5b..4cdc9e1f8d328862061876364d65c2e9f44ac1aa 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Tensor Compute Op.
  * \file tensor_compute_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -72,6 +73,10 @@ Operation TensorComputeOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.TensorComputeOp")
+.set_body_typed(TensorComputeOpNode::make);
+
+
 Array<Tensor> TensorComputeOpNode::InputTensors() const {
   return inputs;
 }
index 3a2226780f20f9d65ce67dbab7d1ab599694748e..6d79f4a8d1d67f9107ba0f6bc28df65cc6286fd9 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file auto_inline_elem_wise.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/schedule_pass.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/expr_functor.h>
@@ -111,5 +112,12 @@ void AutoInlineInjective(Schedule sch) {
   }
 }
 
+TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
+.set_body_typed(AutoInlineElemWise);
+
+
+TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
+.set_body_typed(AutoInlineInjective);
+
 }  // namespace te
 }  // namespace tvm
index 27896e6738a84dab74b072094771914b0156be47..50cbafd2b654026766d6d1247c61742820200fe9 100644 (file)
@@ -21,6 +21,7 @@
  * \file bound.cc
  * \brief The bound inference logic.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/schedule_pass.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/ir_pass.h>
@@ -259,5 +260,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
   return Map<IterVar, Range>(ret.begin(), ret.end());
 }
 
+TVM_REGISTER_GLOBAL("schedule.InferBound")
+.set_body_typed(InferBound);
+
 }  // namespace te
 }  // namespace tvm
index eff0a25c569aff995ed76e74da47007f3c7a8064..9dce36f220ef85b85eefc2ff91adf5ba48516302 100644 (file)
@@ -21,6 +21,7 @@
  * \file graph.cc
  * \brief Utilities to get information about schedule graph.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/te/operation.h>
@@ -429,5 +430,24 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
   return ret;
 }
 
+
+TVM_REGISTER_GLOBAL("schedule.CreateReadGraph")
+.set_body_typed(CreateReadGraph);
+
+TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
+.set_body_typed([](const Array<Operation>& roots,
+                   const ReadGraph& g) {
+  return PostDFSOrder(roots, g);
+});
+
+TVM_REGISTER_GLOBAL("schedule.CreateAttachPath")
+.set_body_typed(CreateAttachPath);
+
+TVM_REGISTER_GLOBAL("schedule.ScanGetBody")
+.set_body_typed(ScanGetBody);
+
+TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis")
+.set_body_typed(ScanFixPointAnalysis);
+
 }  // namespace te
 }  // namespace tvm
index 1763bd64c15ffb924fc44318976ce62da3f3e300..d3b448d37790032ebd81e577f7e12118b88649d7 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file schedule_lang.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/schedule.h>
 #include <tvm/te/operation.h>
 #include <unordered_set>
@@ -848,5 +849,118 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     auto* op = static_cast<const ScheduleNode*>(node.get());
     p->stream << "schedule(" << op << ")";
   });
+
+
+TVM_REGISTER_GLOBAL("te.CreateSchedule")
+.set_body_typed(create_schedule);
+
+TVM_REGISTER_GLOBAL("te.StageSetScope")
+.set_body_method(&Stage::set_scope);
+
+TVM_REGISTER_GLOBAL("te.StageBind")
+.set_body_method(&Stage::bind);
+
+TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
+  IterVar outer, inner;
+  stage.split(parent, factor, &outer, &inner);
+  return Array<IterVar>({outer, inner});
+});
+
+TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
+  IterVar outer, inner;
+  stage.split_by_nparts(parent, nparts, &outer, &inner);
+  return Array<IterVar>({outer, inner});
+});
+
+TVM_REGISTER_GLOBAL("te.StageFuse")
+.set_body_typed([](Stage stage, Array<IterVar> axes) {
+    IterVar fused;
+    stage.fuse(axes, &fused);
+    return fused;
+  });
+
+TVM_REGISTER_GLOBAL("te.StageComputeAt")
+.set_body_method(&Stage::compute_at);
+
+TVM_REGISTER_GLOBAL("te.StageComputeInline")
+.set_body_method(&Stage::compute_inline);
+
+TVM_REGISTER_GLOBAL("te.StageComputeRoot")
+.set_body_method(&Stage::compute_root);
+
+TVM_REGISTER_GLOBAL("te.StageReorder")
+.set_body_method(&Stage::reorder);
+
+TVM_REGISTER_GLOBAL("te.StageTile")
+.set_body_typed([](
+  Stage stage,
+  IterVar x_parent, IterVar y_parent,
+  PrimExpr x_factor, PrimExpr y_factor
+) {
+    IterVar x_outer, y_outer, x_inner, y_inner;
+    stage.tile(x_parent, y_parent,
+               x_factor, y_factor,
+               &x_outer, &y_outer,
+               &x_inner, &y_inner);
+    return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
+  });
+
+TVM_REGISTER_GLOBAL("te.StageEnvThreads")
+.set_body_method(&Stage::env_threads);
+
+TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
+.set_body_method(&Stage::set_store_predicate);
+
+TVM_REGISTER_GLOBAL("te.StageUnroll")
+.set_body_method(&Stage::unroll);
+
+TVM_REGISTER_GLOBAL("te.StageVectorize")
+.set_body_method(&Stage::vectorize);
+
+TVM_REGISTER_GLOBAL("te.StageTensorize")
+.set_body_method(&Stage::tensorize);
+
+TVM_REGISTER_GLOBAL("te.StageParallel")
+.set_body_method(&Stage::parallel);
+
+TVM_REGISTER_GLOBAL("te.StagePragma")
+.set_body_method(&Stage::pragma);
+
+TVM_REGISTER_GLOBAL("te.StagePrefetch")
+.set_body_method(&Stage::prefetch);
+
+TVM_REGISTER_GLOBAL("te.StageStorageAlign")
+.set_body_method(&Stage::storage_align);
+
+TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
+.set_body_method(&Stage::double_buffer);
+
+TVM_REGISTER_GLOBAL("te.StageOpenGL")
+.set_body_method(&Stage::opengl);
+
+TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
+.set_body_method(&Schedule::normalize);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
+.set_body_method(&Schedule::create_group);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
+.set_body_method(&Schedule::cache_read);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    if (args[1].IsObjectRef<Tensor>()) {
+      *ret = args[0].operator Schedule()
+          .cache_write(args[1].operator Tensor(), args[2]);
+    } else {
+      *ret = args[0].operator Schedule()
+          .cache_write(args[1].operator Array<Tensor>(), args[2]);
+    }
+  });
+
+TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
+.set_body_method(&Schedule::rfactor);
 }  // namespace te
 }  // namespace tvm
index 0930f26372c47dd6b8450b34b6b7c389e68311a0..a110bc458fe9b77a997259ef3f82b05265651481 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file schedule_ops.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/stmt_functor.h>
@@ -423,5 +424,13 @@ Stmt ScheduleOps(
   return post_proc(std::move(body));
 }
 
+TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  if (args.size() == 2)
+    *ret = ScheduleOps(args[0], args[1], false);
+  else
+    *ret = ScheduleOps(args[0], args[1], args[2]);
+});
+
 }  // namespace te
 }  // namespace tvm
index f200514468cb141a3eec1bd567a938af61dab1e7..cb14f6a35270dadb8b08c27197699ea4d1cabf66 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file tensor.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/tensor.h>
 #include <tvm/te/operation.h>
 #include <tvm/te/tensor_intrin.h>
@@ -147,5 +148,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
 
+TVM_REGISTER_GLOBAL("te.Tensor")
+.set_body_typed(TensorNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorIntrin")
+.set_body_typed(TensorIntrinNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
+.set_body_typed(TensorIntrinCallNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorEqual")
+.set_body_method(&Tensor::operator==);
+
+TVM_REGISTER_GLOBAL("te.TensorHash")
+.set_body_typed([](Tensor tensor) -> int64_t {
+    return static_cast<int64_t>(std::hash<Tensor>()(tensor));
+  });
+
+TVM_REGISTER_GLOBAL("te.OpGetOutput")
+.set_body_typed([](Operation op, int64_t output) {
+  return op.output(static_cast<size_t>(output));
+});
+
+TVM_REGISTER_GLOBAL("te.OpNumOutputs")
+.set_body_method<Operation>(&OperationNode::num_outputs);
+
+TVM_REGISTER_GLOBAL("te.OpInputTensors")
+.set_body_method<Operation>(&OperationNode::InputTensors);
+
 }  // namespace te
 }  // namespace tvm
index d06c33f79dcc67b9c7d98810fd8082c0ae0194f6..22844745982fd91d040586d9f4e423110de68af8 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file expr.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/op.h>
@@ -45,6 +46,17 @@ SizeVar::SizeVar(std::string name_hint, DataType t)
 SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
     : VarNode(t, std::move(name_hint)) {}
 
+
+TVM_REGISTER_GLOBAL("tir.Var")
+.set_body_typed([](std::string s, DataType t) {
+    return Var(s, t);
+  });
+
+TVM_REGISTER_GLOBAL("tir.SizeVar")
+.set_body_typed([](std::string s, DataType t) {
+    return SizeVar(s, t);
+  });
+
 IterVar IterVarNode::make(Range dom,
                           Var var,
                           IterVarType t,
@@ -57,6 +69,14 @@ IterVar IterVarNode::make(Range dom,
   return IterVar(n);
 }
 
+TVM_REGISTER_GLOBAL("tir.IterVar")
+.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
+  return IterVarNode::make(
+      dom, var,
+      static_cast<IterVarType>(iter_type),
+      thread_tag);
+});
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const IterVarNode*>(node.get());
@@ -83,6 +103,9 @@ PrimExpr StringImmNode::make(std::string value) {
   return PrimExpr(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.StringImm")
+.set_body_typed(StringImmNode::make);
+
 PrimExpr CastNode::make(DataType t, PrimExpr value) {
   CHECK(value.defined());
   CHECK_EQ(t.lanes(), value.dtype().lanes());
@@ -311,6 +334,13 @@ Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b
     });
 }
 
+TVM_REGISTER_GLOBAL("tir.CommReducer")
+.set_body_typed(CommReducerNode::make);
+
+TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
+.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
+
+
 PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
                   Array<IterVar> axis, PrimExpr condition, int value_index) {
   for (size_t i = 0; i < axis.size(); ++i) {
@@ -334,6 +364,11 @@ PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
   return PrimExpr(n);
 }
 
+
+TVM_REGISTER_GLOBAL("tir.Reduce")
+.set_body_typed(ReduceNode::make);
+
+
 PrimExpr AnyNode::make() {
   auto n = make_object<AnyNode>();
   return PrimExpr(n);
@@ -659,5 +694,104 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode);
 TVM_REGISTER_NODE_TYPE(ReduceNode);
 TVM_REGISTER_NODE_TYPE(AnyNode);
 
+
+TVM_REGISTER_GLOBAL("tir.Add")
+.set_body_typed(AddNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Sub")
+.set_body_typed(SubNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Mul")
+.set_body_typed(MulNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Div")
+.set_body_typed(DivNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Mod")
+.set_body_typed(ModNode::make);
+
+TVM_REGISTER_GLOBAL("tir.FloorDiv")
+.set_body_typed(FloorDivNode::make);
+
+TVM_REGISTER_GLOBAL("tir.FloorMod")
+.set_body_typed(FloorModNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Min")
+.set_body_typed(MinNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Max")
+.set_body_typed(MaxNode::make);
+
+TVM_REGISTER_GLOBAL("tir.EQ")
+.set_body_typed(EQNode::make);
+
+TVM_REGISTER_GLOBAL("tir.NE")
+.set_body_typed(NENode::make);
+
+TVM_REGISTER_GLOBAL("tir.LT")
+.set_body_typed(LTNode::make);
+
+TVM_REGISTER_GLOBAL("tir.LE")
+.set_body_typed(LENode::make);
+
+TVM_REGISTER_GLOBAL("tir.GT")
+.set_body_typed(GTNode::make);
+
+TVM_REGISTER_GLOBAL("tir.GE")
+.set_body_typed(GENode::make);
+
+TVM_REGISTER_GLOBAL("tir.And")
+.set_body_typed(AndNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Or")
+.set_body_typed(OrNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Not")
+.set_body_typed(NotNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Select")
+.set_body_typed(SelectNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Ramp")
+.set_body_typed(RampNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Cast")
+.set_body_typed(CastNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Broadcast")
+.set_body_typed(BroadcastNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Shuffle")
+.set_body_typed(ShuffleNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Let")
+.set_body_typed(LetNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Load")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    DataType t = args[0];
+    if (args.size() == 3) {
+      *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
+    } else {
+      *ret = LoadNode::make(t, args[1], args[2], args[3]);
+    }
+  });
+
+
+
+TVM_REGISTER_GLOBAL("tir.Call")
+.set_body_typed([](
+  DataType type, std::string name,
+  Array<PrimExpr> args, int call_type,
+  FunctionRef func, int value_index
+) {
+  return CallNode::make(type,
+                    name,
+                    args,
+                    static_cast<CallNode::CallType>(call_type),
+                    func,
+                    value_index);
+});
+
 }  // namespace tir
 }  // namespace tvm
index 58f8b6b76da8fee2cd3e3a1ac1fe705ed904c58a..452c3bbc68a23ca0b88cfd7748081fbd691d9cd5 100644 (file)
@@ -662,4 +662,90 @@ TVM_REGISTER_GLOBAL("node.LargeUIntImm")
 TVM_REGISTER_GLOBAL("node.String")
 .set_body_typed(tir::StringImmNode::make);
 
+TVM_REGISTER_GLOBAL("tir.min_value")
+.set_body_typed(min_value);
+
+TVM_REGISTER_GLOBAL("tir.max_value")
+.set_body_typed(max_value);
+
+TVM_REGISTER_GLOBAL("tir.abs")
+.set_body_typed(tvm::abs);
+
+TVM_REGISTER_GLOBAL("tir.isnan")
+.set_body_typed(tvm::isnan);
+
+TVM_REGISTER_GLOBAL("tir.floor")
+.set_body_typed(tvm::floor);
+
+TVM_REGISTER_GLOBAL("tir.ceil")
+.set_body_typed(tvm::ceil);
+
+TVM_REGISTER_GLOBAL("tir.round")
+.set_body_typed(tvm::round);
+
+TVM_REGISTER_GLOBAL("tir.nearbyint")
+.set_body_typed(tvm::nearbyint);
+
+TVM_REGISTER_GLOBAL("tir.trunc")
+.set_body_typed(tvm::trunc);
+
+TVM_REGISTER_GLOBAL("tir._cast")
+.set_body_typed(tvm::cast);
+
+
+
+// operator overloading, smarter than make
+#define REGISTER_MAKE_BINARY_OP(Node, Func)                     \
+  TVM_REGISTER_GLOBAL("tir."#Node)                              \
+  .set_body_typed([](PrimExpr a, PrimExpr b) {                  \
+    return (Func(a, b));                                        \
+  })
+
+#define REGISTER_MAKE_BIT_OP(Node, Func)                                \
+  TVM_REGISTER_GLOBAL("tir."#Node)                                      \
+  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
+    bool lhs_is_int = args[0].type_code() == kDLInt;                    \
+    bool rhs_is_int = args[1].type_code() == kDLInt;                    \
+    if (lhs_is_int) {                                                   \
+      *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
+    } else if (rhs_is_int) {                                            \
+      *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
+    } else {                                                            \
+      *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
+    }                                                                   \
+  })
+
+
+REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
+REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
+REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
+REGISTER_MAKE_BINARY_OP(_OpDiv, div);
+REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
+REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
+REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
+REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
+REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
+REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpPow, pow);
+REGISTER_MAKE_BINARY_OP(_OpMin, min);
+REGISTER_MAKE_BINARY_OP(_OpMax, max);
+REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
+REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
+REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpGT, operator>);  // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
+REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
+REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
+REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
+REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
+REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
+REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
+REGISTER_MAKE_BIT_OP(right_shift, operator>>);
+
+TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
+.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
+  return if_then_else(cond, true_value, false_value);
+});
 }  // namespace tvm
index 0cd2aba319ee1d82ea0230de0d6829a6b3bc4a6b..a8fe9cd2bad31f66c22a6381a85f3a852c41c9dd 100644 (file)
@@ -20,7 +20,7 @@
 /*!
  * \file tvm/tir/stmt.cc
  */
-
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/ir_pass.h>
 #include "../pass/ir_util.h"
@@ -40,6 +40,9 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.LetStmt")
+.set_body_typed(LetStmtNode::make);
+
 Stmt AttrStmtNode::make(ObjectRef node,
                     std::string attr_key,
                     PrimExpr value,
@@ -52,6 +55,10 @@ Stmt AttrStmtNode::make(ObjectRef node,
   return Stmt(n);
 }
 
+TVM_REGISTER_GLOBAL("tir.AttrStmt")
+.set_body_typed(AttrStmtNode::make);
+
+
 Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
   CHECK(condition.defined());
   CHECK(message.dtype() == DataType::Int(32) ||
@@ -66,6 +73,10 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.AssertStmt")
+.set_body_typed(AssertStmtNode::make);
+
+
 Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
   CHECK(body.defined());
 
@@ -76,6 +87,10 @@ Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
+.set_body_typed(ProducerConsumerNode::make);
+
+
 Stmt ForNode::make(Var loop_var,
                PrimExpr min,
                PrimExpr extent,
@@ -99,6 +114,19 @@ Stmt ForNode::make(Var loop_var,
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.For")
+.set_body_typed([](
+  Var loop_var, PrimExpr min, PrimExpr extent,
+  int for_type, int device_api, Stmt body) {
+  return ForNode::make(loop_var,
+                   min,
+                   extent,
+                   static_cast<ForType>(for_type),
+                   static_cast<DeviceAPI>(device_api),
+                   body);
+});
+
+
 Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
   CHECK(value.defined());
   CHECK(index.defined());
@@ -114,6 +142,18 @@ Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr pr
   return Stmt(node);
 }
 
+
+TVM_REGISTER_GLOBAL("tir.Store")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    PrimExpr value = args[1];
+    if (args.size() == 3) {
+      *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
+    } else {
+      *ret = StoreNode::make(args[0], value, args[2], args[3]);
+    }
+  });
+
+
 Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
   CHECK(value_index >=0 && value_index < func->num_outputs())
       << "value index output function return value bound";
@@ -131,6 +171,10 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Provide")
+.set_body_typed(ProvideNode::make);
+
+
 Stmt AllocateNode::make(Var buffer_var,
                     DataType dtype,
                     Array<PrimExpr> extents,
@@ -157,6 +201,15 @@ Stmt AllocateNode::make(Var buffer_var,
     return Stmt(node);
 }
 
+// overloaded, needs special handling
+// has default args
+TVM_REGISTER_GLOBAL("tir.Allocate")
+.set_body_typed([](
+    Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
+                   ){
+  return AllocateNode::make(buffer_var, type, extents, condition, body);
+});
+
 int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
   int64_t result = 1;
   for (size_t i = 0; i < extents.size(); ++i) {
@@ -178,12 +231,16 @@ Stmt FreeNode::make(Var buffer_var) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Free")
+.set_body_typed(FreeNode::make);
+
+
 Stmt RealizeNode::make(FunctionRef func,
-                   int value_index,
-                   DataType dtype,
-                   Region bounds,
-                   PrimExpr condition,
-                   Stmt body) {
+                       int value_index,
+                       DataType dtype,
+                       Region bounds,
+                       PrimExpr condition,
+                       Stmt body) {
   for (size_t i = 0; i < bounds.size(); ++i) {
     CHECK(bounds[i]->min.defined());
     CHECK(bounds[i]->extent.defined());
@@ -204,6 +261,11 @@ Stmt RealizeNode::make(FunctionRef func,
   return Stmt(node);
 }
 
+
+TVM_REGISTER_GLOBAL("tir.Realize")
+.set_body_typed(RealizeNode::make);
+
+
 Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
   for (size_t i = 0; i < bounds.size(); ++i) {
     CHECK(bounds[i]->min.defined());
@@ -220,12 +282,21 @@ Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Regio
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Prefetch")
+.set_body_typed(PrefetchNode::make);
+
+
 SeqStmt::SeqStmt(Array<Stmt> seq) {
   auto node = make_object<SeqStmtNode>();
   node->seq = std::move(seq);
   data_ = std::move(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.SeqStmt")
+.set_body_typed([](Array<Stmt> seq) {
+  return SeqStmt(std::move(seq));
+});
+
 Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
   CHECK(condition.defined());
   CHECK(then_case.defined());
@@ -238,6 +309,10 @@ Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.IfThenElse")
+.set_body_typed(IfThenElseNode::make);
+
+
 Stmt EvaluateNode::make(PrimExpr value) {
   CHECK(value.defined());
 
@@ -246,6 +321,9 @@ Stmt EvaluateNode::make(PrimExpr value) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Evaluate")
+.set_body_typed(EvaluateNode::make);
+
 // Printers
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc
new file mode 100644 (file)
index 0000000..233bfa5
--- /dev/null
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Exposure of pass functions.
+ * \file ffi_api.cc
+ */
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace tir {
+
+TVM_REGISTER_GLOBAL("ir_pass.Simplify")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    if (args[0].IsObjectRef<Stmt>()) {
+      if (args.size() > 1) {
+        *ret = Simplify(args[0].operator Stmt(), args[1]);
+      } else {
+        *ret = Simplify(args[0].operator Stmt());
+      }
+    } else {
+      if (args.size() > 1) {
+        *ret = Simplify(args[0].operator PrimExpr(), args[1]);
+      } else {
+        *ret = Simplify(args[0].operator PrimExpr());
+      }
+    }
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    if (args[0].IsObjectRef<Stmt>()) {
+      if (args.size() > 1) {
+        *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
+      } else {
+        *ret = CanonicalSimplify(args[0].operator Stmt());
+      }
+    } else {
+      if (args.size() > 1) {
+        *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]);
+      } else {
+        *ret = CanonicalSimplify(args[0].operator PrimExpr());
+      }
+    }
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.Substitute")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    if (args[0].IsObjectRef<Stmt>()) {
+      *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, PrimExpr>());
+    } else {
+      *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map<Var, PrimExpr>());
+    }
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.Equal")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    if (args[0].IsObjectRef<Stmt>()) {
+      *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
+    } else {
+      *ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr());
+    }
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    if (args.size() <= 3) {
+      *ret = StorageFlatten(args[0], args[1], args[2]);
+    } else {
+      *ret = StorageFlatten(args[0], args[1], args[2], args[3]);
+    }
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
+.set_body_typed
+  ([](const Stmt& stmt,
+      const te::Schedule& schedule,
+      const Map<te::Tensor, Buffer>& extern_buffer) {
+      return RewriteForTensorCore(stmt, schedule, extern_buffer);
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
+.set_body_typed(
+  [](const ObjectRef& lhs, const ObjectRef& rhs) {
+    return AttrsEqual()(lhs, rhs);
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
+.set_body_typed([](const ObjectRef &node) -> int64_t {
+    return AttrsHash()(node);
+});
+
+
+TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    PackedFunc f = args[1];
+    tir::PostOrderVisit(args[0], [f](const ObjectRef& n) {
+        f(n);
+      });
+  });
+
+TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  LoweredFunc f = args[0];
+  auto n = make_object<LoweredFuncNode>(*f.operator->());
+  n->body = LowerStorageAccessInfo(f->body);
+  *ret = LoweredFunc(n);
+});
+
+// make from two arguments
+#define REGISTER_PASS(PassName)                                   \
+  TVM_REGISTER_GLOBAL("ir_pass."#PassName)                        \
+  .set_body_typed(PassName);                                      \
+
+
+REGISTER_PASS(ConvertSSA);
+REGISTER_PASS(VerifySSA);
+REGISTER_PASS(RewriteUnsafeSelect);
+REGISTER_PASS(Inline);
+REGISTER_PASS(IRTransform);
+REGISTER_PASS(VectorizeLoop);
+REGISTER_PASS(SkipVectorize);
+REGISTER_PASS(UnrollLoop);
+REGISTER_PASS(InjectCopyIntrin);
+REGISTER_PASS(ThreadSync);
+REGISTER_PASS(MakeAPI);
+REGISTER_PASS(BindDeviceType);
+REGISTER_PASS(SplitHostDevice);
+REGISTER_PASS(StorageRewrite);
+REGISTER_PASS(CoProcSync);
+REGISTER_PASS(LowerStorageAccessInfo);
+REGISTER_PASS(LowerDeviceStorageAccessInfo)
+REGISTER_PASS(InjectVirtualThread);
+REGISTER_PASS(InjectPrefetch);
+REGISTER_PASS(InjectDoubleBuffer);
+REGISTER_PASS(LoopPartition);
+REGISTER_PASS(RemoveNoOp);
+REGISTER_PASS(LiftAttrScope);
+REGISTER_PASS(LowerThreadAllreduce);
+REGISTER_PASS(LowerWarpMemory);
+REGISTER_PASS(RemapThreadAxis);
+REGISTER_PASS(LowerIntrin);
+REGISTER_PASS(LowerCustomDatatypes);
+REGISTER_PASS(LowerTVMBuiltin);
+REGISTER_PASS(CombineContextCall);
+REGISTER_PASS(VerifyMemory);
+REGISTER_PASS(VerifyGPUCode);
+REGISTER_PASS(DecorateDeviceScope);
+REGISTER_PASS(InstrumentBoundCheckers);
+REGISTER_PASS(VerifyCompactBuffer);
+REGISTER_PASS(HoistIfThenElse);
+REGISTER_PASS(InferFragment)
+}  // namespace tir
+}  // namespace tvm
index d1a2d983ff25102371a3766b81fb9192030b46e1..ac019a0aab40cc38e9f661e375d78d8ee0951f82 100644 (file)
@@ -27,7 +27,7 @@ def test_op_translation():
     except tvm.error.OpNotImplemented as e:
         msg = str(e)
         assert isinstance(e, NotImplementedError)
-        assert msg.find("api_test.cc") != -1
+        assert msg.find("ffi_testing.cc") != -1
 
     fchk_eq = tvm.testing.test_check_eq_callback(
         "InternalError: myop")
@@ -36,14 +36,14 @@ def test_op_translation():
         assert False
     except tvm.error.InternalError as e:
         msg = str(e)
-        assert msg.find("api_test.cc") != -1
+        assert msg.find("ffi_testing.cc") != -1
 
     try:
         tvm.testing.ErrorTest(0, 1)
         assert False
     except ValueError as e:
         msg = str(e)
-        assert msg.find("api_test.cc") != -1
+        assert msg.find("ffi_testing.cc") != -1
 
 
 def test_deep_callback():