[REFACTOR] Polish ffi convention. (#4912)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 19 Feb 2020 21:14:42 +0000 (13:14 -0800)
committerGitHub <noreply@github.com>
Wed, 19 Feb 2020 21:14:42 +0000 (13:14 -0800)
* [REFACTOR] Polish ffi convention.

- Remove the src/api, keep registration local to the c++ function.
- Remove the api_internal as it is no longer needed.

* Update the codebase walk through

34 files changed:
CMakeLists.txt
docs/dev/codebase_walkthrough.rst
python/tvm/_api_internal.py [deleted file]
python/tvm/_ffi/registry.py
src/api/api_arith.cc [deleted file]
src/api/api_ir.cc [deleted file]
src/api/api_lang.cc [deleted file]
src/api/api_schedule.cc [deleted file]
src/arith/analyzer.cc
src/arith/bound_deducer.cc
src/arith/const_int_bound.cc
src/arith/detect_linear_equation.cc
src/arith/domain_touched.cc
src/arith/int_set.cc
src/arith/modular_set.cc
src/ir/expr.cc
src/support/ffi_testing.cc [moved from src/api/api_test.cc with 97% similarity]
src/te/operation/compute_op.cc
src/te/operation/extern_op.cc
src/te/operation/hybrid_op.cc
src/te/operation/placeholder_op.cc
src/te/operation/scan_op.cc
src/te/operation/tensor_compute_op.cc
src/te/schedule/auto_inline_elem_wise.cc
src/te/schedule/bound.cc
src/te/schedule/graph.cc
src/te/schedule/schedule_lang.cc
src/te/schedule/schedule_ops.cc
src/te/tensor.cc
src/tir/ir/expr.cc
src/tir/ir/op.cc
src/tir/ir/stmt.cc
src/tir/pass/ffi_api.cc [moved from src/api/api_pass.cc with 99% similarity]
tests/python/unittest/test_runtime_error.py

index 8540a66..9d25e4a 100644 (file)
@@ -133,7 +133,7 @@ file(GLOB_RECURSE COMPILER_SRCS
     src/tir/*.cc
     src/driver/*.cc
     src/printer/*.cc
-    src/api/*.cc
+    src/support/*.cc
     )
 
 file(GLOB CODEGEN_SRCS
index 0732c26..8513ce5 100644 (file)
@@ -55,7 +55,7 @@ We use a simple example that uses the low level TVM API directly. The example is
    B = tvm.placeholder((n,), name='B')
    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
 
-Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
+Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/te/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/te/tensor.h`` and ``src/te/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.
 
 ::
 
@@ -68,24 +68,12 @@ Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``pytho
 
 The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
 
-``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:
-
-::
-
-   TVM_REGISTER_GLOBAL("_ComputeOp")
-   .set_body([](TVMArgs args,  TVMRetValue* ret) {
-       *ret = ComputeOpNode::make(args[0],
-                                  args[1],
-                                  args[2],
-                                  args[3],
-                                  args[4]);
-     });
-
 We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy.
+You can also checkout `FFI Navigator <https://github.com/tqchen/ffi-navigator>`_ which allows you to navigate between python and c++ FFI calls.
 
-A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
+A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/te/tensor.py``, ``include/tvm/te/operation.h``, and ``src/tvm/te/operation`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``.
 
-We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/schedule.py``.
+We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/te/schedule.py``.
 
 ::
 
@@ -103,7 +91,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.
 
 ``Stage`` corresponds to one ``Operation``. In the vector add example above, there are two placeholder ops and one compute op, so the schedule ``s`` contains three stages. Each ``Stage`` holds information about a loop nest structure, types of each loop (``Parallel``, ``Vectorized``, ``Unrolled``), and where to execute its computation in the loop nest of the next ``Stage``, if any.
 
-``Schedule`` and ``Stage`` are defined in ``tvm/python/schedule.py``, ``include/tvm/schedule.h``, and ``src/schedule/schedule_ops.cc``.
+``Schedule`` and ``Stage`` are defined in ``tvm/python/te/schedule.py``, ``include/tvm/te/schedule.h``, and ``src/te/schedule/schedule_ops.cc``.
 
 To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above.
 
@@ -112,7 +100,7 @@ To keep it simple, we call ``tvm.build(...)`` on the default schedule created by
    target = "cuda"
    fadd = tvm.build(s, [A, B, C], target)
 
-``tvm.build()``, defined in ``python/tvm/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a ``tvm.Module`` object, defined in ``python/tvm/module.py``. A ``Module`` object contains a compiled function which can be invoked with function call syntax.
+``tvm.build()``, defined in ``python/tvm/driver/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a :py:class:`tvm.runtime.Module` object. A :py:class:`tvm.runtime.Module` object contains a compiled function which can be invoked with function call syntax.
 
 The process of ``tvm.build()`` can be divided into two steps:
 
@@ -133,14 +121,14 @@ Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_modu
       stmt = schedule.ScheduleOps(sch, bounds)
       ...
 
-Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
+Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/te/schedule/bound.cc``, ``src/te/schedule/graph.cc`` and ``src/te/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_.
 
 .. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html
 
 
-``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``.
+``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/te/schedule/schedule_ops.cc``.
 
-Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
+Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/tir/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below.
 
 ::
 
@@ -157,7 +145,7 @@ Next, we apply a number of lowering passes to ``stmt``. These passes are impleme
 
 After lowering is done, ``build()`` function generates target machine code from the lowered function. This code can contain SSE or AVX instructions if you target x86, or PTX instructions for CUDA target. In addition to target specific machine code, TVM also generates host side code that is responsible for memory management, kernel launch etc.
 
-Code generation is done by ``build_module()`` function, defined in ``python/tvm/codegen.py``. On the C++ side, code generation is implemented in ``src/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/codegen/codegen.cc``:
+Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
 
 ::
 
diff --git a/python/tvm/_api_internal.py b/python/tvm/_api_internal.py
deleted file mode 100644 (file)
index 5715237..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Namespace of internal API
-
-The functions in this namespace are automatically exported from C++ side via PackedFunc
-that is registered by "TVM_REGISTER_*" macro. This way makes calling Python functions from C++
-side very easily.
-
-Each string starts with "_" in the "TVM_REGISTER_*" macro is an internal API. You can find
-all the functions in "api_lang.cc", "api_base.cc", "api_arith.cc" and "api_ir.cc" under "src/api".
-"""
index be15785..e4b8b18 100644 (file)
@@ -19,7 +19,6 @@
 """FFI registry to register function and objects."""
 import sys
 import ctypes
-from .. import _api_internal
 
 from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
 
@@ -288,17 +287,11 @@ def _init_api_prefix(module_name, prefix):
     module = sys.modules[module_name]
 
     for name in list_global_func_names():
-        if prefix == "api":
-            fname = name
-            if name.startswith("_"):
-                target_module = sys.modules["tvm._api_internal"]
-            else:
-                target_module = module
-        else:
-            if not name.startswith(prefix):
-                continue
-            fname = name[len(prefix)+1:]
-            target_module = module
+        if not name.startswith(prefix):
+            continue
+
+        fname = name[len(prefix)+1:]
+        target_module = module
 
         if fname.find(".") != -1:
             continue
diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc
deleted file mode 100644 (file)
index 3942f6e..0000000
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to arith
- * \file api_arith.cc
- */
-#include <tvm/arith/bound.h>
-#include <tvm/arith/int_set.h>
-#include <tvm/arith/pattern.h>
-#include <tvm/arith/analyzer.h>
-
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/te/tensor.h>
-
-namespace tvm {
-namespace arith {
-
-TVM_REGISTER_GLOBAL("arith.intset_single_point")
-.set_body_typed(IntSet::single_point);
-
-TVM_REGISTER_GLOBAL("arith.intset_vector")
-.set_body_typed(IntSet::vector);
-
-TVM_REGISTER_GLOBAL("arith.intset_interval")
-.set_body_typed(IntSet::interval);
-
-
-TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
-.set_body_typed(DetectLinearEquation);
-
-TVM_REGISTER_GLOBAL("arith.DetectClipBound")
-.set_body_typed(DetectClipBound);
-
-TVM_REGISTER_GLOBAL("arith.DeduceBound")
-.set_body_typed([](
-  PrimExpr v, PrimExpr cond,
-  const Map<Var, IntSet> hint_map,
-  const Map<Var, IntSet> relax_map
-) {
-  return DeduceBound(v, cond, hint_map, relax_map);
-});
-
-
-TVM_REGISTER_GLOBAL("arith.DomainTouched")
-.set_body_typed(DomainTouched);
-
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
-.set_body_method(&IntSet::min);
-
-TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
-.set_body_method(&IntSet::max);
-
-TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
-.set_body_method(&IntSet::is_nothing);
-
-TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
-.set_body_method(&IntSet::is_everything);
-
-ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
-  return ConstIntBound(min_value, max_value);
-}
-
-TVM_REGISTER_GLOBAL("arith.ConstIntBound")
-.set_body_typed(MakeConstIntBound);
-
-ModularSet MakeModularSet(int64_t coeff, int64_t base) {
-  return ModularSet(coeff, base);
-}
-
-TVM_REGISTER_GLOBAL("arith.ModularSet")
-.set_body_typed(MakeModularSet);
-
-TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    using runtime::PackedFunc;
-    using runtime::TypedPackedFunc;
-    auto self = std::make_shared<Analyzer>();
-    auto f = [self](std::string name) -> PackedFunc {
-      if (name == "const_int_bound") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->const_int_bound(args[0]);
-          });
-      } else if (name == "modular_set") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->modular_set(args[0]);
-        });
-      } else if (name == "const_int_bound_update") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            self->const_int_bound.Update(args[0], args[1], args[2]);
-        });
-      } else if (name == "Simplify") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->Simplify(args[0]);
-        });
-      } else if (name == "rewrite_simplify") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->rewrite_simplify(args[0]);
-        });
-      } else if (name == "canonical_simplify") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->canonical_simplify(args[0]);
-        });
-      } else if (name == "int_set") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            *ret = self->int_set(args[0], args[1]);
-        });
-      } else if (name == "bind") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            if (args[1].IsObjectRef<Range>()) {
-              self->Bind(args[0], args[1].operator Range());
-            } else {
-              self->Bind(args[0], args[1].operator PrimExpr());
-            }
-        });
-      } else if (name == "enter_constraint_context") {
-        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
-            // can't use make_shared due to noexcept(false) decl in destructor,
-            // see https://stackoverflow.com/a/43907314
-            auto ctx = std::shared_ptr<With<ConstraintContext> >(
-                new With<ConstraintContext>(self.get(), args[0]));
-            auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
-              ctx.reset();
-            };
-            *ret = PackedFunc(fexit);
-        });
-      }
-      return PackedFunc();
-    };
-    *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
-});
-
-}  // namespace arith
-}  // namespace tvm
diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc
deleted file mode 100644 (file)
index 1e71baf..0000000
+++ /dev/null
@@ -1,237 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to IR build
- * \file api_ir.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/tir/expr.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/tir/op.h>
-
-namespace tvm {
-namespace tir {
-
-TVM_REGISTER_GLOBAL("tir.Var")
-.set_body_typed([](std::string s, DataType t) {
-    return Var(s, t);
-  });
-
-TVM_REGISTER_GLOBAL("tir.SizeVar")
-.set_body_typed([](std::string s, DataType t) {
-    return SizeVar(s, t);
-  });
-
-TVM_REGISTER_GLOBAL("tir.abs")
-.set_body_typed(tvm::abs);
-
-TVM_REGISTER_GLOBAL("tir.isnan")
-.set_body_typed(tvm::isnan);
-
-TVM_REGISTER_GLOBAL("tir.floor")
-.set_body_typed(tvm::floor);
-
-TVM_REGISTER_GLOBAL("tir.ceil")
-.set_body_typed(tvm::ceil);
-
-TVM_REGISTER_GLOBAL("tir.round")
-.set_body_typed(tvm::round);
-
-TVM_REGISTER_GLOBAL("tir.nearbyint")
-.set_body_typed(tvm::nearbyint);
-
-TVM_REGISTER_GLOBAL("tir.trunc")
-.set_body_typed(tvm::trunc);
-
-TVM_REGISTER_GLOBAL("tir._cast")
-.set_body_typed(tvm::cast);
-
-TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
-.set_body_typed(Range::make_by_min_extent);
-
-
-TVM_REGISTER_GLOBAL("tir.SeqStmt")
-.set_body_typed([](Array<Stmt> seq) {
-  return SeqStmt(std::move(seq));
-});
-
-TVM_REGISTER_GLOBAL("tir.For")
-.set_body_typed([](
-  Var loop_var, PrimExpr min, PrimExpr extent,
-  int for_type, int device_api, Stmt body) {
-  return ForNode::make(loop_var,
-                   min,
-                   extent,
-                   static_cast<ForType>(for_type),
-                   static_cast<DeviceAPI>(device_api),
-                   body);
-});
-
-TVM_REGISTER_GLOBAL("tir.Load")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    DataType t = args[0];
-    if (args.size() == 3) {
-      *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
-    } else {
-      *ret = LoadNode::make(t, args[1], args[2], args[3]);
-    }
-  });
-
-TVM_REGISTER_GLOBAL("tir.Store")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    PrimExpr value = args[1];
-    if (args.size() == 3) {
-      *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
-    } else {
-      *ret = StoreNode::make(args[0], value, args[2], args[3]);
-    }
-  });
-
-TVM_REGISTER_GLOBAL("tir.Realize")
-.set_body_typed(RealizeNode::make);
-
-TVM_REGISTER_GLOBAL("tir.Call")
-.set_body_typed([](
-  DataType type, std::string name,
-  Array<PrimExpr> args, int call_type,
-  FunctionRef func, int value_index
-) {
-  return CallNode::make(type,
-                    name,
-                    args,
-                    static_cast<CallNode::CallType>(call_type),
-                    func,
-                    value_index);
-});
-
-TVM_REGISTER_GLOBAL("tir.CommReducer")
-.set_body_typed(CommReducerNode::make);
-
-// make from two arguments
-#define REGISTER_MAKE(NodeName)                                     \
-  TVM_REGISTER_GLOBAL("tir."#NodeName)                             \
-  .set_body_typed(NodeName ## Node::make);                          \
-
-
-REGISTER_MAKE(Reduce);
-REGISTER_MAKE(AttrStmt);
-
-REGISTER_MAKE(StringImm);
-
-REGISTER_MAKE(Add);
-REGISTER_MAKE(Sub);
-REGISTER_MAKE(Mul);
-REGISTER_MAKE(Div);
-REGISTER_MAKE(Mod);
-REGISTER_MAKE(FloorDiv);
-REGISTER_MAKE(FloorMod);
-REGISTER_MAKE(Min);
-REGISTER_MAKE(Max);
-REGISTER_MAKE(EQ);
-REGISTER_MAKE(NE);
-REGISTER_MAKE(LT);
-REGISTER_MAKE(LE);
-REGISTER_MAKE(GT);
-REGISTER_MAKE(GE);
-REGISTER_MAKE(And);
-REGISTER_MAKE(Or);
-
-REGISTER_MAKE(Not);
-REGISTER_MAKE(Select);
-REGISTER_MAKE(Ramp);
-REGISTER_MAKE(Cast);
-REGISTER_MAKE(Broadcast);
-REGISTER_MAKE(Shuffle);
-REGISTER_MAKE(Let);
-REGISTER_MAKE(LetStmt);
-REGISTER_MAKE(AssertStmt);
-REGISTER_MAKE(ProducerConsumer);
-REGISTER_MAKE(Provide);
-REGISTER_MAKE(Prefetch);
-REGISTER_MAKE(Free);
-REGISTER_MAKE(IfThenElse);
-REGISTER_MAKE(Evaluate);
-
-// overloaded, needs special handling
-// has default args
-TVM_REGISTER_GLOBAL("tir.Allocate")
-  .set_body_typed([](
-    Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
-  ){
-    return AllocateNode::make(buffer_var, type, extents, condition, body);
-  });
-
-// operator overloading, smarter than make
-#define REGISTER_MAKE_BINARY_OP(Node, Func)                    \
-  TVM_REGISTER_GLOBAL("tir."#Node)                              \
-  .set_body_typed([](PrimExpr a, PrimExpr b) {                  \
-    return (Func(a, b));                                        \
-  })
-
-#define REGISTER_MAKE_BIT_OP(Node, Func)                               \
-  TVM_REGISTER_GLOBAL("tir."#Node)                                      \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
-    bool lhs_is_int = args[0].type_code() == kDLInt;                    \
-    bool rhs_is_int = args[1].type_code() == kDLInt;                    \
-    if (lhs_is_int) {                                                   \
-      *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
-    } else if (rhs_is_int) {                                            \
-      *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
-    } else {                                                            \
-      *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
-    }                                                                   \
-  })
-
-
-REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
-REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
-REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
-REGISTER_MAKE_BINARY_OP(_OpDiv, div);
-REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
-REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
-REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
-REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
-REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
-REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
-REGISTER_MAKE_BINARY_OP(_OpPow, pow);
-REGISTER_MAKE_BINARY_OP(_OpMin, min);
-REGISTER_MAKE_BINARY_OP(_OpMax, max);
-REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
-REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
-REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGT, operator>);  // NOLINT(*)
-REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
-REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
-REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
-REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
-REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
-REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
-REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
-REGISTER_MAKE_BIT_OP(right_shift, operator>>);
-TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
-.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
-  return if_then_else(cond, true_value, false_value);
-});
-
-}  // namespace tir
-}  // namespace tvm
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
deleted file mode 100644 (file)
index 613b823..0000000
+++ /dev/null
@@ -1,223 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to Higher DSL build.
- * \file api_lang.cc
- */
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
-#include <tvm/te/operation.h>
-#include <tvm/tir/buffer.h>
-#include <tvm/te/schedule.h>
-#include <tvm/runtime/registry.h>
-
-#include <tvm/driver/driver_api.h>
-#include <tvm/tir/data_layout.h>
-
-namespace tvm {
-
-TVM_REGISTER_GLOBAL("tir.min_value")
-.set_body_typed(min_value);
-
-TVM_REGISTER_GLOBAL("tir.max_value")
-.set_body_typed(max_value);
-
-TVM_REGISTER_GLOBAL("ir.Range")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-  *ret = Range(args[0], args[1]);
-  });
-
-namespace tir {
-TVM_REGISTER_GLOBAL("tir.IterVar")
-.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
-  return IterVarNode::make(
-      dom, var,
-      static_cast<IterVarType>(iter_type),
-      thread_tag);
-});
-}
-
-namespace te {
-TVM_REGISTER_GLOBAL("te.Tensor")
-.set_body_typed(TensorNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorIntrin")
-.set_body_typed(TensorIntrinNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
-.set_body_typed(TensorIntrinCallNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorEqual")
-.set_body_method(&Tensor::operator==);
-
-TVM_REGISTER_GLOBAL("te.TensorHash")
-.set_body_typed([](Tensor tensor) -> int64_t {
-    return static_cast<int64_t>(std::hash<Tensor>()(tensor));
-  });
-
-TVM_REGISTER_GLOBAL("te.Placeholder")
-.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
-  return placeholder(shape, dtype, name);
-});
-
-TVM_REGISTER_GLOBAL("te.ComputeOp")
-.set_body_typed(ComputeOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.ScanOp")
-.set_body_typed(ScanOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorComputeOp")
-.set_body_typed(TensorComputeOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.ExternOp")
-.set_body_typed(ExternOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.HybridOp")
-.set_body_typed(HybridOpNode::make);
-
-TVM_REGISTER_GLOBAL("te.OpGetOutput")
-.set_body_typed([](Operation op, int64_t output) {
-  return op.output(static_cast<size_t>(output));
-});
-
-TVM_REGISTER_GLOBAL("te.OpNumOutputs")
-.set_body_method<Operation>(&OperationNode::num_outputs);
-
-TVM_REGISTER_GLOBAL("te.OpInputTensors")
-.set_body_method<Operation>(&OperationNode::InputTensors);
-
-TVM_REGISTER_GLOBAL("te.CreateSchedule")
-.set_body_typed(create_schedule);
-
-TVM_REGISTER_GLOBAL("te.StageSetScope")
-.set_body_method(&Stage::set_scope);
-
-TVM_REGISTER_GLOBAL("te.StageBind")
-.set_body_method(&Stage::bind);
-
-TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
-  IterVar outer, inner;
-  stage.split(parent, factor, &outer, &inner);
-  return Array<IterVar>({outer, inner});
-});
-
-TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
-.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
-  IterVar outer, inner;
-  stage.split_by_nparts(parent, nparts, &outer, &inner);
-  return Array<IterVar>({outer, inner});
-});
-
-TVM_REGISTER_GLOBAL("te.StageFuse")
-.set_body_typed([](Stage stage, Array<IterVar> axes) {
-    IterVar fused;
-    stage.fuse(axes, &fused);
-    return fused;
-  });
-
-TVM_REGISTER_GLOBAL("te.StageComputeAt")
-.set_body_method(&Stage::compute_at);
-
-TVM_REGISTER_GLOBAL("te.StageComputeInline")
-.set_body_method(&Stage::compute_inline);
-
-TVM_REGISTER_GLOBAL("te.StageComputeRoot")
-.set_body_method(&Stage::compute_root);
-
-TVM_REGISTER_GLOBAL("te.StageReorder")
-.set_body_method(&Stage::reorder);
-
-TVM_REGISTER_GLOBAL("te.StageTile")
-.set_body_typed([](
-  Stage stage,
-  IterVar x_parent, IterVar y_parent,
-  PrimExpr x_factor, PrimExpr y_factor
-) {
-    IterVar x_outer, y_outer, x_inner, y_inner;
-    stage.tile(x_parent, y_parent,
-               x_factor, y_factor,
-               &x_outer, &y_outer,
-               &x_inner, &y_inner);
-    return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
-  });
-
-TVM_REGISTER_GLOBAL("te.StageEnvThreads")
-.set_body_method(&Stage::env_threads);
-
-TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
-.set_body_method(&Stage::set_store_predicate);
-
-TVM_REGISTER_GLOBAL("te.StageUnroll")
-.set_body_method(&Stage::unroll);
-
-TVM_REGISTER_GLOBAL("te.StageVectorize")
-.set_body_method(&Stage::vectorize);
-
-TVM_REGISTER_GLOBAL("te.StageTensorize")
-.set_body_method(&Stage::tensorize);
-
-TVM_REGISTER_GLOBAL("te.StageParallel")
-.set_body_method(&Stage::parallel);
-
-TVM_REGISTER_GLOBAL("te.StagePragma")
-.set_body_method(&Stage::pragma);
-
-TVM_REGISTER_GLOBAL("te.StagePrefetch")
-.set_body_method(&Stage::prefetch);
-
-TVM_REGISTER_GLOBAL("te.StageStorageAlign")
-.set_body_method(&Stage::storage_align);
-
-TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
-.set_body_method(&Stage::double_buffer);
-
-TVM_REGISTER_GLOBAL("te.StageOpenGL")
-.set_body_method(&Stage::opengl);
-
-TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
-.set_body_method(&Schedule::normalize);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
-.set_body_method(&Schedule::create_group);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
-.set_body_method(&Schedule::cache_read);
-
-TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    if (args[1].IsObjectRef<Tensor>()) {
-      *ret = args[0].operator Schedule()
-          .cache_write(args[1].operator Tensor(), args[2]);
-    } else {
-      *ret = args[0].operator Schedule()
-          .cache_write(args[1].operator Array<Tensor>(), args[2]);
-    }
-  });
-
-TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
-.set_body_method(&Schedule::rfactor);
-}  // namespace te
-
-TVM_REGISTER_GLOBAL("te.CommReducerCombine")
-.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
-
-}  // namespace tvm
diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc
deleted file mode 100644 (file)
index a53c6e9..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of API functions related to schedule pass.
- * \file api_schedule.cc
- */
-#include <tvm/tir/expr.h>
-#include <tvm/te/tensor.h>
-#include <tvm/te/schedule.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/runtime/registry.h>
-
-#include "../te/schedule/graph.h"
-
-namespace tvm {
-namespace te {
-
-TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
-.set_body_typed(AutoInlineElemWise);
-
-
-TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
-.set_body_typed(AutoInlineInjective);
-
-TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  if (args.size() == 2)
-    *ret = ScheduleOps(args[0], args[1], false);
-  else
-    *ret = ScheduleOps(args[0], args[1], args[2]);
-});
-
-#define REGISTER_SCHEDULE_PASS(PassName)                             \
-  TVM_REGISTER_GLOBAL("schedule."#PassName)                          \
-  .set_body_typed(PassName);                                         \
-
-
-REGISTER_SCHEDULE_PASS(InferBound);
-REGISTER_SCHEDULE_PASS(CreateReadGraph);
-REGISTER_SCHEDULE_PASS(PostDFSOrder);
-REGISTER_SCHEDULE_PASS(CreateAttachPath);
-REGISTER_SCHEDULE_PASS(ScanGetBody);
-REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
-
-}  // namespace te
-}  // namespace tvm
index b12e5f5..9df5aa2 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file tvm/arith/analyzer.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/op.h>
@@ -109,5 +110,64 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
   return res;
 }
 
+TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    using runtime::PackedFunc;
+    using runtime::TypedPackedFunc;
+    auto self = std::make_shared<Analyzer>();
+    auto f = [self](std::string name) -> PackedFunc {
+      if (name == "const_int_bound") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->const_int_bound(args[0]);
+          });
+      } else if (name == "modular_set") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->modular_set(args[0]);
+        });
+      } else if (name == "const_int_bound_update") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            self->const_int_bound.Update(args[0], args[1], args[2]);
+        });
+      } else if (name == "Simplify") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->Simplify(args[0]);
+        });
+      } else if (name == "rewrite_simplify") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->rewrite_simplify(args[0]);
+        });
+      } else if (name == "canonical_simplify") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->canonical_simplify(args[0]);
+        });
+      } else if (name == "int_set") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->int_set(args[0], args[1]);
+        });
+      } else if (name == "bind") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            if (args[1].IsObjectRef<Range>()) {
+              self->Bind(args[0], args[1].operator Range());
+            } else {
+              self->Bind(args[0], args[1].operator PrimExpr());
+            }
+        });
+      } else if (name == "enter_constraint_context") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            // can't use make_shared due to noexcept(false) decl in destructor,
+            // see https://stackoverflow.com/a/43907314
+            auto ctx = std::shared_ptr<With<ConstraintContext> >(
+                new With<ConstraintContext>(self.get(), args[0]));
+            auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
+              ctx.reset();
+            };
+            *ret = PackedFunc(fexit);
+        });
+      }
+      return PackedFunc();
+    };
+    *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
+});
+
 }  // namespace arith
 }  // namespace tvm
index df8f402..26be5d5 100644 (file)
  * \file bound_deducer.cc
  * \brief Utility to deduce bound of expression
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/expr_functor.h>
 #include <tvm/arith/analyzer.h>
-#include <tvm/runtime/registry.h>
 
 #include <unordered_set>
 #include <unordered_map>
@@ -362,5 +362,16 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
   return DeduceBound(v, e, hmap, rmap);
 }
 
+
+TVM_REGISTER_GLOBAL("arith.DeduceBound")
+.set_body_typed([](
+  PrimExpr v, PrimExpr cond,
+  const Map<Var, IntSet> hint_map,
+  const Map<Var, IntSet> relax_map
+) {
+  return DeduceBound(v, cond, hint_map, relax_map);
+});
+
+
 }  // namespace arith
 }  // namespace tvm
index 7fb90a5..9ef5723 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file tvm/arith/const_int_bound.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr_functor.h>
 #include <algorithm>
@@ -41,6 +42,13 @@ ConstIntBound::ConstIntBound(
   data_ = std::move(node);
 }
 
+ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
+  return ConstIntBound(min_value, max_value);
+}
+
+TVM_REGISTER_GLOBAL("arith.ConstIntBound")
+.set_body_typed(MakeConstIntBound);
+
 inline void PrintBoundValue(std::ostream& os, int64_t val) {
   if (val == ConstIntBound::kPosInf) {
     os << "pos_inf";
index 53adf35..cc9c745 100644 (file)
@@ -21,6 +21,7 @@
  * \file detect_linear_equation.cc
  * \brief Utility to detect patterns in the expression.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/expr_functor.h>
@@ -268,6 +269,12 @@ Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
   return ret;
 }
 
+TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
+.set_body_typed(DetectLinearEquation);
 
+TVM_REGISTER_GLOBAL("arith.DetectClipBound")
+.set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
+  return DetectClipBound(e, vars);
+});
 }  // namespace arith
 }  // namespace tvm
index aa1ba4e..4eecabd 100644 (file)
@@ -119,5 +119,8 @@ Domain DomainTouched(Stmt stmt,
   return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt);
 }
 
+TVM_REGISTER_GLOBAL("arith.DomainTouched")
+.set_body_typed(DomainTouched);
+
 }  // namespace arith
 }  // namespace tvm
index adb3879..8c5afb1 100644 (file)
@@ -820,5 +820,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
               << "[" << op->min_value << ", "
               << op->max_value << ']';
   });
+
+
+TVM_REGISTER_GLOBAL("arith.intset_single_point")
+.set_body_typed(IntSet::single_point);
+
+TVM_REGISTER_GLOBAL("arith.intset_vector")
+.set_body_typed(IntSet::vector);
+
+TVM_REGISTER_GLOBAL("arith.intset_interval")
+.set_body_typed(IntSet::interval);
+
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
+.set_body_method(&IntSet::min);
+
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
+.set_body_method(&IntSet::max);
+
+TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
+.set_body_method(&IntSet::is_nothing);
+
+TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
+.set_body_method(&IntSet::is_everything);
+
 }  // namespace arith
 }  // namespace tvm
index c3031ca..40cd7f8 100644 (file)
@@ -21,6 +21,7 @@
  * \file modular_set.cc
  * \brief Modular set analysis
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/expr_functor.h>
@@ -52,6 +53,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
               << op->base << ')';
   });
 
+ModularSet MakeModularSet(int64_t coeff, int64_t base) {
+  return ModularSet(coeff, base);
+}
+
+TVM_REGISTER_GLOBAL("arith.ModularSet")
+.set_body_typed(MakeModularSet);
 
 // internal entry for const int bound
 struct ModularSetAnalyzer::Entry {
index 4feabeb..6244c76 100644 (file)
@@ -134,6 +134,14 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
   return Range(make_object<RangeNode>(min, extent));
 }
 
+TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
+.set_body_typed(Range::make_by_min_extent);
+
+TVM_REGISTER_GLOBAL("ir.Range")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+  *ret = Range(args[0], args[1]);
+  });
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const RangeNode*>(node.get());
similarity index 97%
rename from src/api/api_test.cc
rename to src/support/ffi_testing.cc
index 2a1e605..9053f62 100644 (file)
  */
 
  /*!
- *  Code mainly used for test purposes.
- * \file api_test.cc
+ *  FFI registration code used for frontend testing purposes.
+ * \file ffi_testing.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/te/tensor.h>
 #include <tvm/ir/attrs.h>
-#include <tvm/runtime/registry.h>
 #include <tvm/ir/env_func.h>
 
 namespace tvm {
index 1886d97..6123c61 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Compute Op.
  * \file compute_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -156,6 +157,10 @@ Operation ComputeOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.ComputeOp")
+.set_body_typed(ComputeOpNode::make);
+
+
 // The schedule related logics
 Array<Tensor> ComputeOpNode::InputTensors() const {
   Array<Tensor> ret;
index c1e5504..62c8dfd 100644 (file)
@@ -21,6 +21,7 @@
  * \brief External computation rule.
  * \file extern_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -86,6 +87,10 @@ Operation ExternOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.ExternOp")
+.set_body_typed(ExternOpNode::make);
+
+
 Array<Tensor> ExternOpNode::InputTensors() const {
   return inputs;
 }
index bb883ae..70abf34 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Hybrid computation rule.
  * \file hybrid_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -83,6 +84,10 @@ Operation HybridOpNode::make(std::string name,
   return res;
 }
 
+TVM_REGISTER_GLOBAL("te.HybridOp")
+.set_body_typed(HybridOpNode::make);
+
+
 Array<Tensor> HybridOpNode::InputTensors() const {
   // Because input tensors could be potentially inlined into hybrid scripts,
   // we need to check if all input tensors are used in the body.
index 866ef94..d48be4c 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Placeholder op.
  * \file placeholder_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 
 namespace tvm {
@@ -67,6 +68,11 @@ Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
   return PlaceholderOpNode::make(name, shape, dtype).output(0);
 }
 
+TVM_REGISTER_GLOBAL("te.Placeholder")
+.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
+  return placeholder(shape, dtype, name);
+});
+
 Array<Tensor> PlaceholderOpNode::InputTensors() const {
   return {};
 }
index cacfd8c..956a297 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Scan Operator.
  * \file scan_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
@@ -120,6 +121,10 @@ Operation ScanOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.ScanOp")
+.set_body_typed(ScanOpNode::make);
+
+
 Array<Tensor> scan(Array<Tensor> init,
                    Array<Tensor> update,
                    Array<Tensor> state_placeholder,
index 8ce621c..4cdc9e1 100644 (file)
@@ -21,6 +21,7 @@
  * \brief Tensor Compute Op.
  * \file tensor_compute_op.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/tir/expr.h>
@@ -72,6 +73,10 @@ Operation TensorComputeOpNode::make(std::string name,
   return Operation(n);
 }
 
+TVM_REGISTER_GLOBAL("te.TensorComputeOp")
+.set_body_typed(TensorComputeOpNode::make);
+
+
 Array<Tensor> TensorComputeOpNode::InputTensors() const {
   return inputs;
 }
index 3a22267..6d79f4a 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file auto_inline_elem_wise.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/schedule_pass.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/expr_functor.h>
@@ -111,5 +112,12 @@ void AutoInlineInjective(Schedule sch) {
   }
 }
 
+TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise")
+.set_body_typed(AutoInlineElemWise);
+
+
+TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective")
+.set_body_typed(AutoInlineInjective);
+
 }  // namespace te
 }  // namespace tvm
index 27896e6..50cbafd 100644 (file)
@@ -21,6 +21,7 @@
  * \file bound.cc
  * \brief The bound inference logic.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/schedule_pass.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/ir_pass.h>
@@ -259,5 +260,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
   return Map<IterVar, Range>(ret.begin(), ret.end());
 }
 
+TVM_REGISTER_GLOBAL("schedule.InferBound")
+.set_body_typed(InferBound);
+
 }  // namespace te
 }  // namespace tvm
index eff0a25..9dce36f 100644 (file)
@@ -21,6 +21,7 @@
  * \file graph.cc
  * \brief Utilities to get information about schedule graph.
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/te/operation.h>
@@ -429,5 +430,24 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
   return ret;
 }
 
+
+TVM_REGISTER_GLOBAL("schedule.CreateReadGraph")
+.set_body_typed(CreateReadGraph);
+
+TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
+.set_body_typed([](const Array<Operation>& roots,
+                   const ReadGraph& g) {
+  return PostDFSOrder(roots, g);
+});
+
+TVM_REGISTER_GLOBAL("schedule.CreateAttachPath")
+.set_body_typed(CreateAttachPath);
+
+TVM_REGISTER_GLOBAL("schedule.ScanGetBody")
+.set_body_typed(ScanGetBody);
+
+TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis")
+.set_body_typed(ScanFixPointAnalysis);
+
 }  // namespace te
 }  // namespace tvm
index 1763bd6..d3b448d 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file schedule_lang.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/schedule.h>
 #include <tvm/te/operation.h>
 #include <unordered_set>
@@ -848,5 +849,118 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     auto* op = static_cast<const ScheduleNode*>(node.get());
     p->stream << "schedule(" << op << ")";
   });
+
+
+TVM_REGISTER_GLOBAL("te.CreateSchedule")
+.set_body_typed(create_schedule);
+
+TVM_REGISTER_GLOBAL("te.StageSetScope")
+.set_body_method(&Stage::set_scope);
+
+TVM_REGISTER_GLOBAL("te.StageBind")
+.set_body_method(&Stage::bind);
+
+TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
+  IterVar outer, inner;
+  stage.split(parent, factor, &outer, &inner);
+  return Array<IterVar>({outer, inner});
+});
+
+TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
+  IterVar outer, inner;
+  stage.split_by_nparts(parent, nparts, &outer, &inner);
+  return Array<IterVar>({outer, inner});
+});
+
+TVM_REGISTER_GLOBAL("te.StageFuse")
+.set_body_typed([](Stage stage, Array<IterVar> axes) {
+    IterVar fused;
+    stage.fuse(axes, &fused);
+    return fused;
+  });
+
+TVM_REGISTER_GLOBAL("te.StageComputeAt")
+.set_body_method(&Stage::compute_at);
+
+TVM_REGISTER_GLOBAL("te.StageComputeInline")
+.set_body_method(&Stage::compute_inline);
+
+TVM_REGISTER_GLOBAL("te.StageComputeRoot")
+.set_body_method(&Stage::compute_root);
+
+TVM_REGISTER_GLOBAL("te.StageReorder")
+.set_body_method(&Stage::reorder);
+
+TVM_REGISTER_GLOBAL("te.StageTile")
+.set_body_typed([](
+  Stage stage,
+  IterVar x_parent, IterVar y_parent,
+  PrimExpr x_factor, PrimExpr y_factor
+) {
+    IterVar x_outer, y_outer, x_inner, y_inner;
+    stage.tile(x_parent, y_parent,
+               x_factor, y_factor,
+               &x_outer, &y_outer,
+               &x_inner, &y_inner);
+    return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
+  });
+
+TVM_REGISTER_GLOBAL("te.StageEnvThreads")
+.set_body_method(&Stage::env_threads);
+
+TVM_REGISTER_GLOBAL("te.StageSetStorePredicate")
+.set_body_method(&Stage::set_store_predicate);
+
+TVM_REGISTER_GLOBAL("te.StageUnroll")
+.set_body_method(&Stage::unroll);
+
+TVM_REGISTER_GLOBAL("te.StageVectorize")
+.set_body_method(&Stage::vectorize);
+
+TVM_REGISTER_GLOBAL("te.StageTensorize")
+.set_body_method(&Stage::tensorize);
+
+TVM_REGISTER_GLOBAL("te.StageParallel")
+.set_body_method(&Stage::parallel);
+
+TVM_REGISTER_GLOBAL("te.StagePragma")
+.set_body_method(&Stage::pragma);
+
+TVM_REGISTER_GLOBAL("te.StagePrefetch")
+.set_body_method(&Stage::prefetch);
+
+TVM_REGISTER_GLOBAL("te.StageStorageAlign")
+.set_body_method(&Stage::storage_align);
+
+TVM_REGISTER_GLOBAL("te.StageDoubleBuffer")
+.set_body_method(&Stage::double_buffer);
+
+TVM_REGISTER_GLOBAL("te.StageOpenGL")
+.set_body_method(&Stage::opengl);
+
+TVM_REGISTER_GLOBAL("te.ScheduleNormalize")
+.set_body_method(&Schedule::normalize);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup")
+.set_body_method(&Schedule::create_group);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCacheRead")
+.set_body_method(&Schedule::cache_read);
+
+TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    if (args[1].IsObjectRef<Tensor>()) {
+      *ret = args[0].operator Schedule()
+          .cache_write(args[1].operator Tensor(), args[2]);
+    } else {
+      *ret = args[0].operator Schedule()
+          .cache_write(args[1].operator Array<Tensor>(), args[2]);
+    }
+  });
+
+TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
+.set_body_method(&Schedule::rfactor);
 }  // namespace te
 }  // namespace tvm
index 0930f26..a110bc4 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file schedule_ops.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/tir/stmt_functor.h>
@@ -423,5 +424,13 @@ Stmt ScheduleOps(
   return post_proc(std::move(body));
 }
 
+TVM_REGISTER_GLOBAL("schedule.ScheduleOps")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  if (args.size() == 2)
+    *ret = ScheduleOps(args[0], args[1], false);
+  else
+    *ret = ScheduleOps(args[0], args[1], args[2]);
+});
+
 }  // namespace te
 }  // namespace tvm
index f200514..cb14f6a 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file tensor.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/te/tensor.h>
 #include <tvm/te/operation.h>
 #include <tvm/te/tensor_intrin.h>
@@ -147,5 +148,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
 
+TVM_REGISTER_GLOBAL("te.Tensor")
+.set_body_typed(TensorNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorIntrin")
+.set_body_typed(TensorIntrinNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorIntrinCall")
+.set_body_typed(TensorIntrinCallNode::make);
+
+TVM_REGISTER_GLOBAL("te.TensorEqual")
+.set_body_method(&Tensor::operator==);
+
+TVM_REGISTER_GLOBAL("te.TensorHash")
+.set_body_typed([](Tensor tensor) -> int64_t {
+    return static_cast<int64_t>(std::hash<Tensor>()(tensor));
+  });
+
+TVM_REGISTER_GLOBAL("te.OpGetOutput")
+.set_body_typed([](Operation op, int64_t output) {
+  return op.output(static_cast<size_t>(output));
+});
+
+TVM_REGISTER_GLOBAL("te.OpNumOutputs")
+.set_body_method<Operation>(&OperationNode::num_outputs);
+
+TVM_REGISTER_GLOBAL("te.OpInputTensors")
+.set_body_method<Operation>(&OperationNode::InputTensors);
+
 }  // namespace te
 }  // namespace tvm
index d06c33f..2284474 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file expr.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/op.h>
@@ -45,6 +46,17 @@ SizeVar::SizeVar(std::string name_hint, DataType t)
 SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
     : VarNode(t, std::move(name_hint)) {}
 
+
+TVM_REGISTER_GLOBAL("tir.Var")
+.set_body_typed([](std::string s, DataType t) {
+    return Var(s, t);
+  });
+
+TVM_REGISTER_GLOBAL("tir.SizeVar")
+.set_body_typed([](std::string s, DataType t) {
+    return SizeVar(s, t);
+  });
+
 IterVar IterVarNode::make(Range dom,
                           Var var,
                           IterVarType t,
@@ -57,6 +69,14 @@ IterVar IterVarNode::make(Range dom,
   return IterVar(n);
 }
 
+TVM_REGISTER_GLOBAL("tir.IterVar")
+.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
+  return IterVarNode::make(
+      dom, var,
+      static_cast<IterVarType>(iter_type),
+      thread_tag);
+});
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const IterVarNode*>(node.get());
@@ -83,6 +103,9 @@ PrimExpr StringImmNode::make(std::string value) {
   return PrimExpr(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.StringImm")
+.set_body_typed(StringImmNode::make);
+
 PrimExpr CastNode::make(DataType t, PrimExpr value) {
   CHECK(value.defined());
   CHECK_EQ(t.lanes(), value.dtype().lanes());
@@ -311,6 +334,13 @@ Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b
     });
 }
 
+TVM_REGISTER_GLOBAL("tir.CommReducer")
+.set_body_typed(CommReducerNode::make);
+
+TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
+.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
+
+
 PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
                   Array<IterVar> axis, PrimExpr condition, int value_index) {
   for (size_t i = 0; i < axis.size(); ++i) {
@@ -334,6 +364,11 @@ PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
   return PrimExpr(n);
 }
 
+
+TVM_REGISTER_GLOBAL("tir.Reduce")
+.set_body_typed(ReduceNode::make);
+
+
 PrimExpr AnyNode::make() {
   auto n = make_object<AnyNode>();
   return PrimExpr(n);
@@ -659,5 +694,104 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode);
 TVM_REGISTER_NODE_TYPE(ReduceNode);
 TVM_REGISTER_NODE_TYPE(AnyNode);
 
+
+TVM_REGISTER_GLOBAL("tir.Add")
+.set_body_typed(AddNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Sub")
+.set_body_typed(SubNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Mul")
+.set_body_typed(MulNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Div")
+.set_body_typed(DivNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Mod")
+.set_body_typed(ModNode::make);
+
+TVM_REGISTER_GLOBAL("tir.FloorDiv")
+.set_body_typed(FloorDivNode::make);
+
+TVM_REGISTER_GLOBAL("tir.FloorMod")
+.set_body_typed(FloorModNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Min")
+.set_body_typed(MinNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Max")
+.set_body_typed(MaxNode::make);
+
+TVM_REGISTER_GLOBAL("tir.EQ")
+.set_body_typed(EQNode::make);
+
+TVM_REGISTER_GLOBAL("tir.NE")
+.set_body_typed(NENode::make);
+
+TVM_REGISTER_GLOBAL("tir.LT")
+.set_body_typed(LTNode::make);
+
+TVM_REGISTER_GLOBAL("tir.LE")
+.set_body_typed(LENode::make);
+
+TVM_REGISTER_GLOBAL("tir.GT")
+.set_body_typed(GTNode::make);
+
+TVM_REGISTER_GLOBAL("tir.GE")
+.set_body_typed(GENode::make);
+
+TVM_REGISTER_GLOBAL("tir.And")
+.set_body_typed(AndNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Or")
+.set_body_typed(OrNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Not")
+.set_body_typed(NotNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Select")
+.set_body_typed(SelectNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Ramp")
+.set_body_typed(RampNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Cast")
+.set_body_typed(CastNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Broadcast")
+.set_body_typed(BroadcastNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Shuffle")
+.set_body_typed(ShuffleNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Let")
+.set_body_typed(LetNode::make);
+
+TVM_REGISTER_GLOBAL("tir.Load")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    DataType t = args[0];
+    if (args.size() == 3) {
+      *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
+    } else {
+      *ret = LoadNode::make(t, args[1], args[2], args[3]);
+    }
+  });
+
+
+
+TVM_REGISTER_GLOBAL("tir.Call")
+.set_body_typed([](
+  DataType type, std::string name,
+  Array<PrimExpr> args, int call_type,
+  FunctionRef func, int value_index
+) {
+  return CallNode::make(type,
+                    name,
+                    args,
+                    static_cast<CallNode::CallType>(call_type),
+                    func,
+                    value_index);
+});
+
 }  // namespace tir
 }  // namespace tvm
index 58f8b6b..452c3bb 100644 (file)
@@ -662,4 +662,90 @@ TVM_REGISTER_GLOBAL("node.LargeUIntImm")
 TVM_REGISTER_GLOBAL("node.String")
 .set_body_typed(tir::StringImmNode::make);
 
+TVM_REGISTER_GLOBAL("tir.min_value")
+.set_body_typed(min_value);
+
+TVM_REGISTER_GLOBAL("tir.max_value")
+.set_body_typed(max_value);
+
+TVM_REGISTER_GLOBAL("tir.abs")
+.set_body_typed(tvm::abs);
+
+TVM_REGISTER_GLOBAL("tir.isnan")
+.set_body_typed(tvm::isnan);
+
+TVM_REGISTER_GLOBAL("tir.floor")
+.set_body_typed(tvm::floor);
+
+TVM_REGISTER_GLOBAL("tir.ceil")
+.set_body_typed(tvm::ceil);
+
+TVM_REGISTER_GLOBAL("tir.round")
+.set_body_typed(tvm::round);
+
+TVM_REGISTER_GLOBAL("tir.nearbyint")
+.set_body_typed(tvm::nearbyint);
+
+TVM_REGISTER_GLOBAL("tir.trunc")
+.set_body_typed(tvm::trunc);
+
+TVM_REGISTER_GLOBAL("tir._cast")
+.set_body_typed(tvm::cast);
+
+
+
+// operator overloading, smarter than make
+#define REGISTER_MAKE_BINARY_OP(Node, Func)                     \
+  TVM_REGISTER_GLOBAL("tir."#Node)                              \
+  .set_body_typed([](PrimExpr a, PrimExpr b) {                  \
+    return (Func(a, b));                                        \
+  })
+
+#define REGISTER_MAKE_BIT_OP(Node, Func)                                \
+  TVM_REGISTER_GLOBAL("tir."#Node)                                      \
+  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
+    bool lhs_is_int = args[0].type_code() == kDLInt;                    \
+    bool rhs_is_int = args[1].type_code() == kDLInt;                    \
+    if (lhs_is_int) {                                                   \
+      *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
+    } else if (rhs_is_int) {                                            \
+      *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
+    } else {                                                            \
+      *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
+    }                                                                   \
+  })
+
+
+REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
+REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
+REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
+REGISTER_MAKE_BINARY_OP(_OpDiv, div);
+REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
+REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
+REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
+REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
+REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
+REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpPow, pow);
+REGISTER_MAKE_BINARY_OP(_OpMin, min);
+REGISTER_MAKE_BINARY_OP(_OpMax, max);
+REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
+REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
+REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpGT, operator>);  // NOLINT(*)
+REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
+REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
+REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
+REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
+REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
+REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
+REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
+REGISTER_MAKE_BIT_OP(right_shift, operator>>);
+
+TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
+.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
+  return if_then_else(cond, true_value, false_value);
+});
 }  // namespace tvm
index 0cd2aba..a8fe9cd 100644 (file)
@@ -20,7 +20,7 @@
 /*!
  * \file tvm/tir/stmt.cc
  */
-
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/ir_pass.h>
 #include "../pass/ir_util.h"
@@ -40,6 +40,9 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.LetStmt")
+.set_body_typed(LetStmtNode::make);
+
 Stmt AttrStmtNode::make(ObjectRef node,
                     std::string attr_key,
                     PrimExpr value,
@@ -52,6 +55,10 @@ Stmt AttrStmtNode::make(ObjectRef node,
   return Stmt(n);
 }
 
+TVM_REGISTER_GLOBAL("tir.AttrStmt")
+.set_body_typed(AttrStmtNode::make);
+
+
 Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
   CHECK(condition.defined());
   CHECK(message.dtype() == DataType::Int(32) ||
@@ -66,6 +73,10 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.AssertStmt")
+.set_body_typed(AssertStmtNode::make);
+
+
 Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
   CHECK(body.defined());
 
@@ -76,6 +87,10 @@ Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
+.set_body_typed(ProducerConsumerNode::make);
+
+
 Stmt ForNode::make(Var loop_var,
                PrimExpr min,
                PrimExpr extent,
@@ -99,6 +114,19 @@ Stmt ForNode::make(Var loop_var,
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.For")
+.set_body_typed([](
+  Var loop_var, PrimExpr min, PrimExpr extent,
+  int for_type, int device_api, Stmt body) {
+  return ForNode::make(loop_var,
+                   min,
+                   extent,
+                   static_cast<ForType>(for_type),
+                   static_cast<DeviceAPI>(device_api),
+                   body);
+});
+
+
 Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
   CHECK(value.defined());
   CHECK(index.defined());
@@ -114,6 +142,18 @@ Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr pr
   return Stmt(node);
 }
 
+
+TVM_REGISTER_GLOBAL("tir.Store")
+.set_body([](TVMArgs args,  TVMRetValue *ret) {
+    PrimExpr value = args[1];
+    if (args.size() == 3) {
+      *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
+    } else {
+      *ret = StoreNode::make(args[0], value, args[2], args[3]);
+    }
+  });
+
+
 Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
   CHECK(value_index >=0 && value_index < func->num_outputs())
       << "value index output function return value bound";
@@ -131,6 +171,10 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Provide")
+.set_body_typed(ProvideNode::make);
+
+
 Stmt AllocateNode::make(Var buffer_var,
                     DataType dtype,
                     Array<PrimExpr> extents,
@@ -157,6 +201,15 @@ Stmt AllocateNode::make(Var buffer_var,
     return Stmt(node);
 }
 
+// overloaded, needs special handling
+// has default args
+TVM_REGISTER_GLOBAL("tir.Allocate")
+.set_body_typed([](
+    Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
+                   ){
+  return AllocateNode::make(buffer_var, type, extents, condition, body);
+});
+
 int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
   int64_t result = 1;
   for (size_t i = 0; i < extents.size(); ++i) {
@@ -178,12 +231,16 @@ Stmt FreeNode::make(Var buffer_var) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Free")
+.set_body_typed(FreeNode::make);
+
+
 Stmt RealizeNode::make(FunctionRef func,
-                   int value_index,
-                   DataType dtype,
-                   Region bounds,
-                   PrimExpr condition,
-                   Stmt body) {
+                       int value_index,
+                       DataType dtype,
+                       Region bounds,
+                       PrimExpr condition,
+                       Stmt body) {
   for (size_t i = 0; i < bounds.size(); ++i) {
     CHECK(bounds[i]->min.defined());
     CHECK(bounds[i]->extent.defined());
@@ -204,6 +261,11 @@ Stmt RealizeNode::make(FunctionRef func,
   return Stmt(node);
 }
 
+
+TVM_REGISTER_GLOBAL("tir.Realize")
+.set_body_typed(RealizeNode::make);
+
+
 Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
   for (size_t i = 0; i < bounds.size(); ++i) {
     CHECK(bounds[i]->min.defined());
@@ -220,12 +282,21 @@ Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Regio
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Prefetch")
+.set_body_typed(PrefetchNode::make);
+
+
 SeqStmt::SeqStmt(Array<Stmt> seq) {
   auto node = make_object<SeqStmtNode>();
   node->seq = std::move(seq);
   data_ = std::move(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.SeqStmt")
+.set_body_typed([](Array<Stmt> seq) {
+  return SeqStmt(std::move(seq));
+});
+
 Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
   CHECK(condition.defined());
   CHECK(then_case.defined());
@@ -238,6 +309,10 @@ Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.IfThenElse")
+.set_body_typed(IfThenElseNode::make);
+
+
 Stmt EvaluateNode::make(PrimExpr value) {
   CHECK(value.defined());
 
@@ -246,6 +321,9 @@ Stmt EvaluateNode::make(PrimExpr value) {
   return Stmt(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.Evaluate")
+.set_body_typed(EvaluateNode::make);
+
 // Printers
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
similarity index 99%
rename from src/api/api_pass.cc
rename to src/tir/pass/ffi_api.cc
index 75d5439..233bfa5 100644 (file)
@@ -19,7 +19,7 @@
 
 /*!
  *  Exposure of pass functions.
- * \file api_pass.cc
+ * \file ffi_api.cc
  */
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
@@ -136,8 +136,8 @@ TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
 
 // make from two arguments
 #define REGISTER_PASS(PassName)                                   \
-  TVM_REGISTER_GLOBAL("ir_pass."#PassName)                           \
-  .set_body_typed(PassName);                                     \
+  TVM_REGISTER_GLOBAL("ir_pass."#PassName)                        \
+  .set_body_typed(PassName);                                      \
 
 
 REGISTER_PASS(ConvertSSA);
index d1a2d98..ac019a0 100644 (file)
@@ -27,7 +27,7 @@ def test_op_translation():
     except tvm.error.OpNotImplemented as e:
         msg = str(e)
         assert isinstance(e, NotImplementedError)
-        assert msg.find("api_test.cc") != -1
+        assert msg.find("ffi_testing.cc") != -1
 
     fchk_eq = tvm.testing.test_check_eq_callback(
         "InternalError: myop")
@@ -36,14 +36,14 @@ def test_op_translation():
         assert False
     except tvm.error.InternalError as e:
         msg = str(e)
-        assert msg.find("api_test.cc") != -1
+        assert msg.find("ffi_testing.cc") != -1
 
     try:
         tvm.testing.ErrorTest(0, 1)
         assert False
     except ValueError as e:
         msg = str(e)
-        assert msg.find("api_test.cc") != -1
+        assert msg.find("ffi_testing.cc") != -1
 
 
 def test_deep_callback():