[REFACTOR] Use more TypedPackedFuncs (#2981)
authorJames Gilles <jhgilles@mit.edu>
Wed, 10 Apr 2019 21:28:25 +0000 (17:28 -0400)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 10 Apr 2019 21:28:25 +0000 (14:28 -0700)
* Add `set_body_simple` to Registry, refactor a lot of code to use it

* Add more types to Relay PackedFuncs

* Add Registry::set_body_method to easily make Node methods into
PackedFuncs

* Add set_body_method, set_body_node_method; start typing api_lang

* Add some docs, remove unused script

* Fix mysterious linter problem

* Touch up api_ir.cc

* Fix some issues with TOPI argument counts

* Revert changes to topi.cc to avoid problems with optional arguments

* A little more cleanup

* Type more of the api _ functions

* Whitespace

* Finalize names and docs for new registry helpers

* Update docs

68 files changed:
include/tvm/runtime/registry.h
nnvm/src/compiler/compile_engine.cc
nnvm/src/compiler/graph_hash.cc
src/api/api_arith.cc
src/api/api_base.cc
src/api/api_codegen.cc
src/api/api_ir.cc
src/api/api_lang.cc
src/api/api_pass.cc
src/api/api_schedule.cc
src/codegen/codegen_opencl.cc
src/codegen/codegen_opengl.cc
src/codegen/codegen_vhls.cc
src/codegen/llvm/codegen_amdgpu.cc
src/codegen/llvm/codegen_nvptx.cc
src/codegen/opt/build_cuda_on.cc
src/codegen/source_module.cc
src/codegen/spirv/build_vulkan.cc
src/codegen/stackvm/codegen_stackvm.cc
src/relay/backend/interpreter.cc
src/relay/ir/adt.cc
src/relay/ir/alpha_equal.cc
src/relay/ir/base.cc
src/relay/ir/expr.cc
src/relay/ir/expr_functor.cc
src/relay/ir/hash.cc
src/relay/ir/module.cc
src/relay/ir/type.cc
src/relay/op/debug.cc
src/relay/op/image/resize.cc
src/relay/op/nn/convolution.cc
src/relay/op/nn/nn.cc
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/nn/upsampling.cc
src/relay/op/tensor/reduce.cc
src/relay/op/tensor/transform.cc
src/relay/op/vision/multibox_op.cc
src/relay/op/vision/nms.cc
src/relay/op/vision/rcnn_op.cc
src/relay/op/vision/yolo.cc
src/relay/pass/canonicalize_ops.cc
src/relay/pass/combine_parallel_conv2d.cc
src/relay/pass/dead_code.cc
src/relay/pass/device_annotation.cc
src/relay/pass/fold_constant.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/gradient.cc
src/relay/pass/mac_count.cc
src/relay/pass/pass_manager.cc
src/relay/pass/quantize.cc
src/relay/pass/simplify_inference.cc
src/relay/pass/to_a_normal_form.cc
src/relay/pass/to_graph_normal_form.cc
src/relay/pass/type_infer.cc
src/relay/pass/util.cc
src/relay/pass/well_formed.cc
src/runtime/cuda/cuda_module.cc
src/runtime/metal/metal_module.mm
src/runtime/opencl/aocl/aocl_module.cc
src/runtime/opencl/opencl_module.cc
src/runtime/opencl/sdaccel/sdaccel_module.cc
src/runtime/rocm/rocm_module.cc
src/runtime/rpc/rpc_event_impl.cc
src/runtime/rpc/rpc_socket_impl.cc
src/runtime/stackvm/stackvm_module.cc
src/runtime/vulkan/vulkan_module.cc
web/web_runtime.cc

index 50bb5c5..40e1a52 100644 (file)
@@ -83,6 +83,169 @@ class Registry {
   Registry& set_body_typed(FLambda f) {
     return set_body(TypedPackedFunc<FType>(f).packed());
   }
+
+  /*!
+   * \brief set the body of the function to the given function pointer.
+   *        Note that this doesn't work with lambdas, you need to
+   *        explicitly give a type for those.
+   *        Note that this will ignore default arg values and always require all arguments to be provided.
+   *
+   * \code
+   * 
+   * int multiply(int x, int y) {
+   *   return x * y;
+   * }
+   *
+   * TVM_REGISTER_API("multiply")
+   * .set_body_typed(multiply); // will have type int(int, int)
+   *
+   * \endcode
+   *
+   * \param f The function to forward to.
+   * \tparam R the return type of the function (inferred).
+   * \tparam Args the argument types of the function (inferred).
+   */
+  template<typename R, typename ...Args>
+  Registry& set_body_typed(R (*f)(Args...)) {
+    return set_body(TypedPackedFunc<R(Args...)>(f));
+  }
+
+  /*!
+   * \brief set the body of the function to be the passed method pointer.
+   *        Note that this will ignore default arg values and always require all arguments to be provided.
+   *
+   * \code
+   * 
+   * // node subclass:
+   * struct Example {
+   *    int doThing(int x);
+   * }
+   * TVM_REGISTER_API("Example_doThing")
+   * .set_body_method(&Example::doThing); // will have type int(Example, int)
+   *
+   * \endcode
+   *
+   * \param f the method pointer to forward to.
+   * \tparam T the type containing the method (inferred).
+   * \tparam R the return type of the function (inferred).
+   * \tparam Args the argument types of the function (inferred).
+   */
+  template<typename T, typename R, typename ...Args>
+  Registry& set_body_method(R (T::*f)(Args...)) {
+    return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
+      // call method pointer
+      return (target.*f)(params...);
+    });
+  }
+
+  /*!
+   * \brief set the body of the function to be the passed method pointer.
+   *        Note that this will ignore default arg values and always require all arguments to be provided.
+   *
+   * \code
+   * 
+   * // node subclass:
+   * struct Example {
+   *    int doThing(int x);
+   * }
+   * TVM_REGISTER_API("Example_doThing")
+   * .set_body_method(&Example::doThing); // will have type int(Example, int)
+   *
+   * \endcode
+   *
+   * \param f the method pointer to forward to.
+   * \tparam T the type containing the method (inferred).
+   * \tparam R the return type of the function (inferred).
+   * \tparam Args the argument types of the function (inferred).
+   */
+  template<typename T, typename R, typename ...Args>
+  Registry& set_body_method(R (T::*f)(Args...) const) {
+    return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
+      // call method pointer
+      return (target.*f)(params...);
+    });
+  }
+
+  /*!
+   * \brief set the body of the function to be the passed method pointer.
+   *        Used when calling a method on a Node subclass through a NodeRef subclass.
+   *        Note that this will ignore default arg values and always require all arguments to be provided.
+   *
+   * \code
+   * 
+   * // node subclass:
+   * struct ExampleNode: BaseNode {
+   *    int doThing(int x);
+   * }
+   * 
+   * // noderef subclass
+   * struct Example; 
+   *
+   * TVM_REGISTER_API("Example_doThing")
+   * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
+   * 
+   * // note that just doing:
+   * // .set_body_method(&ExampleNode::doThing);
+   * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
+   *
+   * \endcode
+   *
+   * \param f the method pointer to forward to.
+   * \tparam TNodeRef the node reference type to call the method on
+   * \tparam TNode the node type containing the method (inferred).
+   * \tparam R the return type of the function (inferred).
+   * \tparam Args the argument types of the function (inferred).
+   */
+  template<typename TNodeRef, typename TNode, typename R, typename ...Args,
+    typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
+  Registry& set_body_method(R (TNode::*f)(Args...)) {
+    return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
+      TNode* target = ref.operator->();
+      // call method pointer
+      return (target->*f)(params...);
+    });
+  }
+
+  /*!
+   * \brief set the body of the function to be the passed method pointer.
+   *        Used when calling a method on a Node subclass through a NodeRef subclass.
+   *        Note that this will ignore default arg values and always require all arguments to be provided.
+   *
+   * \code
+   * 
+   * // node subclass:
+   * struct ExampleNode: BaseNode {
+   *    int doThing(int x);
+   * }
+   * 
+   * // noderef subclass
+   * struct Example; 
+   *
+   * TVM_REGISTER_API("Example_doThing")
+   * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
+   * 
+   * // note that just doing:
+   * // .set_body_method(&ExampleNode::doThing);
+   * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
+   *
+   * \endcode
+   *
+   * \param f the method pointer to forward to.
+   * \tparam TNodeRef the node reference type to call the method on
+   * \tparam TNode the node type containing the method (inferred).
+   * \tparam R the return type of the function (inferred).
+   * \tparam Args the argument types of the function (inferred).
+   */
+  template<typename TNodeRef, typename TNode, typename R, typename ...Args,
+    typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
+  Registry& set_body_method(R (TNode::*f)(Args...) const) {
+    return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
+      const TNode* target = ref.operator->();
+      // call method pointer
+      return (target->*f)(params...);
+    });
+  }
+
   /*!
    * \brief Register a function with given name
    * \param name The name of the function.
index a1422d7..5424559 100644 (file)
@@ -360,9 +360,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph")
   });
 
 TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
-.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
-    *rv = GraphKeyNode::make(args[0], args[1], args[2]);
-  });
+.set_body_typed(GraphKeyNode::make);
 
 // This can be used to extract workloads from nnvm compiler
 TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
index e825ef4..b76f99f 100644 (file)
@@ -235,8 +235,6 @@ std::string GraphDeepCompare(const Graph& a,
 }
 
 TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
-.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
-    *rv = GraphDeepCompare(args[0], args[1], args[2]);
-  });
+.set_body_typed(GraphDeepCompare);
 }  // namespace compiler
 }  // namespace nnvm
index ca0bed1..fce73aa 100644 (file)
@@ -31,73 +31,51 @@ namespace tvm {
 namespace arith {
 
 TVM_REGISTER_API("arith.intset_single_point")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = IntSet::single_point(args[0]);
-  });
+.set_body_typed(IntSet::single_point);
 
 TVM_REGISTER_API("arith.intset_vector")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = IntSet::vector(args[0]);
-  });
+.set_body_typed(IntSet::vector);
 
 TVM_REGISTER_API("arith.intset_interval")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = IntSet::interval(args[0], args[1]);
-  });
+.set_body_typed(IntSet::interval);
 
 TVM_REGISTER_API("arith.DetectLinearEquation")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = DetectLinearEquation(args[0], args[1]);
-  });
+.set_body_typed(DetectLinearEquation);
 
 TVM_REGISTER_API("arith.DetectClipBound")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = DetectClipBound(args[0], args[1]);
-  });
+.set_body_typed(DetectClipBound);
 
 TVM_REGISTER_API("arith.DeduceBound")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = DeduceBound(args[0], args[1],
-        args[2].operator Map<Var, IntSet>(),
-        args[3].operator Map<Var, IntSet>());
-  });
+.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
+  Expr v, Expr cond,
+  const Map<Var, IntSet> hint_map,
+  const Map<Var, IntSet> relax_map
+) {
+  return DeduceBound(v, cond, hint_map, relax_map);
+});
 
 
 TVM_REGISTER_API("arith.DomainTouched")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = DomainTouched(args[0], args[1], args[2], args[3]);
-  });
+.set_body_typed(DomainTouched);
 
 
 TVM_REGISTER_API("_IntervalSetGetMin")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = args[0].operator IntSet().min();
-  });
+.set_body_method(&IntSet::min);
 
 TVM_REGISTER_API("_IntervalSetGetMax")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = args[0].operator IntSet().max();
-  });
+.set_body_method(&IntSet::max);
 
 TVM_REGISTER_API("_IntSetIsNothing")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = args[0].operator IntSet().is_nothing();
-  });
+.set_body_method(&IntSet::is_nothing);
 
 TVM_REGISTER_API("_IntSetIsEverything")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = args[0].operator IntSet().is_everything();
-  });
+.set_body_method(&IntSet::is_everything);
 
 TVM_REGISTER_API("arith._make_ConstIntBound")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ConstIntBoundNode::make(args[0], args[1]);
-  });
+.set_body_typed(ConstIntBoundNode::make);
 
 TVM_REGISTER_API("arith._make_ModularSet")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ModularSetNode::make(args[0], args[1]);
-  });
+.set_body_typed(ModularSetNode::make);
 
 TVM_REGISTER_API("arith._CreateAnalyzer")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
index 23d1f5c..28ebb4d 100644 (file)
@@ -50,9 +50,8 @@ TVM_REGISTER_API("_load_json")
 .set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
 
 TVM_REGISTER_API("_TVMSetStream")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    TVMSetStream(args[0], args[1], args[2]);
-  });
+.set_body_typed(TVMSetStream);
+
 TVM_REGISTER_API("_save_param_dict")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
     CHECK_EQ(args.size() % 2, 0u);
index e44ebbe..73e2671 100644 (file)
@@ -41,8 +41,6 @@ TVM_REGISTER_API("codegen._Build")
   });
 
 TVM_REGISTER_API("module._PackImportsToC")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = PackImportsToC(args[0], args[1]);
-  });
+.set_body_typed(PackImportsToC);
 }  // namespace codegen
 }  // namespace tvm
index c5680bb..2525059 100644 (file)
@@ -31,54 +31,43 @@ namespace tvm {
 namespace ir {
 
 TVM_REGISTER_API("_Var")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = Variable::make(args[1], args[0]);
+.set_body_typed<VarExpr(std::string, Type)>([](std::string s, Type t) {
+    return Variable::make(t, s);
   });
 
 TVM_REGISTER_API("make.abs")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = tvm::abs(args[0]);
-  });
+.set_body_typed(tvm::abs);
 
 TVM_REGISTER_API("make.floor")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = tvm::floor(args[0]);
-  });
+.set_body_typed(tvm::floor);
 
 TVM_REGISTER_API("make.ceil")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = tvm::ceil(args[0]);
-  });
+.set_body_typed(tvm::ceil);
 
 TVM_REGISTER_API("make.round")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = tvm::round(args[0]);
-  });
+.set_body_typed(tvm::round);
 
 TVM_REGISTER_API("make.trunc")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = tvm::trunc(args[0]);
-  });
+.set_body_typed(tvm::trunc);
 
 TVM_REGISTER_API("make._cast")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = tvm::cast(args[0], args[1]);
-  });
+.set_body_typed(tvm::cast);
 
 TVM_REGISTER_API("make._range_by_min_extent")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = Range::make_by_min_extent(args[0], args[1]);
-  });
+.set_body_typed(Range::make_by_min_extent);
 
 TVM_REGISTER_API("make.For")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = For::make(args[0],
-                     args[1],
-                     args[2],
-                     static_cast<ForType>(args[3].operator int()),
-                     static_cast<HalideIR::DeviceAPI>(args[4].operator int()),
-                     args[5]);
-  });
+.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
+  VarExpr loop_var, Expr min, Expr extent,
+  int for_type, int device_api, Stmt body
+) {
+  return For::make(loop_var,
+                    min,
+                    extent,
+                    static_cast<ForType>(for_type),
+                    static_cast<HalideIR::DeviceAPI>(device_api),
+                    body);
+});
 
 TVM_REGISTER_API("make.Load")
 .set_body([](TVMArgs args,  TVMRetValue *ret) {
@@ -101,114 +90,87 @@ TVM_REGISTER_API("make.Store")
   });
 
 TVM_REGISTER_API("make.Realize")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = Realize::make(args[0],
-                         args[1],
-                         args[2],
-                         args[3],
-                         args[4],
-                         args[5]);
-  });
-
+.set_body_typed(Realize::make);
 
 TVM_REGISTER_API("make.Call")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = Call::make(args[0],
-                      args[1],
-                      args[2],
-                      static_cast<Call::CallType>(args[3].operator int()),
-                      args[4],
-                      args[5]);
-  });
+.set_body_typed<Expr(Type, std::string, Array<Expr>, int, FunctionRef, int)>([](
+  Type type, std::string name,
+  Array<Expr> args, int call_type,
+  FunctionRef func, int value_index
+) {
+  return Call::make(type,
+                    name,
+                    args,
+                    static_cast<Call::CallType>(call_type),
+                    func,
+                    value_index);
+});
 
 TVM_REGISTER_API("make.CommReducer")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = CommReducerNode::make(args[0],
-                                 args[1],
-                                 args[2],
-                                 args[3]);
-  });
+.set_body_typed(CommReducerNode::make);
 
 // make from two arguments
-#define REGISTER_MAKE1(Node)                                 \
-  TVM_REGISTER_API("make."#Node)                             \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
-      *ret = Node::make(args[0]);                            \
-    })                                                       \
-
-#define REGISTER_MAKE2(Node)                                 \
+#define REGISTER_MAKE(Node)                                  \
   TVM_REGISTER_API("make."#Node)                             \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
-      *ret = Node::make(args[0], args[1]);                   \
-    })                                                       \
-
-#define REGISTER_MAKE3(Node)                                 \
-  TVM_REGISTER_API("make."#Node)                             \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
-      *ret = Node::make(args[0], args[1], args[2]);          \
-    })                                                       \
-
-#define REGISTER_MAKE4(Node)                                            \
-  TVM_REGISTER_API("make."#Node)                                        \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
-      *ret = Node::make(args[0], args[1], args[2], args[3]);            \
-    })                                                                  \
-
-#define REGISTER_MAKE5(Node)                                            \
-  TVM_REGISTER_API("make."#Node)                                        \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
-      *ret = Node::make(args[0], args[1], args[2], args[3], args[4]);   \
-    })                                                                  \
-
-
-REGISTER_MAKE5(Reduce);
-REGISTER_MAKE4(AttrStmt);
-
-REGISTER_MAKE2(IntImm);
-REGISTER_MAKE2(UIntImm);
-REGISTER_MAKE2(FloatImm);
-REGISTER_MAKE1(StringImm);
-
-REGISTER_MAKE2(Add);
-REGISTER_MAKE2(Sub);
-REGISTER_MAKE2(Mul);
-REGISTER_MAKE2(Div);
-REGISTER_MAKE2(Mod);
-REGISTER_MAKE2(Min);
-REGISTER_MAKE2(Max);
-REGISTER_MAKE2(EQ);
-REGISTER_MAKE2(NE);
-REGISTER_MAKE2(LT);
-REGISTER_MAKE2(LE);
-REGISTER_MAKE2(GT);
-REGISTER_MAKE2(GE);
-REGISTER_MAKE2(And);
-REGISTER_MAKE2(Or);
-
-REGISTER_MAKE1(Not);
-REGISTER_MAKE3(Select);
-REGISTER_MAKE3(Ramp);
-REGISTER_MAKE2(Cast);
-REGISTER_MAKE2(Broadcast);
-REGISTER_MAKE2(Shuffle);
-REGISTER_MAKE3(Let);
-REGISTER_MAKE3(LetStmt);
-REGISTER_MAKE3(AssertStmt);
-REGISTER_MAKE3(ProducerConsumer);
-REGISTER_MAKE5(Allocate);
-REGISTER_MAKE4(Provide);
-REGISTER_MAKE4(Prefetch);
-REGISTER_MAKE1(Free);
-REGISTER_MAKE2(Block);
-REGISTER_MAKE3(IfThenElse);
-REGISTER_MAKE1(Evaluate);
+  .set_body_typed(Node::make);                              \
+
+REGISTER_MAKE(Reduce);
+REGISTER_MAKE(AttrStmt);
+
+REGISTER_MAKE(IntImm);
+REGISTER_MAKE(UIntImm);
+REGISTER_MAKE(FloatImm);
+REGISTER_MAKE(StringImm);
+
+REGISTER_MAKE(Add);
+REGISTER_MAKE(Sub);
+REGISTER_MAKE(Mul);
+REGISTER_MAKE(Div);
+REGISTER_MAKE(Mod);
+REGISTER_MAKE(Min);
+REGISTER_MAKE(Max);
+REGISTER_MAKE(EQ);
+REGISTER_MAKE(NE);
+REGISTER_MAKE(LT);
+REGISTER_MAKE(LE);
+REGISTER_MAKE(GT);
+REGISTER_MAKE(GE);
+REGISTER_MAKE(And);
+REGISTER_MAKE(Or);
+
+REGISTER_MAKE(Not);
+REGISTER_MAKE(Select);
+REGISTER_MAKE(Ramp);
+REGISTER_MAKE(Cast);
+REGISTER_MAKE(Broadcast);
+REGISTER_MAKE(Shuffle);
+REGISTER_MAKE(Let);
+REGISTER_MAKE(LetStmt);
+REGISTER_MAKE(AssertStmt);
+REGISTER_MAKE(ProducerConsumer);
+REGISTER_MAKE(Provide);
+REGISTER_MAKE(Prefetch);
+REGISTER_MAKE(Free);
+REGISTER_MAKE(IfThenElse);
+REGISTER_MAKE(Evaluate);
+
+// overloaded, needs special handling
+TVM_REGISTER_API("make.Block")
+  .set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));
+
+// has default args
+TVM_REGISTER_API("make.Allocate")
+  .set_body_typed<Stmt(VarExpr, Type, Array<Expr>, Expr, Stmt)>([](
+    VarExpr buffer_var, Type type, Array<Expr> extents, Expr condition, Stmt body
+  ){
+    return Allocate::make(buffer_var, type, extents, condition, body);
+  });
 
 // operator overloading, smarter than make
 #define REGISTER_MAKE_BINARY_OP(Node, Func)                  \
   TVM_REGISTER_API("make."#Node)                             \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {            \
-      Expr a = args[0], b = args[1];                         \
-      *ret = (Func(a, b));                                   \
+  .set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) {     \
+      return (Func(a, b));                                   \
     })
 
 #define REGISTER_MAKE_BIT_OP(Node, Func)                                \
index aac73f1..42d60b8 100644 (file)
 #include <tvm/build_module.h>
 #include <tvm/data_layout.h>
 
+
 namespace tvm {
 
 TVM_REGISTER_API("_min_value")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    Type t = args[0].operator Type();
-    *ret = t.min();
-  });
+.set_body_method(&Type::min);
 
 TVM_REGISTER_API("_max_value")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    Type t = args[0].operator Type();
-    *ret = t.max();
-  });
+.set_body_method(&Type::max);
 
 TVM_REGISTER_API("_const")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
@@ -58,9 +53,7 @@ TVM_REGISTER_API("_const")
   });
 
 TVM_REGISTER_API("_str")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-  *ret = ir::StringImm::make(args[0]);
-});
+.set_body_typed(ir::StringImm::make);
 
 
 TVM_REGISTER_API("_Array")
@@ -214,373 +207,217 @@ TVM_REGISTER_API("Range")
   });
 
 TVM_REGISTER_API("_Buffer")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = BufferNode::make(args[0],
-                            args[1],
-                            args[2],
-                            args[3],
-                            args[4],
-                            args[5],
-                            args[6],
-                            args[7],
-                            args[8]);
-  });
+.set_body_typed(BufferNode::make);
 
 TVM_REGISTER_API("_BufferAccessPtr")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator Buffer()
-        .access_ptr(args[1], args[2], args[3], args[4]);
-  });
+.set_body_method(&Buffer::access_ptr);
 
 TVM_REGISTER_API("_BufferVLoad")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator Buffer()
-        .vload(args[1], args[2]);
-  });
+.set_body_method(&Buffer::vload);
 
 TVM_REGISTER_API("_BufferVStore")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator Buffer()
-        .vstore(args[1], args[2]);
-  });
+.set_body_method(&Buffer::vstore);
 
 TVM_REGISTER_API("_Layout")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = LayoutNode::make(args[0]);
-  });
+.set_body_typed(LayoutNode::make);
 
 TVM_REGISTER_API("_LayoutIndexOf")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-  *ret = args[0].operator Layout()
-      .IndexOf(LayoutAxis::make(args[1]));
+.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
+  return layout.IndexOf(LayoutAxis::make(axis));
 });
 
 TVM_REGISTER_API("_LayoutFactorOf")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-  *ret = args[0].operator Layout()
-      .FactorOf(LayoutAxis::make(args[1]));
+.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
+  return layout.FactorOf(LayoutAxis::make(axis));
 });
 
 TVM_REGISTER_API("_LayoutNdim")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-  *ret = static_cast<int64_t>(args[0].operator Layout().ndim());
+.set_body_typed<int(Layout)>([](Layout layout) {
+  return layout.ndim();
 });
 
 TVM_REGISTER_API("_LayoutGetItem")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-  const LayoutAxis& axis = args[0].operator Layout()[args[1]];
-  *ret = axis.name();
+.set_body_typed<std::string(Layout, int)>([](Layout layout, int idx) {
+  const LayoutAxis& axis = layout[idx];
+  return axis.name();
 });
 
 TVM_REGISTER_API("_BijectiveLayout")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = BijectiveLayoutNode::make(args[0], args[1]);
-  });
+.set_body_typed(BijectiveLayoutNode::make);
 
 TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator BijectiveLayout()
-        .ForwardIndex(args[1]);
-  });
+.set_body_method(&BijectiveLayout::ForwardIndex);
 
 TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator BijectiveLayout()
-        .BackwardIndex(args[1]);
-  });
+.set_body_method(&BijectiveLayout::BackwardIndex);
 
 TVM_REGISTER_API("_BijectiveLayoutForwardShape")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator BijectiveLayout()
-        .ForwardShape(args[1]);
-  });
+.set_body_method(&BijectiveLayout::ForwardShape);
 
 TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator BijectiveLayout()
-        .BackwardShape(args[1]);
-  });
+.set_body_method(&BijectiveLayout::BackwardShape);
 
 TVM_REGISTER_API("_Tensor")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = TensorNode::make(args[0],
-                            args[1],
-                            args[2],
-                            args[3]);
-  });
+.set_body_typed(TensorNode::make);
 
 TVM_REGISTER_API("_TensorIntrin")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = TensorIntrinNode::make(args[0],
-                                  args[1],
-                                  args[2],
-                                  args[3],
-                                  args[4],
-                                  args[5],
-                                  args[6]);
-  });
+.set_body_typed(TensorIntrinNode::make);
 
 TVM_REGISTER_API("_TensorIntrinCall")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = TensorIntrinCallNode::make(args[0],
-                                      args[1],
-                                      args[2],
-                                      args[3]);
-  });
+.set_body_typed(TensorIntrinCallNode::make);
 
 TVM_REGISTER_API("_TensorEqual")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator Tensor() == args[1].operator Tensor();
-  });
+.set_body_method(&Tensor::operator==);
 
 TVM_REGISTER_API("_TensorHash")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = static_cast<int64_t>(
-        std::hash<Tensor>()(args[0].operator Tensor()));
+.set_body_typed<int64_t(Tensor)>([](Tensor tensor) {
+    return static_cast<int64_t>(std::hash<Tensor>()(tensor));
   });
 
 TVM_REGISTER_API("_Placeholder")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = placeholder(args[0],
-                       args[1],
-                       args[2]);
-  });
+.set_body_typed<Tensor(Array<Expr>, Type, std::string)>([](
+  Array<Expr> shape, Type dtype, std::string name
+) {
+  return placeholder(shape, dtype, name);
+});
 
 TVM_REGISTER_API("_ComputeOp")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = ComputeOpNode::make(args[0],
-                               args[1],
-                               args[2],
-                               args[3],
-                               args[4]);
-  });
+.set_body_typed(ComputeOpNode::make);
 
 TVM_REGISTER_API("_ScanOp")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = ScanOpNode::make(args[0],
-                            args[1],
-                            args[2],
-                            args[3],
-                            args[4],
-                            args[5],
-                            args[6],
-                            args[7]);
-  });
+.set_body_typed(ScanOpNode::make);
 
 TVM_REGISTER_API("_TensorComputeOp")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = TensorComputeOpNode::make(args[0],
-                                     args[1],
-                                     args[2],
-                                     args[3],
-                                     args[4],
-                                     args[5],
-                                     args[6],
-                                     args[7]);
-  });
+.set_body_typed(TensorComputeOpNode::make);
 
 TVM_REGISTER_API("_ExternOp")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = ExternOpNode::make(args[0],
-                              args[1],
-                              args[2],
-                              args[3],
-                              args[4],
-                              args[5],
-                              args[6]);
-  });
+.set_body_typed(ExternOpNode::make);
 
 TVM_REGISTER_API("_HybridOp")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = HybridOpNode::make(args[0],
-                              args[1],
-                              args[2],
-                              args[3],
-                              args[4],
-                              args[5]);
-  });
+.set_body_typed(HybridOpNode::make);
 
 TVM_REGISTER_API("_OpGetOutput")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator Operation().output(
-        static_cast<size_t>(args[1].operator int64_t()));
-  });
+.set_body_typed<Tensor(Operation, int64_t)>([](Operation op, int64_t output) {
+  return op.output(static_cast<size_t>(output));
+});
 
 TVM_REGISTER_API("_OpNumOutputs")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator Operation()->num_outputs();
-  });
+.set_body_method<Operation>(&OperationNode::num_outputs);
 
 TVM_REGISTER_API("_OpInputTensors")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = args[0].operator Operation()->InputTensors();
-  });
+.set_body_method<Operation>(&OperationNode::InputTensors);
 
 TVM_REGISTER_API("_IterVar")
-.set_body([](TVMArgs args,  TVMRetValue* ret) {
-    *ret = IterVarNode::make(
-        args[0], args[1],
-        static_cast<IterVarType>(args[2].operator int()),
-        args[3]);
-  });
+.set_body_typed<IterVar(Range, Var, int, std::string)>([](
+  Range dom, Var var, int iter_type, std::string thread_tag
+) {
+  return IterVarNode::make(
+      dom, var,
+      static_cast<IterVarType>(iter_type),
+      thread_tag);
+});
 
 TVM_REGISTER_API("_CreateSchedule")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = create_schedule(args[0].operator Array<Operation>());
-  });
+.set_body_typed(create_schedule);
 
 TVM_REGISTER_API("_StageSetScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .set_scope(args[1]);
-  });
+.set_body_method(&Stage::set_scope);
 
 TVM_REGISTER_API("_StageBind")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .bind(args[1], args[2]);
-  });
+.set_body_method(&Stage::bind);
 
 TVM_REGISTER_API("_StageSplitByFactor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    IterVar outer, inner;
-    args[0].operator Stage()
-        .split(args[1], args[2], &outer, &inner);
-    *ret = Array<IterVar>({outer, inner});
-  });
+.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
+  Stage stage, IterVar parent, Expr factor
+) {
+  IterVar outer, inner;
+  stage.split(parent, factor, &outer, &inner);
+  return Array<IterVar>({outer, inner});
+});
 
 TVM_REGISTER_API("_StageSplitByNParts")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    IterVar outer, inner;
-    args[0].operator Stage()
-        .split_by_nparts(args[1], args[2], &outer, &inner);
-    *ret = Array<IterVar>({outer, inner});
-  });
+.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
+  Stage stage, IterVar parent, Expr nparts
+) {
+  IterVar outer, inner;
+  stage.split_by_nparts(parent, nparts, &outer, &inner);
+  return Array<IterVar>({outer, inner});
+});
 
 TVM_REGISTER_API("_StageFuse")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+.set_body_typed<IterVar(Stage, Array<IterVar>)>([](Stage stage, Array<IterVar> axes) {
     IterVar fused;
-    args[0].operator Stage()
-        .fuse(args[1], &fused);
-    *ret = fused;
+    stage.fuse(axes, &fused);
+    return fused;
   });
 
 TVM_REGISTER_API("_StageComputeAt")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .compute_at(args[1], args[2]);
-  });
+.set_body_method(&Stage::compute_at);
 
 TVM_REGISTER_API("_StageComputeInline")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .compute_inline();
-  });
+.set_body_method(&Stage::compute_inline);
 
 TVM_REGISTER_API("_StageComputeRoot")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .compute_root();
-  });
+.set_body_method(&Stage::compute_root);
 
 TVM_REGISTER_API("_StageReorder")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .reorder(args[1]);
-  });
+.set_body_method(&Stage::reorder);
 
 TVM_REGISTER_API("_StageTile")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
+.set_body_typed<Array<IterVar>(Stage, IterVar, IterVar, Expr, Expr)>([](
+  Stage stage,
+  IterVar x_parent, IterVar y_parent,
+  Expr x_factor, Expr y_factor
+) {
     IterVar x_outer, y_outer, x_inner, y_inner;
-    args[0].operator Stage()
-        .tile(args[1], args[2],
-              args[3], args[4],
-              &x_outer, &y_outer,
-              &x_inner, &y_inner);
-    *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
+    stage.tile(x_parent, y_parent,
+               x_factor, y_factor,
+               &x_outer, &y_outer,
+               &x_inner, &y_inner);
+    return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
   });
 
 TVM_REGISTER_API("_StageEnvThreads")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .env_threads(args[1]);
-  });
+.set_body_method(&Stage::env_threads);
 
 TVM_REGISTER_API("_StageSetStorePredicate")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .set_store_predicate(args[1]);
-  });
+.set_body_method(&Stage::set_store_predicate);
 
 TVM_REGISTER_API("_StageUnroll")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .unroll(args[1]);
-  });
+.set_body_method(&Stage::unroll);
 
 TVM_REGISTER_API("_StageVectorize")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .vectorize(args[1]);
-  });
+.set_body_method(&Stage::vectorize);
 
 TVM_REGISTER_API("_StageTensorize")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .tensorize(args[1], args[2]);
-  });
+.set_body_method(&Stage::tensorize);
 
 TVM_REGISTER_API("_StageParallel")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .parallel(args[1]);
-  });
+.set_body_method(&Stage::parallel);
 
 TVM_REGISTER_API("_StagePragma")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Stage()
-        .pragma(args[1], args[2], args[3]);
-  });
+.set_body_method(&Stage::pragma);
 
 TVM_REGISTER_API("_StagePrefetch")
-  .set_body([](TVMArgs args, TVMRetValue *ret) {
-    args[0].operator Stage()
-        .prefetch(args[1], args[2], args[3]);
-  });
+.set_body_method(&Stage::prefetch);
 
 TVM_REGISTER_API("_StageStorageAlign")
-  .set_body([](TVMArgs args, TVMRetValue *ret) {
-    args[0].operator Stage()
-        .storage_align(args[1], args[2], args[3]);
-  });
+.set_body_method(&Stage::storage_align);
 
 TVM_REGISTER_API("_StageDoubleBuffer")
-  .set_body([](TVMArgs args, TVMRetValue *ret) {
-    args[0].operator Stage().double_buffer();
-  });
+.set_body_method(&Stage::double_buffer);
 
 TVM_REGISTER_API("_StageOpenGL")
-  .set_body([](TVMArgs args, TVMRetValue *ret) {
-    args[0].operator Stage().opengl();
-  });
+.set_body_method(&Stage::opengl);
 
 TVM_REGISTER_API("_ScheduleNormalize")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = args[0].operator Schedule()
-        .normalize();
-  });
+.set_body_method(&Schedule::normalize);
 
 TVM_REGISTER_API("_ScheduleCreateGroup")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = args[0].operator Schedule()
-        .create_group(args[1], args[2], args[3]);
-  });
+.set_body_method(&Schedule::create_group);
 
 TVM_REGISTER_API("_ScheduleCacheRead")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = args[0].operator Schedule()
-        .cache_read(args[1], args[2], args[3]);
-  });
+.set_body_method(&Schedule::cache_read);
 
 TVM_REGISTER_API("_ScheduleCacheWrite")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -594,16 +431,9 @@ TVM_REGISTER_API("_ScheduleCacheWrite")
   });
 
 TVM_REGISTER_API("_ScheduleRFactor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = args[0].operator Schedule()
-        .rfactor(args[1], args[2], args[3]);
-  });
+.set_body_method(&Schedule::rfactor);
 
 TVM_REGISTER_API("_CommReducerCombine")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    const ir::CommReducerNode* combiner =
-      args[0].operator ir::CommReducer().as<ir::CommReducerNode>();
-    *ret = (*combiner)(args[1], args[2]);
-  });
+.set_body_method<ir::CommReducer>(&ir::CommReducerNode::operator());
 
 }  // namespace tvm
index 2e1ab42..6195aac 100644 (file)
@@ -119,68 +119,43 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
   });
 
 // make from two arguments
-#define REGISTER_PASS1(PassName)                                  \
+#define REGISTER_PASS(PassName)                                   \
   TVM_REGISTER_API("ir_pass."#PassName)                           \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
-      *ret = PassName(args[0]);                                   \
-    })                                                            \
-
-#define REGISTER_PASS2(PassName)                                  \
-  TVM_REGISTER_API("ir_pass."#PassName)                           \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
-      *ret = PassName(args[0], args[1]);                          \
-    })                                                            \
-
-#define REGISTER_PASS3(PassName)                                        \
-  TVM_REGISTER_API("ir_pass."#PassName)                                 \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
-      *ret = PassName(args[0], args[1], args[2]);                       \
-    })                                                                  \
-
-#define REGISTER_PASS4(PassName)                                        \
-  TVM_REGISTER_API("ir_pass."#PassName)                                 \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
-      *ret = PassName(args[0], args[1], args[2], args[3]);              \
-    })                                                                  \
-
-#define REGISTER_PASS5(PassName)                                        \
-  TVM_REGISTER_API("ir_pass."#PassName)                                 \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                       \
-      *ret = PassName(args[0], args[1], args[2], args[3], args[4]);     \
-    })                                                                  \
-
-REGISTER_PASS1(ConvertSSA);
-REGISTER_PASS1(VerifySSA);
-REGISTER_PASS1(RewriteUnsafeSelect);
-REGISTER_PASS4(Inline);
-REGISTER_PASS4(IRTransform);
-REGISTER_PASS1(VectorizeLoop);
-REGISTER_PASS5(UnrollLoop);
-REGISTER_PASS3(InjectCopyIntrin);
-REGISTER_PASS2(ThreadSync);
-REGISTER_PASS5(MakeAPI);
-REGISTER_PASS2(BindDeviceType);
-REGISTER_PASS1(SplitHostDevice);
-REGISTER_PASS1(StorageRewrite);
-REGISTER_PASS1(CoProcSync);
-REGISTER_PASS1(LowerStorageAccessInfo);
-REGISTER_PASS1(InjectVirtualThread);
-REGISTER_PASS1(InjectPrefetch);
-REGISTER_PASS2(InjectDoubleBuffer);
-REGISTER_PASS2(LoopPartition);
-REGISTER_PASS1(RemoveNoOp);
-REGISTER_PASS2(SplitPipeline);
-REGISTER_PASS2(LiftAttrScope);
-REGISTER_PASS1(NarrowChannelAccess);
-REGISTER_PASS2(LowerThreadAllreduce);
-REGISTER_PASS2(LowerWarpMemory);
-REGISTER_PASS2(RemapThreadAxis);
-REGISTER_PASS2(LowerIntrin);
-REGISTER_PASS1(LowerTVMBuiltin);
-REGISTER_PASS1(CombineContextCall);
-REGISTER_PASS2(VerifyMemory);
-REGISTER_PASS2(VerifyGPUCode);
-REGISTER_PASS1(DecorateDeviceScope);
-REGISTER_PASS1(InstrumentBoundCheckers);
+  .set_body_typed(PassName);                                     \
+
+
+REGISTER_PASS(ConvertSSA);
+REGISTER_PASS(VerifySSA);
+REGISTER_PASS(RewriteUnsafeSelect);
+REGISTER_PASS(Inline);
+REGISTER_PASS(IRTransform);
+REGISTER_PASS(VectorizeLoop);
+REGISTER_PASS(UnrollLoop);
+REGISTER_PASS(InjectCopyIntrin);
+REGISTER_PASS(ThreadSync);
+REGISTER_PASS(MakeAPI);
+REGISTER_PASS(BindDeviceType);
+REGISTER_PASS(SplitHostDevice);
+REGISTER_PASS(StorageRewrite);
+REGISTER_PASS(CoProcSync);
+REGISTER_PASS(LowerStorageAccessInfo);
+REGISTER_PASS(InjectVirtualThread);
+REGISTER_PASS(InjectPrefetch);
+REGISTER_PASS(InjectDoubleBuffer);
+REGISTER_PASS(LoopPartition);
+REGISTER_PASS(RemoveNoOp);
+REGISTER_PASS(SplitPipeline);
+REGISTER_PASS(LiftAttrScope);
+REGISTER_PASS(NarrowChannelAccess);
+REGISTER_PASS(LowerThreadAllreduce);
+REGISTER_PASS(LowerWarpMemory);
+REGISTER_PASS(RemapThreadAxis);
+REGISTER_PASS(LowerIntrin);
+REGISTER_PASS(LowerTVMBuiltin);
+REGISTER_PASS(CombineContextCall);
+REGISTER_PASS(VerifyMemory);
+REGISTER_PASS(VerifyGPUCode);
+REGISTER_PASS(DecorateDeviceScope);
+REGISTER_PASS(InstrumentBoundCheckers);
 }  // namespace ir
 }  // namespace tvm
index 45e2eb4..177360b 100644 (file)
@@ -33,15 +33,11 @@ namespace tvm {
 namespace schedule {
 
 TVM_REGISTER_API("schedule.AutoInlineElemWise")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    AutoInlineElemWise(args[0]);
-  });
+.set_body_typed(AutoInlineElemWise);
 
 
 TVM_REGISTER_API("schedule.AutoInlineInjective")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    AutoInlineInjective(args[0]);
-  });
+.set_body_typed(AutoInlineInjective);
 
 TVM_REGISTER_API("schedule.ScheduleOps")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -51,25 +47,17 @@ TVM_REGISTER_API("schedule.ScheduleOps")
     *ret = ScheduleOps(args[0], args[1], args[2]);
 });
 
-#define REGISTER_SCHEDULE_PASS1(PassName)                         \
-  TVM_REGISTER_API("schedule."#PassName)                          \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
-      *ret = PassName(args[0]);                                   \
-    })                                                            \
-
-#define REGISTER_SCHEDULE_PASS2(PassName)                         \
+#define REGISTER_SCHEDULE_PASS(PassName)                          \
   TVM_REGISTER_API("schedule."#PassName)                          \
-  .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
-      *ret = PassName(args[0], args[1]);                          \
-    })                                                            \
+  .set_body_typed(PassName);                                     \
 
 
-REGISTER_SCHEDULE_PASS1(InferBound);
-REGISTER_SCHEDULE_PASS1(CreateReadGraph);
-REGISTER_SCHEDULE_PASS2(PostDFSOrder);
-REGISTER_SCHEDULE_PASS1(CreateAttachPath);
-REGISTER_SCHEDULE_PASS1(ScanGetBody);
-REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
+REGISTER_SCHEDULE_PASS(InferBound);
+REGISTER_SCHEDULE_PASS(CreateReadGraph);
+REGISTER_SCHEDULE_PASS(PostDFSOrder);
+REGISTER_SCHEDULE_PASS(CreateAttachPath);
+REGISTER_SCHEDULE_PASS(ScanGetBody);
+REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
 
 }  // namespace schedule
 }  // namespace tvm
index 96e1b9e..382124a 100644 (file)
@@ -263,8 +263,6 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
 }
 
 TVM_REGISTER_API("codegen.build_opencl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = BuildOpenCL(args[0]);
-  });
+.set_body_typed(BuildOpenCL);
 }  // namespace codegen
 }  // namespace tvm
index 27d910e..797a7d1 100644 (file)
@@ -302,9 +302,7 @@ runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
 }
 
 TVM_REGISTER_API("codegen.build_opengl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = BuildOpenGL(args[0]);
-});
+.set_body_typed(BuildOpenGL);
 
 }  // namespace codegen
 }  // namespace tvm
index 460647a..a18312f 100644 (file)
@@ -164,9 +164,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
 }
 
 TVM_REGISTER_API("codegen.build_sdaccel")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = BuildSDAccel(args[0], args[1]);
-  });
+.set_body_typed(BuildSDAccel);
 
 }  // namespace codegen
 }  // namespace tvm
index 22c432c..396ae59 100644 (file)
@@ -265,9 +265,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
 }
 
 TVM_REGISTER_API("codegen.build_rocm")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = BuildAMDGPU(args[0], args[1]);
-  });
+.set_body_typed(BuildAMDGPU);
 
 }  // namespace codegen
 }  // namespace tvm
index b1b541d..290727f 100644 (file)
@@ -243,9 +243,7 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
 }
 
 TVM_REGISTER_API("codegen.build_nvptx")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = BuildNVPTX(args[0], args[1]);
-  });
+.set_body_typed(BuildNVPTX);
 
 }  // namespace codegen
 }  // namespace tvm
index 581c330..fda239f 100644 (file)
@@ -155,8 +155,6 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
 }
 
 TVM_REGISTER_API("codegen.build_cuda")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = BuildCUDA(args[0]);
-  });
+.set_body_typed(BuildCUDA);
 }  // namespace codegen
 }  // namespace tvm
index 70047a6..88be7fe 100644 (file)
@@ -188,8 +188,6 @@ runtime::Module DeviceSourceModuleCreate(
 }
 
 TVM_REGISTER_GLOBAL("module.source_module_create")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = SourceModuleCreate(args[0], args[1]);
-  });
+.set_body_typed(SourceModuleCreate);
 }  // namespace codegen
 }  // namespace tvm
index 2b1ef66..18ffad1 100644 (file)
@@ -103,9 +103,7 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
 }
 
 TVM_REGISTER_API("codegen.build_vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = BuildSPIRV(args[0]);
-  });
+.set_body_typed(BuildSPIRV);
 
 }  // namespace codegen
 }  // namespace tvm
index 8c4c258..2d71a20 100644 (file)
@@ -522,8 +522,6 @@ runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) {
 }
 
 TVM_REGISTER_API("codegen.build_stackvm")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = BuildStackVM(args[0]);
-  });
+.set_body_typed(BuildStackVM);
 }  // namespace codegen
 }  // namespace tvm
index 735f183..9af3f82 100644 (file)
@@ -51,9 +51,7 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
 }
 
 TVM_REGISTER_API("relay._make.Closure")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ClosureNode::make(args[0], args[1]);
-  });
+.set_body_typed(ClosureNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
@@ -67,9 +65,7 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) {
 }
 
 TVM_REGISTER_API("relay._make.TupleValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = TupleValueNode::make(args[0]);
-  });
+.set_body_typed(TupleValueNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) {
@@ -90,10 +86,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
   });
 
 TVM_REGISTER_API("relay._make.TensorValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    runtime::NDArray data = args[0];
-    *ret = TensorValueNode::make(data);
-  });
+.set_body_typed(TensorValueNode::make);
 
 RefValue RefValueNode::make(Value value) {
   NodePtr<RefValueNode> n = make_node<RefValueNode>();
@@ -102,9 +95,7 @@ RefValue RefValueNode::make(Value value) {
 }
 
 TVM_REGISTER_API("relay._make.RefValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = RefValueNode::make(args[0]);
-  });
+.set_body_typed(RefValueNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<RefValueNode>([](const RefValueNode* node,
@@ -121,9 +112,7 @@ ConstructorValue ConstructorValueNode::make(Constructor constructor,
 }
 
 TVM_REGISTER_API("relay._make.ConstructorValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ConstructorValueNode::make(args[0], args[1]);
-  });
+.set_body_typed(ConstructorValueNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
@@ -614,9 +603,7 @@ CreateInterpreter(
 }
 
 TVM_REGISTER_API("relay.backend.CreateInterpreter")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = CreateInterpreter(args[0], args[1], args[2]);
-  });
+.set_body_typed(CreateInterpreter);
 
 TVM_REGISTER_NODE_TYPE(ClosureNode);
 TVM_REGISTER_NODE_TYPE(TupleValueNode);
index 2e7d854..b59281a 100644 (file)
@@ -36,9 +36,7 @@ PatternWildcard PatternWildcardNode::make() {
 TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
 
 TVM_REGISTER_API("relay._make.PatternWildcard")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = PatternWildcardNode::make();
-  });
+.set_body_typed(PatternWildcardNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
@@ -55,9 +53,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) {
 TVM_REGISTER_NODE_TYPE(PatternVarNode);
 
 TVM_REGISTER_API("relay._make.PatternVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = PatternVarNode::make(args[0]);
-  });
+.set_body_typed(PatternVarNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<PatternVarNode>([](const PatternVarNode* node,
@@ -76,9 +72,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor,
 TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
 
 TVM_REGISTER_API("relay._make.PatternConstructor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = PatternConstructorNode::make(args[0], args[1]);
-  });
+.set_body_typed(PatternConstructorNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
@@ -100,9 +94,7 @@ Constructor ConstructorNode::make(std::string name_hint,
 TVM_REGISTER_NODE_TYPE(ConstructorNode);
 
 TVM_REGISTER_API("relay._make.Constructor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ConstructorNode::make(args[0], args[1], args[2]);
-  });
+.set_body_typed(ConstructorNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ConstructorNode>([](const ConstructorNode* node,
@@ -124,9 +116,7 @@ TypeData TypeDataNode::make(GlobalTypeVar header,
 TVM_REGISTER_NODE_TYPE(TypeDataNode);
 
 TVM_REGISTER_API("relay._make.TypeData")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = TypeDataNode::make(args[0], args[1], args[2]);
-  });
+.set_body_typed(TypeDataNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TypeDataNode>([](const TypeDataNode* node,
@@ -145,9 +135,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) {
 TVM_REGISTER_NODE_TYPE(ClauseNode);
 
 TVM_REGISTER_API("relay._make.Clause")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ClauseNode::make(args[0], args[1]);
-  });
+.set_body_typed(ClauseNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ClauseNode>([](const ClauseNode* node,
@@ -166,9 +154,7 @@ Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) {
 TVM_REGISTER_NODE_TYPE(MatchNode);
 
 TVM_REGISTER_API("relay._make.Match")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = MatchNode::make(args[0], args[1]);
-  });
+.set_body_typed(MatchNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<MatchNode>([](const MatchNode* node,
index 9670345..81017d4 100644 (file)
@@ -505,18 +505,18 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
 
 // TODO(@jroesch): move to correct namespace?
 TVM_REGISTER_API("relay._make._alpha_equal")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = AlphaEqualHandler(false).Equal(args[0], args[1]);
+.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
+    return AlphaEqualHandler(false).Equal(a, b);
   });
 
 TVM_REGISTER_API("relay._make._type_alpha_equal")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]);
+.set_body_typed<bool(Type, Type)>([](Type a, Type b) {
+    return AlphaEqualHandler(false).TypeEqual(a, b);
   });
 
 TVM_REGISTER_API("relay._make._graph_equal")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = AlphaEqualHandler(true).Equal(args[0], args[1]);
+.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
+    return AlphaEqualHandler(true).Equal(a, b);
   });
 }  // namespace relay
 }  // namespace tvm
index 9c35173..f60f659 100644 (file)
@@ -52,9 +52,7 @@ SourceName SourceName::Get(const std::string& name) {
 }
 
 TVM_REGISTER_API("relay._make.SourceName")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = SourceName::Get(args[0]);
-  });
+.set_body_typed(SourceName::Get);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
@@ -78,9 +76,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
 TVM_REGISTER_NODE_TYPE(SpanNode);
 
 TVM_REGISTER_API("relay._make.Span")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = SpanNode::make(args[0], args[1], args[2]);
-  });
+.set_body_typed(SpanNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
@@ -91,11 +87,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 TVM_REGISTER_NODE_TYPE(IdNode);
 
 TVM_REGISTER_API("relay._base.set_span")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    NodeRef node_ref = args[0];
+.set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) {
     auto rn = node_ref.as_derived<RelayNode>();
     CHECK(rn);
-    Span sp = args[1];
     rn->span = sp;
 });
 
index 3108bc2..63d41c4 100644 (file)
@@ -39,9 +39,7 @@ Constant ConstantNode::make(runtime::NDArray data) {
 TVM_REGISTER_NODE_TYPE(ConstantNode);
 
 TVM_REGISTER_API("relay._make.Constant")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ConstantNode::make(args[0]);
-  });
+.set_body_typed(ConstantNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
@@ -73,9 +71,7 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
 TVM_REGISTER_NODE_TYPE(TupleNode);
 
 TVM_REGISTER_API("relay._make.Tuple")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = TupleNode::make(args[0]);
-  });
+.set_body_typed(TupleNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
@@ -99,9 +95,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
 TVM_REGISTER_NODE_TYPE(VarNode);
 
 TVM_REGISTER_API("relay._make.Var")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = VarNode::make(args[0].operator std::string(), args[1]);
-  });
+.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
@@ -122,9 +116,7 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
 TVM_REGISTER_NODE_TYPE(GlobalVarNode);
 
 TVM_REGISTER_API("relay._make.GlobalVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = GlobalVarNode::make(args[0]);
-  });
+.set_body_typed(GlobalVarNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
@@ -201,9 +193,7 @@ Function FunctionSetAttr(const Function& func, const std::string& key, const Nod
 TVM_REGISTER_NODE_TYPE(FunctionNode);
 
 TVM_REGISTER_API("relay._make.Function")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]);
-});
+.set_body_typed(FunctionNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<FunctionNode>([](const FunctionNode* node,
@@ -226,9 +216,7 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
 TVM_REGISTER_NODE_TYPE(CallNode);
 
 TVM_REGISTER_API("relay._make.Call")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = CallNode::make(args[0], args[1], args[2], args[3]);
-});
+.set_body_typed(CallNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
@@ -247,9 +235,7 @@ Let LetNode::make(Var var, Expr value, Expr body) {
 TVM_REGISTER_NODE_TYPE(LetNode);
 
 TVM_REGISTER_API("relay._make.Let")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = LetNode::make(args[0], args[1], args[2]);
-  });
+.set_body_typed(LetNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
@@ -267,9 +253,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
 
 TVM_REGISTER_NODE_TYPE(IfNode);
 
-TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = IfNode::make(args[0], args[1], args[2]);
-});
+TVM_REGISTER_API("relay._make.If")
+.set_body_typed(IfNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
@@ -286,9 +271,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
 
 TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
 
-TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = TupleGetItemNode::make(args[0], args[1]);
-});
+TVM_REGISTER_API("relay._make.TupleGetItem")
+.set_body_typed(TupleGetItemNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
@@ -301,9 +285,8 @@ RefCreate RefCreateNode::make(Expr value) {
   return RefCreate(n);
 }
 
-TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = RefCreateNode::make(args[0]);
-});
+TVM_REGISTER_API("relay._make.RefCreate")
+.set_body_typed(RefCreateNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
@@ -317,9 +300,7 @@ RefRead RefReadNode::make(Expr ref) {
 }
 
 TVM_REGISTER_API("relay._make.RefRead")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = RefReadNode::make(args[0]);
-});
+.set_body_typed(RefReadNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
@@ -334,9 +315,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
 }
 
 TVM_REGISTER_API("relay._make.RefWrite")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = RefWriteNode::make(args[0], args[1]);
-});
+.set_body_typed(RefWriteNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
@@ -344,9 +323,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 });
 
 TVM_REGISTER_API("relay._expr.TempExprRealize")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  TempExpr temp = args[0];
-  *ret = temp->Realize();
+.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
+  return temp->Realize();
 });
 
 }  // namespace relay
index d0cd30a..7a6250c 100644 (file)
@@ -346,9 +346,8 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.post_order_visit")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    PackedFunc f = args[1];
-    PostOrderVisit(args[0], [f](const Expr& n) {
+.set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
+    PostOrderVisit(expr, [f](const Expr& n) {
         f(n);
       });
   });
index cb2be8b..89ad608 100644 (file)
@@ -410,14 +410,14 @@ size_t StructuralHash::operator()(const Expr& expr) const {
 }
 
 TVM_REGISTER_API("relay._ir_pass._expr_hash")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = static_cast<int64_t>(RelayHashHandler().Hash(args[0]));
-  });
+.set_body_typed<int64_t(NodeRef)>([](NodeRef ref) {
+  return static_cast<int64_t>(RelayHashHandler().Hash(ref));
+});
 
 TVM_REGISTER_API("relay._ir_pass._type_hash")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = static_cast<int64_t>(RelayHashHandler().TypeHash(args[0]));
-  });
+.set_body_typed<int64_t(Type)>([](Type type) {
+  return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
+});
 
 }  // namespace relay
 }  // namespace tvm
index 38c9756..eabea2e 100644 (file)
@@ -181,66 +181,43 @@ Module ModuleNode::FromExpr(
 TVM_REGISTER_NODE_TYPE(ModuleNode);
 
 TVM_REGISTER_API("relay._make.Module")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = ModuleNode::make(args[0], args[1]);
-  });
+.set_body_typed(ModuleNode::make);
 
 TVM_REGISTER_API("relay._make.Module_Add")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    mod->Add(args[1], args[2], args[3]);
-  });
+.set_body_method<Module>(&ModuleNode::Add);
 
 TVM_REGISTER_API("relay._module.Module_AddDef")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    mod->AddDef(args[1], args[2]);
-  });
+.set_body_method<Module>(&ModuleNode::AddDef);
 
 TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    *ret = mod->GetGlobalVar(args[1]);
-  });
+.set_body_method<Module>(&ModuleNode::GetGlobalVar);
 
 TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    *ret = mod->GetGlobalTypeVar(args[1]);
-  });
+.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
 
 TVM_REGISTER_API("relay._module.Module_Lookup")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    GlobalVar var = args[1];
-    *ret = mod->Lookup(var);
+.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
+    return mod->Lookup(var);
   });
 
 TVM_REGISTER_API("relay._module.Module_Lookup_str")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    std::string var_name = args[1];
-    *ret = mod->Lookup(var_name);
+.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
+    return mod->Lookup(var);
   });
 
 TVM_REGISTER_API("relay._module.Module_LookupDef")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    GlobalTypeVar var = args[1];
-    *ret = mod->LookupDef(var);
+.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
+    return mod->LookupDef(var);
   });
 
 TVM_REGISTER_API("relay._module.Module_LookupDef_str")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    std::string var_name = args[1];
-    *ret = mod->LookupDef(var_name);
+.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
+    return mod->LookupDef(var);
   });
 
 TVM_REGISTER_API("relay._module.Module_Update")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    Module mod = args[0];
-    mod->Update(args[1]);
+.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
+    mod->Update(from);
   });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
index fb0d919..8f0bdcb 100644 (file)
@@ -56,10 +56,7 @@ IndexExpr TensorTypeNode::Size() const {
 TVM_REGISTER_NODE_TYPE(TensorTypeNode);
 
 TVM_REGISTER_API("relay._make.TensorType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  Array<IndexExpr> shape = args[0];
-  *ret = TensorTypeNode::make(shape, args[1]);
-});
+.set_body_typed(TensorTypeNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
@@ -77,10 +74,8 @@ TypeVar TypeVarNode::make(std::string name, Kind kind) {
 TVM_REGISTER_NODE_TYPE(TypeVarNode);
 
 TVM_REGISTER_API("relay._make.TypeVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  int kind = args[1];
-  *ret =
-    TypeVarNode::make(args[0], static_cast<Kind>(kind));
+.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
+    return TypeVarNode::make(name, static_cast<Kind>(kind));
     });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
@@ -100,10 +95,9 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
 TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
 
 TVM_REGISTER_API("relay._make.GlobalTypeVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  int kind = args[1];
-  *ret = GlobalTypeVarNode::make(args[0], static_cast<Kind>(kind));
-});
+.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
+    return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
+    });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
@@ -122,9 +116,7 @@ TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
 TVM_REGISTER_NODE_TYPE(TypeCallNode);
 
 TVM_REGISTER_API("relay._make.TypeCall")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = TypeCallNode::make(args[0], args[1]);
-});
+.set_body_typed(TypeCallNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TypeCallNode>([](const TypeCallNode* node,
@@ -142,9 +134,8 @@ IncompleteType IncompleteTypeNode::make(Kind kind) {
 TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
 
 TVM_REGISTER_API("relay._make.IncompleteType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    int kind = args[0];
-    *ret = IncompleteTypeNode::make(static_cast<Kind>(kind));
+.set_body_typed<IncompleteType(int)>([](int kind) {
+    return IncompleteTypeNode::make(static_cast<Kind>(kind));
   });
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
@@ -169,9 +160,7 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
 TVM_REGISTER_NODE_TYPE(FuncTypeNode);
 
 TVM_REGISTER_API("relay._make.FuncType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
-});
+.set_body_typed(FuncTypeNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
@@ -196,9 +185,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
 TVM_REGISTER_NODE_TYPE(TypeRelationNode);
 
 TVM_REGISTER_API("relay._make.TypeRelation")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
-});
+.set_body_typed(TypeRelationNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
@@ -216,9 +203,7 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
 TVM_REGISTER_NODE_TYPE(TupleTypeNode);
 
 TVM_REGISTER_API("relay._make.TupleType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = TupleTypeNode::make(args[0]);
-});
+.set_body_typed(TupleTypeNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
@@ -233,9 +218,7 @@ RefType RefTypeNode::make(Type value) {
 }
 
 TVM_REGISTER_API("relay._make.RefType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = RefTypeNode::make(args[0]);
-});
+.set_body_typed(RefTypeNode::make);
 
 TVM_REGISTER_NODE_TYPE(RefTypeNode);
 
index 3aea0c0..37fb090 100644 (file)
@@ -64,9 +64,7 @@ Expr MakeDebug(Expr expr, std::string name) {
 }
 
 TVM_REGISTER_API("relay.op._make.debug")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv);
-  });
+.set_body_typed(MakeDebug);
 
 }  // namespace relay
 }  // namespace tvm
index 7ca762e..ffa489e 100644 (file)
@@ -105,9 +105,7 @@ Expr MakeResize(Expr data,
 
 
 TVM_REGISTER_API("relay.op.image._make.resize")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 5>(MakeResize, args, rv);
-  });
+.set_body_typed(MakeResize);
 
 
 RELAY_REGISTER_OP("image.resize")
index f2c0a27..97cba79 100644 (file)
@@ -170,9 +170,7 @@ Expr MakeConv2D(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.conv2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
-  });
+.set_body_typed(MakeConv2D);
 
 
 RELAY_REGISTER_OP("nn.conv2d")
@@ -324,9 +322,7 @@ Expr MakeConv2DTranspose(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 12>(MakeConv2DTranspose, args, rv);
-  });
+.set_body_typed(MakeConv2DTranspose);
 
 RELAY_REGISTER_OP("nn.conv2d_transpose")
 .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
@@ -465,9 +461,7 @@ Expr MakeConv2DWinograd(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 13>(MakeConv2DWinograd, args, rv);
-  });
+.set_body_typed(MakeConv2DWinograd);
 
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
@@ -530,9 +524,7 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,
 
 
 TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
-});
+.set_body_typed(MakeConv2DWinogradWeightTransform);
 
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
@@ -580,9 +572,7 @@ Expr MakeConv2DWinogradNNPACK(Expr data,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
-});
+.set_body_typed(MakeConv2DWinogradNNPACK);
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
 .describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
@@ -649,9 +639,7 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 3>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
-});
+.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
 .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
@@ -698,9 +686,7 @@ Expr MakeConv2DNCHWc(Expr data,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 12>(MakeConv2DNCHWc, args, rv);
-  });
+.set_body_typed(MakeConv2DNCHWc);
 
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
@@ -750,9 +736,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 12>(MakeDepthwiseConv2DNCHWc, args, rv);
-  });
+.set_body_typed(MakeDepthwiseConv2DNCHWc);
 
 
 RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
@@ -910,9 +894,7 @@ Expr MakeDeformableConv2D(Expr data,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 14>(MakeDeformableConv2D, args, rv);
-  });
+.set_body_typed(MakeDeformableConv2D);
 
 
 }  // namespace relay
index d244313..2356634 100644 (file)
@@ -78,9 +78,7 @@ Expr MakeBiasAdd(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.bias_add")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeBiasAdd, args, rv);
-  });
+.set_body_typed(MakeBiasAdd);
 
 
 RELAY_REGISTER_OP("nn.bias_add")
@@ -145,9 +143,7 @@ Expr MakeDense(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.dense")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeDense, args, rv);
-  });
+.set_body_typed(MakeDense);
 
 
 RELAY_REGISTER_OP("nn.dense")
@@ -179,9 +175,7 @@ Expr MakeLeakyRelu(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.leaky_relu")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeLeakyRelu, args, rv);
-  });
+.set_body_typed(MakeLeakyRelu);
 
 
 RELAY_REGISTER_OP("nn.leaky_relu")
@@ -244,9 +238,7 @@ Expr MakePRelu(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.prelu")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakePRelu, args, rv);
-  });
+.set_body_typed(MakePRelu);
 
 
 RELAY_REGISTER_OP("nn.prelu")
@@ -276,17 +268,14 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
 TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
 
 TVM_REGISTER_API("relay.op.nn._make.softmax")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  auto make_func = [](Expr data, int axis) {
-    auto attrs = make_node<SoftmaxAttrs>();
-    attrs->axis = axis;
-    static const Op& op = Op::Get("nn.softmax");
-    return CallNode::make(op, {data}, Attrs(attrs), {});
-  };
-
-  runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
+.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
+  auto attrs = make_node<SoftmaxAttrs>();
+  attrs->axis = axis;
+  static const Op& op = Op::Get("nn.softmax");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
 });
 
+
 RELAY_REGISTER_OP("nn.softmax")
     .describe(R"code(Softmax layer.
 
@@ -314,15 +303,11 @@ RELAY_REGISTER_OP("nn.softmax")
 
 // relay.nn.log_softmax
 TVM_REGISTER_API("relay.op.nn._make.log_softmax")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  auto make_func = [](Expr data, int axis) {
-    auto attrs = make_node<SoftmaxAttrs>();
-    attrs->axis = axis;
-    static const Op& op = Op::Get("nn.log_softmax");
-    return CallNode::make(op, {data}, Attrs(attrs), {});
-  };
-
-  runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
+.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
+  auto attrs = make_node<SoftmaxAttrs>();
+  attrs->axis = axis;
+  static const Op& op = Op::Get("nn.log_softmax");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
 });
 
 RELAY_REGISTER_OP("nn.log_softmax")
@@ -382,9 +367,7 @@ Expr MakeBatchFlatten(Expr data) {
 
 
 TVM_REGISTER_API("relay.op.nn._make.batch_flatten")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 1>(MakeBatchFlatten, args, rv);
-  });
+.set_body_typed(MakeBatchFlatten);
 
 
 RELAY_REGISTER_OP("nn.batch_flatten")
@@ -424,7 +407,7 @@ Example::
 
 // relu
 TVM_REGISTER_API("relay.op.nn._make.relu")
-.set_body_typed<Expr(Expr)>([](Expr data) {
+.set_body_typed<Call(Expr)>([](Expr data) {
     static const Op& op = Op::Get("nn.relu");
     return CallNode::make(op, {data}, Attrs(), {});
   });
@@ -469,9 +452,7 @@ Expr MakeLRN(Expr data,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.lrn")
-  .set_body([](const TVMArgs& args, TVMRetValue* rv) {
-      runtime::detail::unpack_call<Expr, 6>(MakeLRN, args, rv);
-  });
+.set_body_typed(MakeLRN);
 
 RELAY_REGISTER_OP("nn.lrn")
 .describe(R"code(LRN layer.
@@ -509,9 +490,7 @@ Expr MakeL2Normalize(Expr data,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
-  });
+.set_body_typed(MakeL2Normalize);
 
 RELAY_REGISTER_OP("nn.l2_normalize")
 .describe(R"code(L2 Normalization layer.
@@ -556,9 +535,7 @@ Expr MakeDropout(Expr data, double rate) {
 }
 
 TVM_REGISTER_API("relay.op.nn._make.dropout")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeDropout, args, rv);
-  });
+.set_body_typed(MakeDropout);
 
 RELAY_REGISTER_OP("nn.dropout")
 .describe(R"code(Applies the dropout operation to the input array.
@@ -622,9 +599,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi
 }
 
 TVM_REGISTER_API("relay.op.nn._make.batch_norm")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 9>(MakeBatchNorm, args, rv);
-  });
+.set_body_typed(MakeBatchNorm);
 
 RELAY_REGISTER_OP("nn.batch_norm")
 .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
@@ -711,9 +686,7 @@ Expr MakeBatchMatmul(Expr x,
 
 
 TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeBatchMatmul, args, rv);
-  });
+.set_body_typed(MakeBatchMatmul);
 
 
 RELAY_REGISTER_OP("nn.batch_matmul")
index c653e3b..98b9d67 100644 (file)
@@ -115,9 +115,7 @@ Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
 }
 
 TVM_REGISTER_API("relay.op.nn._make.pad")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
-  });
+.set_body_typed(MakePad);
 
 RELAY_REGISTER_OP("nn.pad")
 .describe(R"code(Pad for n-D tensor.
index 0717ee5..df238b3 100644 (file)
@@ -186,9 +186,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.max_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 6>(MakeMaxPool2D, args, rv);
-  });
+.set_body_typed(MakeMaxPool2D);
 
 
 RELAY_REGISTER_OP("nn.max_pool2d")
@@ -242,9 +240,7 @@ Expr MakeAvgPool2D(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.avg_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 7>(MakeAvgPool2D, args, rv);
-  });
+.set_body_typed(MakeAvgPool2D);
 
 
 RELAY_REGISTER_OP("nn.avg_pool2d")
@@ -345,9 +341,7 @@ Expr MakeGlobalAvgPool2D(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeGlobalAvgPool2D, args, rv);
-  });
+.set_body_typed(MakeGlobalAvgPool2D);
 
 // GlobalAvgPool
 RELAY_REGISTER_OP("nn.global_avg_pool2d")
@@ -378,9 +372,7 @@ Expr MakeGlobalMaxPool2D(Expr data,
 }
 
 TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeGlobalMaxPool2D, args, rv);
-  });
+.set_body_typed(MakeGlobalMaxPool2D);
 
 
 RELAY_REGISTER_OP("nn.global_max_pool2d")
index 98458b9..acefaf3 100644 (file)
@@ -110,9 +110,7 @@ Expr MakeUpSampling(Expr data,
 
 
 TVM_REGISTER_API("relay.op.nn._make.upsampling")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 4>(MakeUpSampling, args, rv);
-  });
+.set_body_typed(MakeUpSampling);
 
 
 RELAY_REGISTER_OP("nn.upsampling")
index 7bade46..b889b6c 100644 (file)
@@ -265,8 +265,8 @@ bool ReduceRel(const Array<Type>& types,
 
 #define RELAY_REGISTER_REDUCE_OP(OpName)                           \
   TVM_REGISTER_API("relay.op._make." OpName)                       \
-  .set_body([](const TVMArgs& args, TVMRetValue* rv) {             \
-    auto make_func = [](Expr data,                                 \
+  .set_body_typed<Call(Expr, Array<Integer>, bool, bool)>([](      \
+                        Expr data,                                 \
                         Array<Integer> axis,                       \
                         bool keepdims,                             \
                         bool exclude) {                            \
@@ -276,8 +276,6 @@ bool ReduceRel(const Array<Type>& types,
       attrs->exclude = exclude;                                    \
       static const Op& op = Op::Get(OpName);                       \
       return CallNode::make(op, {data}, Attrs(attrs), {});         \
-    };                                                             \
-    runtime::detail::unpack_call<Expr, 4>(make_func, args, rv);    \
     });                                                            \
   RELAY_REGISTER_OP(OpName)                                        \
   .set_num_inputs(1)                                               \
index f86156b..873e75d 100644 (file)
@@ -81,9 +81,7 @@ Expr MakeCast(Expr data,
 }
 
 TVM_REGISTER_API("relay._make.cast")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
-});
+.set_body_typed(MakeCast);
 
 RELAY_REGISTER_OP("cast")
 .describe(R"code(Cast the data into a new data type.
@@ -161,9 +159,7 @@ Expr MakeExpandDims(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.expand_dims")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeExpandDims, args, rv);
-});
+.set_body_typed(MakeExpandDims);
 
 RELAY_REGISTER_OP("expand_dims")
 .describe(R"code(Insert `num_newaxis` axises at the position given by `axis`
@@ -279,9 +275,7 @@ Expr MakeConcatenate(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.concatenate")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeConcatenate, args, rv);
-});
+.set_body_typed(MakeConcatenate);
 
 RELAY_REGISTER_OP("concatenate")
 .describe(R"code(Concatenate the input tensors along the given axis.
@@ -367,9 +361,7 @@ Expr MakeStack(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.stack")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeStack, args, rv);
-});
+.set_body_typed(MakeStack);
 
 RELAY_REGISTER_OP("stack")
 .describe(R"code(Stack the input tensors along the given axis.
@@ -461,9 +453,7 @@ Expr MakeTranspose(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.transpose")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeTranspose, args, rv);
-});
+.set_body_typed(MakeTranspose);
 
 RELAY_REGISTER_OP("transpose")
 .describe(R"code(Permutes the dimensions of an array.
@@ -598,9 +588,7 @@ Expr MakeReshape(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.reshape")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeReshape, args, rv);
-});
+.set_body_typed(MakeReshape);
 
 RELAY_REGISTER_OP("reshape")
 .describe(R"code(Reshapes the input array.
@@ -698,9 +686,7 @@ Expr MakeReshapeLike(Expr data,
 
 
 TVM_REGISTER_API("relay.op._make.reshape_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
-});
+.set_body_typed(MakeReshapeLike);
 
 
 RELAY_REGISTER_OP("reshape_like")
@@ -790,9 +776,7 @@ Expr MakeTake(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.take")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
-});
+.set_body_typed(MakeTake);
 
 RELAY_REGISTER_OP("take")
 .describe(R"code(Take elements from an array along an axis.
@@ -873,9 +857,7 @@ Expr MakeFull(Expr fill_value,
 }
 
 TVM_REGISTER_API("relay.op._make.full")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeFull, args, rv);
-});
+.set_body_typed(MakeFull);
 
 RELAY_REGISTER_OP("full")
 .describe(R"code(Fill array with scalar value.
@@ -910,9 +892,7 @@ Expr MakeZeros(Array<IndexExpr> shape,
 }
 
 TVM_REGISTER_API("relay.op._make.zeros")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeZeros, args, rv);
-  });
+.set_body_typed(MakeZeros);
 
 RELAY_REGISTER_OP("zeros")
 .describe(R"code(Fill array with zeros.
@@ -933,9 +913,7 @@ Expr MakeOnes(Array<IndexExpr> shape,
 }
 
 TVM_REGISTER_API("relay.op._make.ones")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeOnes, args, rv);
-  });
+.set_body_typed(MakeOnes);
 
 RELAY_REGISTER_OP("ones")
 .describe(R"code(Fill array with ones.
@@ -982,9 +960,7 @@ Expr MakeFullLike(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.full_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeFullLike, args, rv);
-  });
+.set_body_typed(MakeFullLike);
 
 RELAY_REGISTER_OP("full_like")
 .describe(R"code(Return an scalar value array with the same shape
@@ -1041,9 +1017,7 @@ Expr MakeArange(tvm::Expr start,
 }
 
 TVM_REGISTER_API("relay.op._make.arange")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv);
-});
+.set_body_typed(MakeArange);
 
 RELAY_REGISTER_OP("arange")
 .describe(R"code(Returns evenly spaced values within a given interval.
@@ -1117,9 +1091,7 @@ Expr MakeRepeat(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.repeat")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
-});
+.set_body_typed(MakeRepeat);
 
 RELAY_REGISTER_OP("repeat")
 .describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
@@ -1217,9 +1189,7 @@ Expr MakeTile(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.tile")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
-});
+.set_body_typed(MakeTile);
 
 RELAY_REGISTER_OP("tile")
 .describe(R"code(Repeat the whole array multiple times.
@@ -1280,9 +1250,7 @@ Expr MakeReverse(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.reverse")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeReverse, args, rv);
-});
+.set_body_typed(MakeReverse);
 
 RELAY_REGISTER_OP("reverse")
 .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
@@ -1345,9 +1313,7 @@ Array<Tensor> WhereCompute(const Attrs& attrs,
 }
 
 TVM_REGISTER_API("relay.op._make.where")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv);
-});
+.set_body_typed(MakeWhere);
 
 RELAY_REGISTER_OP("where")
 .describe(R"code(
@@ -1400,9 +1366,7 @@ Expr MakeSqueeze(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.squeeze")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
-  });
+.set_body_typed(MakeSqueeze);
 
 
 bool SqueezeRel(const Array<Type>& types,
@@ -1507,9 +1471,7 @@ Array<Tensor> CollapseSumLikeCompute(const Attrs& attrs,
 }
 
 TVM_REGISTER_API("relay.op._make.collapse_sum_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
-  });
+.set_body_typed(MakeCollapseSumLike);
 
 RELAY_REGISTER_OP("collapse_sum_like")
 .describe(R"code(Collapse the first input to match the shape of the second input.
@@ -1554,9 +1516,7 @@ Array<Tensor> BroadCastToCompute(const Attrs& attrs,
 }
 
 TVM_REGISTER_API("relay.op._make.broadcast_to")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeBroadCastTo, args, rv);
-  });
+.set_body_typed(MakeBroadCastTo);
 
 RELAY_REGISTER_OP("broadcast_to")
 .describe(R"code(Broadcast the first input to match the shape argument.
@@ -1594,9 +1554,7 @@ Array<Tensor> BroadCastToLikeCompute(const Attrs& attrs,
 }
 
 TVM_REGISTER_API("relay.op._make.broadcast_to_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
-  });
+.set_body_typed(MakeBroadCastToLike);
 
 RELAY_REGISTER_OP("broadcast_to_like")
 .describe(R"code(Broadcast the first input to match the shape of the second input.
@@ -1806,9 +1764,7 @@ Array<Tensor> StridedSliceCompute(const Attrs& attrs,
 
 
 TVM_REGISTER_API("relay.op._make.strided_slice")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 4>(MakeStridedSlice, args, rv);
-  });
+.set_body_typed(MakeStridedSlice);
 
 
 RELAY_REGISTER_OP("strided_slice")
@@ -2081,9 +2037,7 @@ Array<Tensor> SliceLikeCompute(const Attrs& attrs,
 
 
 TVM_REGISTER_API("relay.op._make.slice_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 3>(MakeSliceLike, args, rv);
-});
+.set_body_typed(MakeSliceLike);
 
 
 RELAY_REGISTER_OP("slice_like")
@@ -2144,9 +2098,7 @@ Expr MakeLayoutTransform(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.layout_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 3>(MakeLayoutTransform, args, rv);
-});
+.set_body_typed(MakeLayoutTransform);
 
 RELAY_REGISTER_OP("layout_transform")
 .describe(R"code(Transform the input data layout.
@@ -2174,9 +2126,7 @@ Expr MakeReverseReshape(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeReverseReshape, args, rv);
-});
+.set_body_typed(MakeReverseReshape);
 
 RELAY_REGISTER_OP("_contrib_reverse_reshape")
 .describe(R"code(Reshapes the input array where the special values are inferred from
@@ -2250,9 +2200,7 @@ Expr MakeGatherND(Expr data,
 }
 
 TVM_REGISTER_API("relay.op._make.gather_nd")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
-});
+.set_body_typed(MakeGatherND);
 
 RELAY_REGISTER_OP("gather_nd")
 .describe(R"code(Gather elements or slices from data and store to
index 2c9f76b..56a03ff 100644 (file)
@@ -73,9 +73,7 @@ Expr MakeMultiBoxPrior(Expr data,
 
 
 TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxPrior, args, rv);
-});
+.set_body_typed(MakeMultiBoxPrior);
 
 
 RELAY_REGISTER_OP("vision.multibox_prior")
@@ -147,9 +145,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob,
 }
 
 TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxTransformLoc, args, rv);
-});
+.set_body_typed(MakeMultiBoxTransformLoc);
 
 RELAY_REGISTER_OP("vision.multibox_transform_loc")
 .describe(R"doc("Location transformation for multibox detection."
index 75161bf..5344bce 100644 (file)
@@ -59,9 +59,7 @@ Expr MakeGetValidCounts(Expr data,
 
 
 TVM_REGISTER_API("relay.op.vision._make.get_valid_counts")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 2>(MakeGetValidCounts, args, rv);
-});
+.set_body_typed(MakeGetValidCounts);
 
 
 RELAY_REGISTER_OP("vision.get_valid_counts")
@@ -125,9 +123,7 @@ Expr MakeNMS(Expr data,
 
 
 TVM_REGISTER_API("relay.op.vision._make.non_max_suppression")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 9>(MakeNMS, args, rv);
-});
+.set_body_typed(MakeNMS);
 
 
 RELAY_REGISTER_OP("vision.non_max_suppression")
index 70fe292..0522ab8 100644 (file)
@@ -62,9 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spa
 }
 
 TVM_REGISTER_API("relay.op.vision._make.roi_align")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 6>(MakeROIAlign, args, rv);
-  });
+.set_body_typed(MakeROIAlign);
 
 RELAY_REGISTER_OP("vision.roi_align")
     .describe(R"doc(ROI Align operator.
@@ -114,9 +112,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spat
 }
 
 TVM_REGISTER_API("relay.op.vision._make.roi_pool")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 5>(MakeROIPool, args, rv);
-  });
+.set_body_typed(MakeROIPool);
 
 RELAY_REGISTER_OP("vision.roi_pool")
     .describe(R"doc(ROI Pool operator.
@@ -182,9 +178,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr>
 }
 
 TVM_REGISTER_API("relay.op.vision._make.proposal")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-    runtime::detail::unpack_call<Expr, 11>(MakeProposal, args, rv);
-  });
+.set_body_typed(MakeProposal);
 
 RELAY_REGISTER_OP("vision.proposal")
     .describe(R"code(Generate region proposals via RPN.
index 310e30a..0a1d961 100644 (file)
@@ -71,9 +71,7 @@ Expr MakeYoloReorg(Expr data,
 
 
 TVM_REGISTER_API("relay.op.vision._make.yolo_reorg")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 2>(MakeYoloReorg, args, rv);
-});
+.set_body_typed(MakeYoloReorg);
 
 
 RELAY_REGISTER_OP("vision.yolo_reorg")
index c4350cc..9a46027 100644 (file)
@@ -61,9 +61,7 @@ Expr CanonicalizeOps(const Expr& e) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-*ret = CanonicalizeOps(args[0]);
-});
+.set_body_typed(CanonicalizeOps);
 
 }  // namespace relay
 }  // namespace tvm
index cd7a852..7e76322 100644 (file)
@@ -355,9 +355,7 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = CombineParallelConv2D(args[0], args[1]);
-});
+.set_body_typed(CombineParallelConv2D);
 
 }  // namespace relay
 }  // namespace tvm
index 06cd909..c5c4f33 100644 (file)
@@ -148,9 +148,7 @@ Expr DeadCodeElimination(const Expr& e) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = DeadCodeElimination(args[0]);
-  });
+.set_body_typed(DeadCodeElimination);
 
 }  // namespace relay
 }  // namespace tvm
index 6f06383..46f4268 100644 (file)
@@ -493,19 +493,13 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-  *ret = CollectDeviceInfo(args[0]);
-});
+.set_body_typed(CollectDeviceInfo);
 
 TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-  *ret = RewriteAnnotatedOps(args[0], args[1]);
-});
+.set_body_typed(RewriteAnnotatedOps);
 
 TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-  *ret = CollectDeviceAnnotationOps(args[0]);
-});
+.set_body_typed(CollectDeviceAnnotationOps);
 
 }  // namespace relay
 }  // namespace tvm
index 9d55a54..5bfee6c 100644 (file)
@@ -210,9 +210,7 @@ Expr FoldConstant(const Expr& expr) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.FoldConstant")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = FoldConstant(args[0]);
-});
+.set_body_typed(FoldConstant);
 
 }  // namespace relay
 }  // namespace tvm
index 4b50c64..6de9c2d 100644 (file)
@@ -912,8 +912,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.FuseOps")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = FuseOps(args[0], args[1]);
-});
+.set_body_typed(FuseOps);
 }  // namespace relay
 }  // namespace tvm
index 8a5d1df..5c5ea01 100644 (file)
@@ -247,10 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  CHECK_EQ(args.size(), 2);
-  *ret = FirstOrderGradient(args[0], args[1]);
-});
+.set_body_typed(FirstOrderGradient);
 
 struct ReverseADType : TypeMutator {
   Type VisitType_(const TensorTypeNode* ttn) final {
@@ -263,7 +260,7 @@ struct ReverseAD : ExprMutator {
   Var bp;
   const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
 
-  ReverseAD(const Var& bp) : bp(bp) { }
+  ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*)
 
   Expr VisitExpr_(const OpNode* op) final {
     LOG(FATAL) << "op should only be inside call";
@@ -349,10 +346,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.gradient")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  CHECK_EQ(args.size(), 2);
-  *ret = Gradient(args[0], args[1]);
-});
+.set_body_typed(Gradient);
 
 }  // namespace relay
 }  // namespace tvm
index 702e703..c9ee4ee 100644 (file)
@@ -147,9 +147,7 @@ int64_t GetTotalMacNumber(const Expr& expr) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-  *ret = GetTotalMacNumber(args[0]);
-});
+.set_body_typed(GetTotalMacNumber);
 
 }  // namespace mac_count
 }  // namespace relay
index fad3728..d607247 100644 (file)
@@ -426,12 +426,7 @@ Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
 TVM_REGISTER_NODE_TYPE(PassInfoNode);
 
 TVM_REGISTER_API("relay._ir_pass.PassInfo")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  int opt_level = args[0];
-  std::string name = args[1];
-  tvm::Array<tvm::Expr> required = args[2];
-  *ret = PassInfoNode::make(opt_level, name, required);
-});
+.set_body_typed(PassInfoNode::make);
 
 TVM_REGISTER_API("relay._ir_pass.Info")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -456,13 +451,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 TVM_REGISTER_NODE_TYPE(ModulePassNode);
 
 TVM_REGISTER_API("relay._ir_pass.CreateModulePass")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  PackedFunc pass_func = args[0];
-  int opt_level = args[1];
-  std::string name = args[2];
-  tvm::Array<tvm::Expr> required = args[3];
-  *ret = CreateModulePass(pass_func, opt_level, name, required);
-});
+.set_body_typed(CreateModulePass);
 
 TVM_REGISTER_API("relay._ir_pass.RunPass")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -487,13 +476,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 TVM_REGISTER_NODE_TYPE(FunctionPassNode);
 
 TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  PackedFunc pass_func = args[0];
-  int opt_level = args[1];
-  std::string name = args[2];
-  tvm::Array<tvm::Expr> required = args[3];
-  *ret = CreateFunctionPass(pass_func, opt_level, name, required);
-});
+.set_body_typed(CreateFunctionPass);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
@@ -541,9 +524,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext")
 TVM_REGISTER_NODE_TYPE(PassContextNode);
 
 TVM_REGISTER_API("relay._ir_pass.PassContext")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = PassContextNode::make();
-});
+.set_body_typed(PassContextNode::make);
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<PassContextNode>([](const PassContextNode* node,
index cb0f9d9..5fa3053 100644 (file)
@@ -571,20 +571,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_REGISTER_API("relay._quantize._GetCurrentQConfig")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = QConfig::Current();
-  });
+.set_body_typed(QConfig::Current);
 
 TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  QConfig target = args[0];
-  QConfig::EnterQConfigScope(target);
-  });
+.set_body_typed(QConfig::EnterQConfigScope);
 
 TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  QConfig::ExitQConfigScope();
-  });
+.set_body_typed(QConfig::ExitQConfigScope);
 
 }  // namespace quantize
 }  // namespace relay
index 28ebaaa..cecebc5 100644 (file)
@@ -103,9 +103,7 @@ Expr SimplifyInference(const Expr& e) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.simplify_inference")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = SimplifyInference(args[0]);
-  });
+.set_body_typed(SimplifyInference);
 
 }  // namespace relay
 }  // namespace tvm
index bac6fd2..5507de4 100644 (file)
@@ -491,9 +491,7 @@ Expr ToANormalForm(const Expr& e, const Module& m) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = ToANormalForm(args[0], args[1]);
-  });
+.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
 
 }  // namespace relay
 }  // namespace tvm
index cc7e1a4..490a80f 100644 (file)
@@ -77,9 +77,7 @@ Expr ToGraphNormalForm(const Expr& e) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = ToGraphNormalForm(args[0]);
-});
+.set_body_typed(ToGraphNormalForm);
 
 }  // namespace relay
 }  // namespace tvm
index 5abf0b7..30d4d79 100644 (file)
@@ -801,8 +801,8 @@ Function InferType(const Function& func,
 }
 
 TVM_REGISTER_API("relay._ir_pass.infer_type")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = InferType(args[0], args[1]);
+.set_body_typed<Expr(const Expr&, const Module&)>([](const Expr& expr, const Module& mod_ref) {
+    return InferType(expr, mod_ref);
   });
 }  // namespace relay
 }  // namespace tvm
index fa655a7..8e02cf1 100644 (file)
@@ -275,9 +275,7 @@ tvm::Array<Var> AllVars(const Expr& expr) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.free_vars")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = FreeVars(args[0]);
-  });
+.set_body_typed(FreeVars);
 
 TVM_REGISTER_API("relay._ir_pass.bound_vars")
   .set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -290,9 +288,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_vars")
     });
 
 TVM_REGISTER_API("relay._ir_pass.all_vars")
-  .set_body([](TVMArgs args, TVMRetValue* ret) {
-      *ret = AllVars(args[0]);
-    });
+.set_body_typed(AllVars);
 
 TVM_REGISTER_API("relay._ir_pass.free_type_vars")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
index 86107d6..4eaaa93 100644 (file)
@@ -79,10 +79,7 @@ bool WellFormed(const Expr& e) {
 }
 
 TVM_REGISTER_API("relay._ir_pass.well_formed")
-  .set_body([](TVMArgs args, TVMRetValue *ret) {
-      Expr e = args[0];
-      *ret = WellFormed(e);
-    });
+.set_body_typed(WellFormed);
 
 }  // namespace relay
 }  // namespace tvm
index a46f0eb..55d9e64 100644 (file)
@@ -308,18 +308,12 @@ Module CUDAModuleLoadBinary(void* strm) {
 }
 
 TVM_REGISTER_GLOBAL("module.loadfile_cubin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = CUDAModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(CUDAModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("module.loadfile_ptx")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = CUDAModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(CUDAModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = CUDAModuleLoadBinary(args[0]);
-  });
+.set_body_typed(CUDAModuleLoadBinary);
 }  // namespace runtime
 }  // namespace tvm
index e1f0e3f..af809d7 100644 (file)
@@ -310,13 +310,9 @@ Module MetalModuleLoadBinary(void* strm) {
 }
 
 TVM_REGISTER_GLOBAL("module.loadfile_metal")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = MetalModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(MetalModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("module.loadbinary_metal")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = MetalModuleLoadBinary(args[0]);
-  });
+.set_body_typed(MetalModuleLoadBinary);
 }  // namespace runtime
 }  // namespace tvm
index 38e82ed..d9a3aa2 100644 (file)
@@ -69,9 +69,7 @@ Module AOCLModuleLoadFile(const std::string& file_name,
 }
 
 TVM_REGISTER_GLOBAL("module.loadfile_aocx")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = AOCLModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(AOCLModuleLoadFile);
 
 }  // namespace runtime
 }  // namespace tvm
index 543ffb9..971ae34 100644 (file)
@@ -281,18 +281,12 @@ Module OpenCLModuleLoadBinary(void* strm) {
 }
 
 TVM_REGISTER_GLOBAL("module.loadfile_cl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = OpenCLModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(OpenCLModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("module.loadfile_clbin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = OpenCLModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(OpenCLModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = OpenCLModuleLoadBinary(args[0]);
-  });
+.set_body_typed(OpenCLModuleLoadBinary);
 }  // namespace runtime
 }  // namespace tvm
index 9bfc9d2..900d564 100644 (file)
@@ -80,13 +80,9 @@ Module SDAccelModuleLoadBinary(void* strm) {
 }
 
 TVM_REGISTER_GLOBAL("module.loadfile_xclbin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = SDAccelModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(SDAccelModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = SDAccelModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(SDAccelModuleLoadFile);
 }  // namespace runtime
 }  // namespace tvm
index b7b93c7..6531f97 100644 (file)
@@ -243,14 +243,10 @@ Module ROCMModuleLoadBinary(void* strm) {
 
 
 TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = ROCMModuleLoadBinary(args[0]);
-  });
+.set_body_typed(ROCMModuleLoadBinary);
 
 
 TVM_REGISTER_GLOBAL("module.loadbinary_hip")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = ROCMModuleLoadBinary(args[0]);
-  });
+.set_body_typed(ROCMModuleLoadBinary);
 }  // namespace runtime
 }  // namespace tvm
index dfbdb26..7a142f3 100644 (file)
@@ -64,8 +64,6 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
 }
 
 TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = CreateEventDrivenServer(args[0], args[1], args[2]);
-  });
+.set_body_typed(CreateEventDrivenServer);
 }  // namespace runtime
 }  // namespace tvm
index 33d852f..16528bc 100644 (file)
@@ -110,9 +110,7 @@ void RPCServerLoop(int sockfd) {
 }
 
 TVM_REGISTER_GLOBAL("rpc._Connect")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = RPCClientConnect(args[0], args[1], args[2]);
-  });
+.set_body_typed(RPCClientConnect);
 
 TVM_REGISTER_GLOBAL("rpc._ServerLoop")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
index 5e6f96b..4e7d422 100644 (file)
@@ -142,9 +142,7 @@ Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
 }
 
 TVM_REGISTER_GLOBAL("module.loadfile_stackvm")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = StackVMModuleNode::LoadFromFile(args[0], args[1]);
-  });
+.set_body_typed(StackVMModuleNode::LoadFromFile);
 
 }  // namespace runtime
 }  // namespace tvm
index cfa80be..c1db14d 100644 (file)
@@ -427,13 +427,9 @@ Module VulkanModuleLoadBinary(void* strm) {
 }
 
 TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = VulkanModuleLoadFile(args[0], args[1]);
-  });
+.set_body_typed(VulkanModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    *rv = VulkanModuleLoadBinary(args[0]);
-  });
+.set_body_typed(VulkanModuleLoadBinary);
 }  // namespace runtime
 }  // namespace tvm
index 273d43b..12bc53c 100644 (file)
@@ -60,16 +60,16 @@ struct RPCEnv {
 };
 
 TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+.set_body_typed<std::string(std::string)>([](std::string path) {
     static RPCEnv env;
-    *rv = env.GetPath(args[0]);
+    return env.GetPath(path);
   });
 
 TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
-    std::string file_name = "/rpc/" + args[0].operator std::string();
-    *rv = Module::LoadFromFile(file_name, "");
+.set_body_typed<Module(std::string)>([](std::string path) {
+    std::string file_name = "/rpc/" + path;
     LOG(INFO) << "Load module from " << file_name << " ...";
+    return Module::LoadFromFile(file_name, "");
   });
 }  // namespace contrib
 }  // namespace tvm