Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}
+
+ /*!
+ * \brief set the body of the function to the given function pointer.
+ * Note that this doesn't work with lambdas, you need to
+ * explicitly give a type for those.
+ * Note that this will ignore default arg values and always require all arguments to be provided.
+ *
+ * \code
+ *
+ * int multiply(int x, int y) {
+ * return x * y;
+ * }
+ *
+ * TVM_REGISTER_API("multiply")
+ * .set_body_typed(multiply); // will have type int(int, int)
+ *
+ * \endcode
+ *
+ * \param f The function to forward to.
+ * \tparam R the return type of the function (inferred).
+ * \tparam Args the argument types of the function (inferred).
+ */
+ template<typename R, typename ...Args>
+ Registry& set_body_typed(R (*f)(Args...)) {
+ return set_body(TypedPackedFunc<R(Args...)>(f));
+ }
+
+ /*!
+ * \brief set the body of the function to be the passed method pointer.
+ * Note that this will ignore default arg values and always require all arguments to be provided.
+ *
+ * \code
+ *
+ * // node subclass:
+ * struct Example {
+ * int doThing(int x);
+ * }
+ * TVM_REGISTER_API("Example_doThing")
+ * .set_body_method(&Example::doThing); // will have type int(Example, int)
+ *
+ * \endcode
+ *
+ * \param f the method pointer to forward to.
+ * \tparam T the type containing the method (inferred).
+ * \tparam R the return type of the function (inferred).
+ * \tparam Args the argument types of the function (inferred).
+ */
+ template<typename T, typename R, typename ...Args>
+ Registry& set_body_method(R (T::*f)(Args...)) {
+ return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
+ // call method pointer
+ return (target.*f)(params...);
+ });
+ }
+
+ /*!
+ * \brief set the body of the function to be the passed method pointer.
+ * Note that this will ignore default arg values and always require all arguments to be provided.
+ *
+ * \code
+ *
+ * // node subclass:
+ * struct Example {
+ * int doThing(int x);
+ * }
+ * TVM_REGISTER_API("Example_doThing")
+ * .set_body_method(&Example::doThing); // will have type int(Example, int)
+ *
+ * \endcode
+ *
+ * \param f the method pointer to forward to.
+ * \tparam T the type containing the method (inferred).
+ * \tparam R the return type of the function (inferred).
+ * \tparam Args the argument types of the function (inferred).
+ */
+ template<typename T, typename R, typename ...Args>
+ Registry& set_body_method(R (T::*f)(Args...) const) {
+ return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
+ // call method pointer
+ return (target.*f)(params...);
+ });
+ }
+
+ /*!
+ * \brief set the body of the function to be the passed method pointer.
+ * Used when calling a method on a Node subclass through a NodeRef subclass.
+ * Note that this will ignore default arg values and always require all arguments to be provided.
+ *
+ * \code
+ *
+ * // node subclass:
+ * struct ExampleNode: BaseNode {
+ * int doThing(int x);
+ * }
+ *
+ * // noderef subclass
+ * struct Example;
+ *
+ * TVM_REGISTER_API("Example_doThing")
+ * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
+ *
+ * // note that just doing:
+ * // .set_body_method(&ExampleNode::doThing);
+ * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
+ *
+ * \endcode
+ *
+ * \param f the method pointer to forward to.
+ * \tparam TNodeRef the node reference type to call the method on
+ * \tparam TNode the node type containing the method (inferred).
+ * \tparam R the return type of the function (inferred).
+ * \tparam Args the argument types of the function (inferred).
+ */
+ template<typename TNodeRef, typename TNode, typename R, typename ...Args,
+ typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
+ Registry& set_body_method(R (TNode::*f)(Args...)) {
+ return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
+ TNode* target = ref.operator->();
+ // call method pointer
+ return (target->*f)(params...);
+ });
+ }
+
+ /*!
+ * \brief set the body of the function to be the passed method pointer.
+ * Used when calling a method on a Node subclass through a NodeRef subclass.
+ * Note that this will ignore default arg values and always require all arguments to be provided.
+ *
+ * \code
+ *
+ * // node subclass:
+ * struct ExampleNode: BaseNode {
+ * int doThing(int x);
+ * }
+ *
+ * // noderef subclass
+ * struct Example;
+ *
+ * TVM_REGISTER_API("Example_doThing")
+ * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
+ *
+ * // note that just doing:
+ * // .set_body_method(&ExampleNode::doThing);
+ * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
+ *
+ * \endcode
+ *
+ * \param f the method pointer to forward to.
+ * \tparam TNodeRef the node reference type to call the method on
+ * \tparam TNode the node type containing the method (inferred).
+ * \tparam R the return type of the function (inferred).
+ * \tparam Args the argument types of the function (inferred).
+ */
+ template<typename TNodeRef, typename TNode, typename R, typename ...Args,
+ typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
+ Registry& set_body_method(R (TNode::*f)(Args...) const) {
+ return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
+ const TNode* target = ref.operator->();
+ // call method pointer
+ return (target->*f)(params...);
+ });
+ }
+
/*!
* \brief Register a function with given name
* \param name The name of the function.
});
TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
-.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
- *rv = GraphKeyNode::make(args[0], args[1], args[2]);
- });
+.set_body_typed(GraphKeyNode::make);
// This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
}
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
-.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
- *rv = GraphDeepCompare(args[0], args[1], args[2]);
- });
+.set_body_typed(GraphDeepCompare);
} // namespace compiler
} // namespace nnvm
namespace arith {
TVM_REGISTER_API("arith.intset_single_point")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = IntSet::single_point(args[0]);
- });
+.set_body_typed(IntSet::single_point);
TVM_REGISTER_API("arith.intset_vector")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = IntSet::vector(args[0]);
- });
+.set_body_typed(IntSet::vector);
TVM_REGISTER_API("arith.intset_interval")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = IntSet::interval(args[0], args[1]);
- });
+.set_body_typed(IntSet::interval);
TVM_REGISTER_API("arith.DetectLinearEquation")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = DetectLinearEquation(args[0], args[1]);
- });
+.set_body_typed(DetectLinearEquation);
TVM_REGISTER_API("arith.DetectClipBound")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = DetectClipBound(args[0], args[1]);
- });
+.set_body_typed(DetectClipBound);
TVM_REGISTER_API("arith.DeduceBound")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = DeduceBound(args[0], args[1],
- args[2].operator Map<Var, IntSet>(),
- args[3].operator Map<Var, IntSet>());
- });
+.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
+ Expr v, Expr cond,
+ const Map<Var, IntSet> hint_map,
+ const Map<Var, IntSet> relax_map
+) {
+ return DeduceBound(v, cond, hint_map, relax_map);
+});
TVM_REGISTER_API("arith.DomainTouched")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = DomainTouched(args[0], args[1], args[2], args[3]);
- });
+.set_body_typed(DomainTouched);
TVM_REGISTER_API("_IntervalSetGetMin")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = args[0].operator IntSet().min();
- });
+.set_body_method(&IntSet::min);
TVM_REGISTER_API("_IntervalSetGetMax")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = args[0].operator IntSet().max();
- });
+.set_body_method(&IntSet::max);
TVM_REGISTER_API("_IntSetIsNothing")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = args[0].operator IntSet().is_nothing();
- });
+.set_body_method(&IntSet::is_nothing);
TVM_REGISTER_API("_IntSetIsEverything")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = args[0].operator IntSet().is_everything();
- });
+.set_body_method(&IntSet::is_everything);
TVM_REGISTER_API("arith._make_ConstIntBound")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ConstIntBoundNode::make(args[0], args[1]);
- });
+.set_body_typed(ConstIntBoundNode::make);
TVM_REGISTER_API("arith._make_ModularSet")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ModularSetNode::make(args[0], args[1]);
- });
+.set_body_typed(ModularSetNode::make);
TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
TVM_REGISTER_API("_TVMSetStream")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- TVMSetStream(args[0], args[1], args[2]);
- });
+.set_body_typed(TVMSetStream);
+
TVM_REGISTER_API("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
});
TVM_REGISTER_API("module._PackImportsToC")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = PackImportsToC(args[0], args[1]);
- });
+.set_body_typed(PackImportsToC);
} // namespace codegen
} // namespace tvm
namespace ir {
TVM_REGISTER_API("_Var")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = Variable::make(args[1], args[0]);
+.set_body_typed<VarExpr(std::string, Type)>([](std::string s, Type t) {
+ return Variable::make(t, s);
});
TVM_REGISTER_API("make.abs")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = tvm::abs(args[0]);
- });
+.set_body_typed(tvm::abs);
TVM_REGISTER_API("make.floor")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = tvm::floor(args[0]);
- });
+.set_body_typed(tvm::floor);
TVM_REGISTER_API("make.ceil")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = tvm::ceil(args[0]);
- });
+.set_body_typed(tvm::ceil);
TVM_REGISTER_API("make.round")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = tvm::round(args[0]);
- });
+.set_body_typed(tvm::round);
TVM_REGISTER_API("make.trunc")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = tvm::trunc(args[0]);
- });
+.set_body_typed(tvm::trunc);
TVM_REGISTER_API("make._cast")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = tvm::cast(args[0], args[1]);
- });
+.set_body_typed(tvm::cast);
TVM_REGISTER_API("make._range_by_min_extent")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = Range::make_by_min_extent(args[0], args[1]);
- });
+.set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_API("make.For")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = For::make(args[0],
- args[1],
- args[2],
- static_cast<ForType>(args[3].operator int()),
- static_cast<HalideIR::DeviceAPI>(args[4].operator int()),
- args[5]);
- });
+.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
+ VarExpr loop_var, Expr min, Expr extent,
+ int for_type, int device_api, Stmt body
+) {
+ return For::make(loop_var,
+ min,
+ extent,
+ static_cast<ForType>(for_type),
+ static_cast<HalideIR::DeviceAPI>(device_api),
+ body);
+});
TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
TVM_REGISTER_API("make.Realize")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = Realize::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4],
- args[5]);
- });
-
+.set_body_typed(Realize::make);
TVM_REGISTER_API("make.Call")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = Call::make(args[0],
- args[1],
- args[2],
- static_cast<Call::CallType>(args[3].operator int()),
- args[4],
- args[5]);
- });
+.set_body_typed<Expr(Type, std::string, Array<Expr>, int, FunctionRef, int)>([](
+ Type type, std::string name,
+ Array<Expr> args, int call_type,
+ FunctionRef func, int value_index
+) {
+ return Call::make(type,
+ name,
+ args,
+ static_cast<Call::CallType>(call_type),
+ func,
+ value_index);
+});
TVM_REGISTER_API("make.CommReducer")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = CommReducerNode::make(args[0],
- args[1],
- args[2],
- args[3]);
- });
+.set_body_typed(CommReducerNode::make);
// make from two arguments
-#define REGISTER_MAKE1(Node) \
- TVM_REGISTER_API("make."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = Node::make(args[0]); \
- }) \
-
-#define REGISTER_MAKE2(Node) \
+#define REGISTER_MAKE(Node) \
TVM_REGISTER_API("make."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = Node::make(args[0], args[1]); \
- }) \
-
-#define REGISTER_MAKE3(Node) \
- TVM_REGISTER_API("make."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = Node::make(args[0], args[1], args[2]); \
- }) \
-
-#define REGISTER_MAKE4(Node) \
- TVM_REGISTER_API("make."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = Node::make(args[0], args[1], args[2], args[3]); \
- }) \
-
-#define REGISTER_MAKE5(Node) \
- TVM_REGISTER_API("make."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
- }) \
-
-
-REGISTER_MAKE5(Reduce);
-REGISTER_MAKE4(AttrStmt);
-
-REGISTER_MAKE2(IntImm);
-REGISTER_MAKE2(UIntImm);
-REGISTER_MAKE2(FloatImm);
-REGISTER_MAKE1(StringImm);
-
-REGISTER_MAKE2(Add);
-REGISTER_MAKE2(Sub);
-REGISTER_MAKE2(Mul);
-REGISTER_MAKE2(Div);
-REGISTER_MAKE2(Mod);
-REGISTER_MAKE2(Min);
-REGISTER_MAKE2(Max);
-REGISTER_MAKE2(EQ);
-REGISTER_MAKE2(NE);
-REGISTER_MAKE2(LT);
-REGISTER_MAKE2(LE);
-REGISTER_MAKE2(GT);
-REGISTER_MAKE2(GE);
-REGISTER_MAKE2(And);
-REGISTER_MAKE2(Or);
-
-REGISTER_MAKE1(Not);
-REGISTER_MAKE3(Select);
-REGISTER_MAKE3(Ramp);
-REGISTER_MAKE2(Cast);
-REGISTER_MAKE2(Broadcast);
-REGISTER_MAKE2(Shuffle);
-REGISTER_MAKE3(Let);
-REGISTER_MAKE3(LetStmt);
-REGISTER_MAKE3(AssertStmt);
-REGISTER_MAKE3(ProducerConsumer);
-REGISTER_MAKE5(Allocate);
-REGISTER_MAKE4(Provide);
-REGISTER_MAKE4(Prefetch);
-REGISTER_MAKE1(Free);
-REGISTER_MAKE2(Block);
-REGISTER_MAKE3(IfThenElse);
-REGISTER_MAKE1(Evaluate);
+ .set_body_typed(Node::make); \
+
+REGISTER_MAKE(Reduce);
+REGISTER_MAKE(AttrStmt);
+
+REGISTER_MAKE(IntImm);
+REGISTER_MAKE(UIntImm);
+REGISTER_MAKE(FloatImm);
+REGISTER_MAKE(StringImm);
+
+REGISTER_MAKE(Add);
+REGISTER_MAKE(Sub);
+REGISTER_MAKE(Mul);
+REGISTER_MAKE(Div);
+REGISTER_MAKE(Mod);
+REGISTER_MAKE(Min);
+REGISTER_MAKE(Max);
+REGISTER_MAKE(EQ);
+REGISTER_MAKE(NE);
+REGISTER_MAKE(LT);
+REGISTER_MAKE(LE);
+REGISTER_MAKE(GT);
+REGISTER_MAKE(GE);
+REGISTER_MAKE(And);
+REGISTER_MAKE(Or);
+
+REGISTER_MAKE(Not);
+REGISTER_MAKE(Select);
+REGISTER_MAKE(Ramp);
+REGISTER_MAKE(Cast);
+REGISTER_MAKE(Broadcast);
+REGISTER_MAKE(Shuffle);
+REGISTER_MAKE(Let);
+REGISTER_MAKE(LetStmt);
+REGISTER_MAKE(AssertStmt);
+REGISTER_MAKE(ProducerConsumer);
+REGISTER_MAKE(Provide);
+REGISTER_MAKE(Prefetch);
+REGISTER_MAKE(Free);
+REGISTER_MAKE(IfThenElse);
+REGISTER_MAKE(Evaluate);
+
+// overloaded, needs special handling
+TVM_REGISTER_API("make.Block")
+ .set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));
+
+// has default args
+TVM_REGISTER_API("make.Allocate")
+ .set_body_typed<Stmt(VarExpr, Type, Array<Expr>, Expr, Stmt)>([](
+ VarExpr buffer_var, Type type, Array<Expr> extents, Expr condition, Stmt body
+ ){
+ return Allocate::make(buffer_var, type, extents, condition, body);
+ });
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- Expr a = args[0], b = args[1]; \
- *ret = (Func(a, b)); \
+ .set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) { \
+ return (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
#include <tvm/build_module.h>
#include <tvm/data_layout.h>
+
namespace tvm {
TVM_REGISTER_API("_min_value")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Type t = args[0].operator Type();
- *ret = t.min();
- });
+.set_body_method(&Type::min);
TVM_REGISTER_API("_max_value")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Type t = args[0].operator Type();
- *ret = t.max();
- });
+.set_body_method(&Type::max);
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
});
TVM_REGISTER_API("_str")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ir::StringImm::make(args[0]);
-});
+.set_body_typed(ir::StringImm::make);
TVM_REGISTER_API("_Array")
});
TVM_REGISTER_API("_Buffer")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = BufferNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4],
- args[5],
- args[6],
- args[7],
- args[8]);
- });
+.set_body_typed(BufferNode::make);
TVM_REGISTER_API("_BufferAccessPtr")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Buffer()
- .access_ptr(args[1], args[2], args[3], args[4]);
- });
+.set_body_method(&Buffer::access_ptr);
TVM_REGISTER_API("_BufferVLoad")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Buffer()
- .vload(args[1], args[2]);
- });
+.set_body_method(&Buffer::vload);
TVM_REGISTER_API("_BufferVStore")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Buffer()
- .vstore(args[1], args[2]);
- });
+.set_body_method(&Buffer::vstore);
TVM_REGISTER_API("_Layout")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = LayoutNode::make(args[0]);
- });
+.set_body_typed(LayoutNode::make);
TVM_REGISTER_API("_LayoutIndexOf")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Layout()
- .IndexOf(LayoutAxis::make(args[1]));
+.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
+ return layout.IndexOf(LayoutAxis::make(axis));
});
TVM_REGISTER_API("_LayoutFactorOf")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Layout()
- .FactorOf(LayoutAxis::make(args[1]));
+.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
+ return layout.FactorOf(LayoutAxis::make(axis));
});
TVM_REGISTER_API("_LayoutNdim")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = static_cast<int64_t>(args[0].operator Layout().ndim());
+.set_body_typed<int(Layout)>([](Layout layout) {
+ return layout.ndim();
});
TVM_REGISTER_API("_LayoutGetItem")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- const LayoutAxis& axis = args[0].operator Layout()[args[1]];
- *ret = axis.name();
+.set_body_typed<std::string(Layout, int)>([](Layout layout, int idx) {
+ const LayoutAxis& axis = layout[idx];
+ return axis.name();
});
TVM_REGISTER_API("_BijectiveLayout")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = BijectiveLayoutNode::make(args[0], args[1]);
- });
+.set_body_typed(BijectiveLayoutNode::make);
TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator BijectiveLayout()
- .ForwardIndex(args[1]);
- });
+.set_body_method(&BijectiveLayout::ForwardIndex);
TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator BijectiveLayout()
- .BackwardIndex(args[1]);
- });
+.set_body_method(&BijectiveLayout::BackwardIndex);
TVM_REGISTER_API("_BijectiveLayoutForwardShape")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator BijectiveLayout()
- .ForwardShape(args[1]);
- });
+.set_body_method(&BijectiveLayout::ForwardShape);
TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator BijectiveLayout()
- .BackwardShape(args[1]);
- });
+.set_body_method(&BijectiveLayout::BackwardShape);
TVM_REGISTER_API("_Tensor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TensorNode::make(args[0],
- args[1],
- args[2],
- args[3]);
- });
+.set_body_typed(TensorNode::make);
TVM_REGISTER_API("_TensorIntrin")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TensorIntrinNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4],
- args[5],
- args[6]);
- });
+.set_body_typed(TensorIntrinNode::make);
TVM_REGISTER_API("_TensorIntrinCall")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TensorIntrinCallNode::make(args[0],
- args[1],
- args[2],
- args[3]);
- });
+.set_body_typed(TensorIntrinCallNode::make);
TVM_REGISTER_API("_TensorEqual")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Tensor() == args[1].operator Tensor();
- });
+.set_body_method(&Tensor::operator==);
TVM_REGISTER_API("_TensorHash")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = static_cast<int64_t>(
- std::hash<Tensor>()(args[0].operator Tensor()));
+.set_body_typed<int64_t(Tensor)>([](Tensor tensor) {
+ return static_cast<int64_t>(std::hash<Tensor>()(tensor));
});
TVM_REGISTER_API("_Placeholder")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = placeholder(args[0],
- args[1],
- args[2]);
- });
+.set_body_typed<Tensor(Array<Expr>, Type, std::string)>([](
+ Array<Expr> shape, Type dtype, std::string name
+) {
+ return placeholder(shape, dtype, name);
+});
TVM_REGISTER_API("_ComputeOp")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ComputeOpNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4]);
- });
+.set_body_typed(ComputeOpNode::make);
TVM_REGISTER_API("_ScanOp")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ScanOpNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4],
- args[5],
- args[6],
- args[7]);
- });
+.set_body_typed(ScanOpNode::make);
TVM_REGISTER_API("_TensorComputeOp")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TensorComputeOpNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4],
- args[5],
- args[6],
- args[7]);
- });
+.set_body_typed(TensorComputeOpNode::make);
TVM_REGISTER_API("_ExternOp")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ExternOpNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4],
- args[5],
- args[6]);
- });
+.set_body_typed(ExternOpNode::make);
TVM_REGISTER_API("_HybridOp")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = HybridOpNode::make(args[0],
- args[1],
- args[2],
- args[3],
- args[4],
- args[5]);
- });
+.set_body_typed(HybridOpNode::make);
TVM_REGISTER_API("_OpGetOutput")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Operation().output(
- static_cast<size_t>(args[1].operator int64_t()));
- });
+.set_body_typed<Tensor(Operation, int64_t)>([](Operation op, int64_t output) {
+ return op.output(static_cast<size_t>(output));
+});
TVM_REGISTER_API("_OpNumOutputs")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Operation()->num_outputs();
- });
+.set_body_method<Operation>(&OperationNode::num_outputs);
TVM_REGISTER_API("_OpInputTensors")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Operation()->InputTensors();
- });
+.set_body_method<Operation>(&OperationNode::InputTensors);
TVM_REGISTER_API("_IterVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = IterVarNode::make(
- args[0], args[1],
- static_cast<IterVarType>(args[2].operator int()),
- args[3]);
- });
+.set_body_typed<IterVar(Range, Var, int, std::string)>([](
+ Range dom, Var var, int iter_type, std::string thread_tag
+) {
+ return IterVarNode::make(
+ dom, var,
+ static_cast<IterVarType>(iter_type),
+ thread_tag);
+});
TVM_REGISTER_API("_CreateSchedule")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = create_schedule(args[0].operator Array<Operation>());
- });
+.set_body_typed(create_schedule);
TVM_REGISTER_API("_StageSetScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .set_scope(args[1]);
- });
+.set_body_method(&Stage::set_scope);
TVM_REGISTER_API("_StageBind")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .bind(args[1], args[2]);
- });
+.set_body_method(&Stage::bind);
TVM_REGISTER_API("_StageSplitByFactor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- IterVar outer, inner;
- args[0].operator Stage()
- .split(args[1], args[2], &outer, &inner);
- *ret = Array<IterVar>({outer, inner});
- });
+.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
+ Stage stage, IterVar parent, Expr factor
+) {
+ IterVar outer, inner;
+ stage.split(parent, factor, &outer, &inner);
+ return Array<IterVar>({outer, inner});
+});
TVM_REGISTER_API("_StageSplitByNParts")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- IterVar outer, inner;
- args[0].operator Stage()
- .split_by_nparts(args[1], args[2], &outer, &inner);
- *ret = Array<IterVar>({outer, inner});
- });
+.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
+ Stage stage, IterVar parent, Expr nparts
+) {
+ IterVar outer, inner;
+ stage.split_by_nparts(parent, nparts, &outer, &inner);
+ return Array<IterVar>({outer, inner});
+});
TVM_REGISTER_API("_StageFuse")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
+.set_body_typed<IterVar(Stage, Array<IterVar>)>([](Stage stage, Array<IterVar> axes) {
IterVar fused;
- args[0].operator Stage()
- .fuse(args[1], &fused);
- *ret = fused;
+ stage.fuse(axes, &fused);
+ return fused;
});
TVM_REGISTER_API("_StageComputeAt")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .compute_at(args[1], args[2]);
- });
+.set_body_method(&Stage::compute_at);
TVM_REGISTER_API("_StageComputeInline")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .compute_inline();
- });
+.set_body_method(&Stage::compute_inline);
TVM_REGISTER_API("_StageComputeRoot")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .compute_root();
- });
+.set_body_method(&Stage::compute_root);
TVM_REGISTER_API("_StageReorder")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .reorder(args[1]);
- });
+.set_body_method(&Stage::reorder);
TVM_REGISTER_API("_StageTile")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
+.set_body_typed<Array<IterVar>(Stage, IterVar, IterVar, Expr, Expr)>([](
+ Stage stage,
+ IterVar x_parent, IterVar y_parent,
+ Expr x_factor, Expr y_factor
+) {
IterVar x_outer, y_outer, x_inner, y_inner;
- args[0].operator Stage()
- .tile(args[1], args[2],
- args[3], args[4],
- &x_outer, &y_outer,
- &x_inner, &y_inner);
- *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
+ stage.tile(x_parent, y_parent,
+ x_factor, y_factor,
+ &x_outer, &y_outer,
+ &x_inner, &y_inner);
+ return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API("_StageEnvThreads")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .env_threads(args[1]);
- });
+.set_body_method(&Stage::env_threads);
TVM_REGISTER_API("_StageSetStorePredicate")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .set_store_predicate(args[1]);
- });
+.set_body_method(&Stage::set_store_predicate);
TVM_REGISTER_API("_StageUnroll")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .unroll(args[1]);
- });
+.set_body_method(&Stage::unroll);
TVM_REGISTER_API("_StageVectorize")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .vectorize(args[1]);
- });
+.set_body_method(&Stage::vectorize);
TVM_REGISTER_API("_StageTensorize")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .tensorize(args[1], args[2]);
- });
+.set_body_method(&Stage::tensorize);
TVM_REGISTER_API("_StageParallel")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .parallel(args[1]);
- });
+.set_body_method(&Stage::parallel);
TVM_REGISTER_API("_StagePragma")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- args[0].operator Stage()
- .pragma(args[1], args[2], args[3]);
- });
+.set_body_method(&Stage::pragma);
TVM_REGISTER_API("_StagePrefetch")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- args[0].operator Stage()
- .prefetch(args[1], args[2], args[3]);
- });
+.set_body_method(&Stage::prefetch);
TVM_REGISTER_API("_StageStorageAlign")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- args[0].operator Stage()
- .storage_align(args[1], args[2], args[3]);
- });
+.set_body_method(&Stage::storage_align);
TVM_REGISTER_API("_StageDoubleBuffer")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- args[0].operator Stage().double_buffer();
- });
+.set_body_method(&Stage::double_buffer);
TVM_REGISTER_API("_StageOpenGL")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- args[0].operator Stage().opengl();
- });
+.set_body_method(&Stage::opengl);
TVM_REGISTER_API("_ScheduleNormalize")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Schedule()
- .normalize();
- });
+.set_body_method(&Schedule::normalize);
TVM_REGISTER_API("_ScheduleCreateGroup")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Schedule()
- .create_group(args[1], args[2], args[3]);
- });
+.set_body_method(&Schedule::create_group);
TVM_REGISTER_API("_ScheduleCacheRead")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Schedule()
- .cache_read(args[1], args[2], args[3]);
- });
+.set_body_method(&Schedule::cache_read);
TVM_REGISTER_API("_ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
});
TVM_REGISTER_API("_ScheduleRFactor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Schedule()
- .rfactor(args[1], args[2], args[3]);
- });
+.set_body_method(&Schedule::rfactor);
TVM_REGISTER_API("_CommReducerCombine")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- const ir::CommReducerNode* combiner =
- args[0].operator ir::CommReducer().as<ir::CommReducerNode>();
- *ret = (*combiner)(args[1], args[2]);
- });
+.set_body_method<ir::CommReducer>(&ir::CommReducerNode::operator());
} // namespace tvm
});
// make from two arguments
-#define REGISTER_PASS1(PassName) \
+#define REGISTER_PASS(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = PassName(args[0]); \
- }) \
-
-#define REGISTER_PASS2(PassName) \
- TVM_REGISTER_API("ir_pass."#PassName) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = PassName(args[0], args[1]); \
- }) \
-
-#define REGISTER_PASS3(PassName) \
- TVM_REGISTER_API("ir_pass."#PassName) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = PassName(args[0], args[1], args[2]); \
- }) \
-
-#define REGISTER_PASS4(PassName) \
- TVM_REGISTER_API("ir_pass."#PassName) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = PassName(args[0], args[1], args[2], args[3]); \
- }) \
-
-#define REGISTER_PASS5(PassName) \
- TVM_REGISTER_API("ir_pass."#PassName) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = PassName(args[0], args[1], args[2], args[3], args[4]); \
- }) \
-
-REGISTER_PASS1(ConvertSSA);
-REGISTER_PASS1(VerifySSA);
-REGISTER_PASS1(RewriteUnsafeSelect);
-REGISTER_PASS4(Inline);
-REGISTER_PASS4(IRTransform);
-REGISTER_PASS1(VectorizeLoop);
-REGISTER_PASS5(UnrollLoop);
-REGISTER_PASS3(InjectCopyIntrin);
-REGISTER_PASS2(ThreadSync);
-REGISTER_PASS5(MakeAPI);
-REGISTER_PASS2(BindDeviceType);
-REGISTER_PASS1(SplitHostDevice);
-REGISTER_PASS1(StorageRewrite);
-REGISTER_PASS1(CoProcSync);
-REGISTER_PASS1(LowerStorageAccessInfo);
-REGISTER_PASS1(InjectVirtualThread);
-REGISTER_PASS1(InjectPrefetch);
-REGISTER_PASS2(InjectDoubleBuffer);
-REGISTER_PASS2(LoopPartition);
-REGISTER_PASS1(RemoveNoOp);
-REGISTER_PASS2(SplitPipeline);
-REGISTER_PASS2(LiftAttrScope);
-REGISTER_PASS1(NarrowChannelAccess);
-REGISTER_PASS2(LowerThreadAllreduce);
-REGISTER_PASS2(LowerWarpMemory);
-REGISTER_PASS2(RemapThreadAxis);
-REGISTER_PASS2(LowerIntrin);
-REGISTER_PASS1(LowerTVMBuiltin);
-REGISTER_PASS1(CombineContextCall);
-REGISTER_PASS2(VerifyMemory);
-REGISTER_PASS2(VerifyGPUCode);
-REGISTER_PASS1(DecorateDeviceScope);
-REGISTER_PASS1(InstrumentBoundCheckers);
+ .set_body_typed(PassName); \
+
+
+REGISTER_PASS(ConvertSSA);
+REGISTER_PASS(VerifySSA);
+REGISTER_PASS(RewriteUnsafeSelect);
+REGISTER_PASS(Inline);
+REGISTER_PASS(IRTransform);
+REGISTER_PASS(VectorizeLoop);
+REGISTER_PASS(UnrollLoop);
+REGISTER_PASS(InjectCopyIntrin);
+REGISTER_PASS(ThreadSync);
+REGISTER_PASS(MakeAPI);
+REGISTER_PASS(BindDeviceType);
+REGISTER_PASS(SplitHostDevice);
+REGISTER_PASS(StorageRewrite);
+REGISTER_PASS(CoProcSync);
+REGISTER_PASS(LowerStorageAccessInfo);
+REGISTER_PASS(InjectVirtualThread);
+REGISTER_PASS(InjectPrefetch);
+REGISTER_PASS(InjectDoubleBuffer);
+REGISTER_PASS(LoopPartition);
+REGISTER_PASS(RemoveNoOp);
+REGISTER_PASS(SplitPipeline);
+REGISTER_PASS(LiftAttrScope);
+REGISTER_PASS(NarrowChannelAccess);
+REGISTER_PASS(LowerThreadAllreduce);
+REGISTER_PASS(LowerWarpMemory);
+REGISTER_PASS(RemapThreadAxis);
+REGISTER_PASS(LowerIntrin);
+REGISTER_PASS(LowerTVMBuiltin);
+REGISTER_PASS(CombineContextCall);
+REGISTER_PASS(VerifyMemory);
+REGISTER_PASS(VerifyGPUCode);
+REGISTER_PASS(DecorateDeviceScope);
+REGISTER_PASS(InstrumentBoundCheckers);
} // namespace ir
} // namespace tvm
namespace schedule {
TVM_REGISTER_API("schedule.AutoInlineElemWise")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- AutoInlineElemWise(args[0]);
- });
+.set_body_typed(AutoInlineElemWise);
TVM_REGISTER_API("schedule.AutoInlineInjective")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- AutoInlineInjective(args[0]);
- });
+.set_body_typed(AutoInlineInjective);
TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ScheduleOps(args[0], args[1], args[2]);
});
-#define REGISTER_SCHEDULE_PASS1(PassName) \
- TVM_REGISTER_API("schedule."#PassName) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = PassName(args[0]); \
- }) \
-
-#define REGISTER_SCHEDULE_PASS2(PassName) \
+#define REGISTER_SCHEDULE_PASS(PassName) \
TVM_REGISTER_API("schedule."#PassName) \
- .set_body([](TVMArgs args, TVMRetValue *ret) { \
- *ret = PassName(args[0], args[1]); \
- }) \
+ .set_body_typed(PassName); \
-REGISTER_SCHEDULE_PASS1(InferBound);
-REGISTER_SCHEDULE_PASS1(CreateReadGraph);
-REGISTER_SCHEDULE_PASS2(PostDFSOrder);
-REGISTER_SCHEDULE_PASS1(CreateAttachPath);
-REGISTER_SCHEDULE_PASS1(ScanGetBody);
-REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
+REGISTER_SCHEDULE_PASS(InferBound);
+REGISTER_SCHEDULE_PASS(CreateReadGraph);
+REGISTER_SCHEDULE_PASS(PostDFSOrder);
+REGISTER_SCHEDULE_PASS(CreateAttachPath);
+REGISTER_SCHEDULE_PASS(ScanGetBody);
+REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
} // namespace schedule
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_opencl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildOpenCL(args[0]);
- });
+.set_body_typed(BuildOpenCL);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_opengl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildOpenGL(args[0]);
-});
+.set_body_typed(BuildOpenGL);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_sdaccel")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildSDAccel(args[0], args[1]);
- });
+.set_body_typed(BuildSDAccel);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_rocm")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildAMDGPU(args[0], args[1]);
- });
+.set_body_typed(BuildAMDGPU);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_nvptx")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildNVPTX(args[0], args[1]);
- });
+.set_body_typed(BuildNVPTX);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_cuda")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildCUDA(args[0]);
- });
+.set_body_typed(BuildCUDA);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_GLOBAL("module.source_module_create")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = SourceModuleCreate(args[0], args[1]);
- });
+.set_body_typed(SourceModuleCreate);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildSPIRV(args[0]);
- });
+.set_body_typed(BuildSPIRV);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("codegen.build_stackvm")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = BuildStackVM(args[0]);
- });
+.set_body_typed(BuildStackVM);
} // namespace codegen
} // namespace tvm
}
TVM_REGISTER_API("relay._make.Closure")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ClosureNode::make(args[0], args[1]);
- });
+.set_body_typed(ClosureNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
}
TVM_REGISTER_API("relay._make.TupleValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TupleValueNode::make(args[0]);
- });
+.set_body_typed(TupleValueNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) {
});
TVM_REGISTER_API("relay._make.TensorValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- runtime::NDArray data = args[0];
- *ret = TensorValueNode::make(data);
- });
+.set_body_typed(TensorValueNode::make);
RefValue RefValueNode::make(Value value) {
NodePtr<RefValueNode> n = make_node<RefValueNode>();
}
TVM_REGISTER_API("relay._make.RefValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = RefValueNode::make(args[0]);
- });
+.set_body_typed(RefValueNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node,
}
TVM_REGISTER_API("relay._make.ConstructorValue")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ConstructorValueNode::make(args[0], args[1]);
- });
+.set_body_typed(ConstructorValueNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
}
TVM_REGISTER_API("relay.backend.CreateInterpreter")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = CreateInterpreter(args[0], args[1], args[2]);
- });
+.set_body_typed(CreateInterpreter);
TVM_REGISTER_NODE_TYPE(ClosureNode);
TVM_REGISTER_NODE_TYPE(TupleValueNode);
TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_API("relay._make.PatternWildcard")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = PatternWildcardNode::make();
- });
+.set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_API("relay._make.PatternVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = PatternVarNode::make(args[0]);
- });
+.set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const PatternVarNode* node,
TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_API("relay._make.PatternConstructor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = PatternConstructorNode::make(args[0], args[1]);
- });
+.set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_API("relay._make.Constructor")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ConstructorNode::make(args[0], args[1], args[2]);
- });
+.set_body_typed(ConstructorNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ConstructorNode* node,
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_API("relay._make.TypeData")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TypeDataNode::make(args[0], args[1], args[2]);
- });
+.set_body_typed(TypeDataNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const TypeDataNode* node,
TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_API("relay._make.Clause")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ClauseNode::make(args[0], args[1]);
- });
+.set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ClauseNode* node,
TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_API("relay._make.Match")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = MatchNode::make(args[0], args[1]);
- });
+.set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node,
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = AlphaEqualHandler(false).Equal(args[0], args[1]);
+.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
+ return AlphaEqualHandler(false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]);
+.set_body_typed<bool(Type, Type)>([](Type a, Type b) {
+ return AlphaEqualHandler(false).TypeEqual(a, b);
});
TVM_REGISTER_API("relay._make._graph_equal")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = AlphaEqualHandler(true).Equal(args[0], args[1]);
+.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
+ return AlphaEqualHandler(true).Equal(a, b);
});
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._make.SourceName")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = SourceName::Get(args[0]);
- });
+.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = SpanNode::make(args[0], args[1], args[2]);
- });
+.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- NodeRef node_ref = args[0];
+.set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) {
auto rn = node_ref.as_derived<RelayNode>();
CHECK(rn);
- Span sp = args[1];
rn->span = sp;
});
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ConstantNode::make(args[0]);
- });
+.set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TupleNode::make(args[0]);
- });
+.set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = VarNode::make(args[0].operator std::string(), args[1]);
- });
+.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = GlobalVarNode::make(args[0]);
- });
+.set_body_typed(GlobalVarNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]);
-});
+.set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode* node,
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = CallNode::make(args[0], args[1], args[2], args[3]);
-});
+.set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = LetNode::make(args[0], args[1], args[2]);
- });
+.set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(IfNode);
-TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = IfNode::make(args[0], args[1], args[2]);
-});
+TVM_REGISTER_API("relay._make.If")
+.set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
-TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TupleGetItemNode::make(args[0], args[1]);
-});
+TVM_REGISTER_API("relay._make.TupleGetItem")
+.set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
return RefCreate(n);
}
-TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = RefCreateNode::make(args[0]);
-});
+TVM_REGISTER_API("relay._make.RefCreate")
+.set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
}
TVM_REGISTER_API("relay._make.RefRead")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = RefReadNode::make(args[0]);
-});
+.set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
}
TVM_REGISTER_API("relay._make.RefWrite")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = RefWriteNode::make(args[0], args[1]);
-});
+.set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
});
TVM_REGISTER_API("relay._expr.TempExprRealize")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- TempExpr temp = args[0];
- *ret = temp->Realize();
+.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
+ return temp->Realize();
});
} // namespace relay
}
TVM_REGISTER_API("relay._ir_pass.post_order_visit")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- PackedFunc f = args[1];
- PostOrderVisit(args[0], [f](const Expr& n) {
+.set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
+ PostOrderVisit(expr, [f](const Expr& n) {
f(n);
});
});
}
TVM_REGISTER_API("relay._ir_pass._expr_hash")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = static_cast<int64_t>(RelayHashHandler().Hash(args[0]));
- });
+.set_body_typed<int64_t(NodeRef)>([](NodeRef ref) {
+ return static_cast<int64_t>(RelayHashHandler().Hash(ref));
+});
TVM_REGISTER_API("relay._ir_pass._type_hash")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = static_cast<int64_t>(RelayHashHandler().TypeHash(args[0]));
- });
+.set_body_typed<int64_t(Type)>([](Type type) {
+ return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
+});
} // namespace relay
} // namespace tvm
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = ModuleNode::make(args[0], args[1]);
- });
+.set_body_typed(ModuleNode::make);
TVM_REGISTER_API("relay._make.Module_Add")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- mod->Add(args[1], args[2], args[3]);
- });
+.set_body_method<Module>(&ModuleNode::Add);
TVM_REGISTER_API("relay._module.Module_AddDef")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- mod->AddDef(args[1], args[2]);
- });
+.set_body_method<Module>(&ModuleNode::AddDef);
TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- *ret = mod->GetGlobalVar(args[1]);
- });
+.set_body_method<Module>(&ModuleNode::GetGlobalVar);
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- *ret = mod->GetGlobalTypeVar(args[1]);
- });
+.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
TVM_REGISTER_API("relay._module.Module_Lookup")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- GlobalVar var = args[1];
- *ret = mod->Lookup(var);
+.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
+ return mod->Lookup(var);
});
TVM_REGISTER_API("relay._module.Module_Lookup_str")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- std::string var_name = args[1];
- *ret = mod->Lookup(var_name);
+.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
+ return mod->Lookup(var);
});
TVM_REGISTER_API("relay._module.Module_LookupDef")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- GlobalTypeVar var = args[1];
- *ret = mod->LookupDef(var);
+.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
+ return mod->LookupDef(var);
});
TVM_REGISTER_API("relay._module.Module_LookupDef_str")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- std::string var_name = args[1];
- *ret = mod->LookupDef(var_name);
+.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
+ return mod->LookupDef(var);
});
TVM_REGISTER_API("relay._module.Module_Update")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- Module mod = args[0];
- mod->Update(args[1]);
+.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
+ mod->Update(from);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- Array<IndexExpr> shape = args[0];
- *ret = TensorTypeNode::make(shape, args[1]);
-});
+.set_body_typed(TensorTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- int kind = args[1];
- *ret =
- TypeVarNode::make(args[0], static_cast<Kind>(kind));
+.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
+ return TypeVarNode::make(name, static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_API("relay._make.GlobalTypeVar")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- int kind = args[1];
- *ret = GlobalTypeVarNode::make(args[0], static_cast<Kind>(kind));
-});
+.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
+ return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
+ });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_API("relay._make.TypeCall")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TypeCallNode::make(args[0], args[1]);
-});
+.set_body_typed(TypeCallNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const TypeCallNode* node,
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- int kind = args[0];
- *ret = IncompleteTypeNode::make(static_cast<Kind>(kind));
+.set_body_typed<IncompleteType(int)>([](int kind) {
+ return IncompleteTypeNode::make(static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
-});
+.set_body_typed(FuncTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
-});
+.set_body_typed(TypeRelationNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = TupleTypeNode::make(args[0]);
-});
+.set_body_typed(TupleTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
}
TVM_REGISTER_API("relay._make.RefType")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = RefTypeNode::make(args[0]);
-});
+.set_body_typed(RefTypeNode::make);
TVM_REGISTER_NODE_TYPE(RefTypeNode);
}
TVM_REGISTER_API("relay.op._make.debug")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv);
- });
+.set_body_typed(MakeDebug);
} // namespace relay
} // namespace tvm
TVM_REGISTER_API("relay.op.image._make.resize")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 5>(MakeResize, args, rv);
- });
+.set_body_typed(MakeResize);
RELAY_REGISTER_OP("image.resize")
TVM_REGISTER_API("relay.op.nn._make.conv2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
- });
+.set_body_typed(MakeConv2D);
RELAY_REGISTER_OP("nn.conv2d")
TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 12>(MakeConv2DTranspose, args, rv);
- });
+.set_body_typed(MakeConv2DTranspose);
RELAY_REGISTER_OP("nn.conv2d_transpose")
.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 13>(MakeConv2DWinograd, args, rv);
- });
+.set_body_typed(MakeConv2DWinograd);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
-});
+.set_body_typed(MakeConv2DWinogradWeightTransform);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
-});
+.set_body_typed(MakeConv2DWinogradNNPACK);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
-});
+.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 12>(MakeConv2DNCHWc, args, rv);
- });
+.set_body_typed(MakeConv2DNCHWc);
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
}
TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 12>(MakeDepthwiseConv2DNCHWc, args, rv);
- });
+.set_body_typed(MakeDepthwiseConv2DNCHWc);
RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
}
TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 14>(MakeDeformableConv2D, args, rv);
- });
+.set_body_typed(MakeDeformableConv2D);
} // namespace relay
TVM_REGISTER_API("relay.op.nn._make.bias_add")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeBiasAdd, args, rv);
- });
+.set_body_typed(MakeBiasAdd);
RELAY_REGISTER_OP("nn.bias_add")
TVM_REGISTER_API("relay.op.nn._make.dense")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeDense, args, rv);
- });
+.set_body_typed(MakeDense);
RELAY_REGISTER_OP("nn.dense")
TVM_REGISTER_API("relay.op.nn._make.leaky_relu")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeLeakyRelu, args, rv);
- });
+.set_body_typed(MakeLeakyRelu);
RELAY_REGISTER_OP("nn.leaky_relu")
TVM_REGISTER_API("relay.op.nn._make.prelu")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakePRelu, args, rv);
- });
+.set_body_typed(MakePRelu);
RELAY_REGISTER_OP("nn.prelu")
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
TVM_REGISTER_API("relay.op.nn._make.softmax")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- auto make_func = [](Expr data, int axis) {
- auto attrs = make_node<SoftmaxAttrs>();
- attrs->axis = axis;
- static const Op& op = Op::Get("nn.softmax");
- return CallNode::make(op, {data}, Attrs(attrs), {});
- };
-
- runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
+.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
+ auto attrs = make_node<SoftmaxAttrs>();
+ attrs->axis = axis;
+ static const Op& op = Op::Get("nn.softmax");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
});
+
RELAY_REGISTER_OP("nn.softmax")
.describe(R"code(Softmax layer.
// relay.nn.log_softmax
TVM_REGISTER_API("relay.op.nn._make.log_softmax")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- auto make_func = [](Expr data, int axis) {
- auto attrs = make_node<SoftmaxAttrs>();
- attrs->axis = axis;
- static const Op& op = Op::Get("nn.log_softmax");
- return CallNode::make(op, {data}, Attrs(attrs), {});
- };
-
- runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
+.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
+ auto attrs = make_node<SoftmaxAttrs>();
+ attrs->axis = axis;
+ static const Op& op = Op::Get("nn.log_softmax");
+ return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("nn.log_softmax")
TVM_REGISTER_API("relay.op.nn._make.batch_flatten")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 1>(MakeBatchFlatten, args, rv);
- });
+.set_body_typed(MakeBatchFlatten);
RELAY_REGISTER_OP("nn.batch_flatten")
// relu
TVM_REGISTER_API("relay.op.nn._make.relu")
-.set_body_typed<Expr(Expr)>([](Expr data) {
+.set_body_typed<Call(Expr)>([](Expr data) {
static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {data}, Attrs(), {});
});
}
TVM_REGISTER_API("relay.op.nn._make.lrn")
- .set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 6>(MakeLRN, args, rv);
- });
+.set_body_typed(MakeLRN);
RELAY_REGISTER_OP("nn.lrn")
.describe(R"code(LRN layer.
}
TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
- });
+.set_body_typed(MakeL2Normalize);
RELAY_REGISTER_OP("nn.l2_normalize")
.describe(R"code(L2 Normalization layer.
}
TVM_REGISTER_API("relay.op.nn._make.dropout")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeDropout, args, rv);
- });
+.set_body_typed(MakeDropout);
RELAY_REGISTER_OP("nn.dropout")
.describe(R"code(Applies the dropout operation to the input array.
}
TVM_REGISTER_API("relay.op.nn._make.batch_norm")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 9>(MakeBatchNorm, args, rv);
- });
+.set_body_typed(MakeBatchNorm);
RELAY_REGISTER_OP("nn.batch_norm")
.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeBatchMatmul, args, rv);
- });
+.set_body_typed(MakeBatchMatmul);
RELAY_REGISTER_OP("nn.batch_matmul")
}
TVM_REGISTER_API("relay.op.nn._make.pad")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
- });
+.set_body_typed(MakePad);
RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor.
}
TVM_REGISTER_API("relay.op.nn._make.max_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 6>(MakeMaxPool2D, args, rv);
- });
+.set_body_typed(MakeMaxPool2D);
RELAY_REGISTER_OP("nn.max_pool2d")
TVM_REGISTER_API("relay.op.nn._make.avg_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 7>(MakeAvgPool2D, args, rv);
- });
+.set_body_typed(MakeAvgPool2D);
RELAY_REGISTER_OP("nn.avg_pool2d")
TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeGlobalAvgPool2D, args, rv);
- });
+.set_body_typed(MakeGlobalAvgPool2D);
// GlobalAvgPool
RELAY_REGISTER_OP("nn.global_avg_pool2d")
}
TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeGlobalMaxPool2D, args, rv);
- });
+.set_body_typed(MakeGlobalMaxPool2D);
RELAY_REGISTER_OP("nn.global_max_pool2d")
TVM_REGISTER_API("relay.op.nn._make.upsampling")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 4>(MakeUpSampling, args, rv);
- });
+.set_body_typed(MakeUpSampling);
RELAY_REGISTER_OP("nn.upsampling")
#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
- .set_body([](const TVMArgs& args, TVMRetValue* rv) { \
- auto make_func = [](Expr data, \
+ .set_body_typed<Call(Expr, Array<Integer>, bool, bool)>([]( \
+ Expr data, \
Array<Integer> axis, \
bool keepdims, \
bool exclude) { \
attrs->exclude = exclude; \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(attrs), {}); \
- }; \
- runtime::detail::unpack_call<Expr, 4>(make_func, args, rv); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
}
TVM_REGISTER_API("relay._make.cast")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
-});
+.set_body_typed(MakeCast);
RELAY_REGISTER_OP("cast")
.describe(R"code(Cast the data into a new data type.
}
TVM_REGISTER_API("relay.op._make.expand_dims")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeExpandDims, args, rv);
-});
+.set_body_typed(MakeExpandDims);
RELAY_REGISTER_OP("expand_dims")
.describe(R"code(Insert `num_newaxis` axises at the position given by `axis`
}
TVM_REGISTER_API("relay.op._make.concatenate")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeConcatenate, args, rv);
-});
+.set_body_typed(MakeConcatenate);
RELAY_REGISTER_OP("concatenate")
.describe(R"code(Concatenate the input tensors along the given axis.
}
TVM_REGISTER_API("relay.op._make.stack")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeStack, args, rv);
-});
+.set_body_typed(MakeStack);
RELAY_REGISTER_OP("stack")
.describe(R"code(Stack the input tensors along the given axis.
}
TVM_REGISTER_API("relay.op._make.transpose")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeTranspose, args, rv);
-});
+.set_body_typed(MakeTranspose);
RELAY_REGISTER_OP("transpose")
.describe(R"code(Permutes the dimensions of an array.
}
TVM_REGISTER_API("relay.op._make.reshape")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeReshape, args, rv);
-});
+.set_body_typed(MakeReshape);
RELAY_REGISTER_OP("reshape")
.describe(R"code(Reshapes the input array.
TVM_REGISTER_API("relay.op._make.reshape_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
-});
+.set_body_typed(MakeReshapeLike);
RELAY_REGISTER_OP("reshape_like")
}
TVM_REGISTER_API("relay.op._make.take")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
-});
+.set_body_typed(MakeTake);
RELAY_REGISTER_OP("take")
.describe(R"code(Take elements from an array along an axis.
}
TVM_REGISTER_API("relay.op._make.full")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeFull, args, rv);
-});
+.set_body_typed(MakeFull);
RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value.
}
TVM_REGISTER_API("relay.op._make.zeros")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeZeros, args, rv);
- });
+.set_body_typed(MakeZeros);
RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros.
}
TVM_REGISTER_API("relay.op._make.ones")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeOnes, args, rv);
- });
+.set_body_typed(MakeOnes);
RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones.
}
TVM_REGISTER_API("relay.op._make.full_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeFullLike, args, rv);
- });
+.set_body_typed(MakeFullLike);
RELAY_REGISTER_OP("full_like")
.describe(R"code(Return an scalar value array with the same shape
}
TVM_REGISTER_API("relay.op._make.arange")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv);
-});
+.set_body_typed(MakeArange);
RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval.
}
TVM_REGISTER_API("relay.op._make.repeat")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
-});
+.set_body_typed(MakeRepeat);
RELAY_REGISTER_OP("repeat")
.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
}
TVM_REGISTER_API("relay.op._make.tile")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
-});
+.set_body_typed(MakeTile);
RELAY_REGISTER_OP("tile")
.describe(R"code(Repeat the whole array multiple times.
}
TVM_REGISTER_API("relay.op._make.reverse")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeReverse, args, rv);
-});
+.set_body_typed(MakeReverse);
RELAY_REGISTER_OP("reverse")
.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
}
TVM_REGISTER_API("relay.op._make.where")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv);
-});
+.set_body_typed(MakeWhere);
RELAY_REGISTER_OP("where")
.describe(R"code(
}
TVM_REGISTER_API("relay.op._make.squeeze")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
- });
+.set_body_typed(MakeSqueeze);
bool SqueezeRel(const Array<Type>& types,
}
TVM_REGISTER_API("relay.op._make.collapse_sum_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
- });
+.set_body_typed(MakeCollapseSumLike);
RELAY_REGISTER_OP("collapse_sum_like")
.describe(R"code(Collapse the first input to match the shape of the second input.
}
TVM_REGISTER_API("relay.op._make.broadcast_to")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeBroadCastTo, args, rv);
- });
+.set_body_typed(MakeBroadCastTo);
RELAY_REGISTER_OP("broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
}
TVM_REGISTER_API("relay.op._make.broadcast_to_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
- });
+.set_body_typed(MakeBroadCastToLike);
RELAY_REGISTER_OP("broadcast_to_like")
.describe(R"code(Broadcast the first input to match the shape of the second input.
TVM_REGISTER_API("relay.op._make.strided_slice")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 4>(MakeStridedSlice, args, rv);
- });
+.set_body_typed(MakeStridedSlice);
RELAY_REGISTER_OP("strided_slice")
TVM_REGISTER_API("relay.op._make.slice_like")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeSliceLike, args, rv);
-});
+.set_body_typed(MakeSliceLike);
RELAY_REGISTER_OP("slice_like")
}
TVM_REGISTER_API("relay.op._make.layout_transform")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 3>(MakeLayoutTransform, args, rv);
-});
+.set_body_typed(MakeLayoutTransform);
RELAY_REGISTER_OP("layout_transform")
.describe(R"code(Transform the input data layout.
}
TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeReverseReshape, args, rv);
-});
+.set_body_typed(MakeReverseReshape);
RELAY_REGISTER_OP("_contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from
}
TVM_REGISTER_API("relay.op._make.gather_nd")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
-});
+.set_body_typed(MakeGatherND);
RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxPrior, args, rv);
-});
+.set_body_typed(MakeMultiBoxPrior);
RELAY_REGISTER_OP("vision.multibox_prior")
}
TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxTransformLoc, args, rv);
-});
+.set_body_typed(MakeMultiBoxTransformLoc);
RELAY_REGISTER_OP("vision.multibox_transform_loc")
.describe(R"doc("Location transformation for multibox detection."
TVM_REGISTER_API("relay.op.vision._make.get_valid_counts")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeGetValidCounts, args, rv);
-});
+.set_body_typed(MakeGetValidCounts);
RELAY_REGISTER_OP("vision.get_valid_counts")
TVM_REGISTER_API("relay.op.vision._make.non_max_suppression")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 9>(MakeNMS, args, rv);
-});
+.set_body_typed(MakeNMS);
RELAY_REGISTER_OP("vision.non_max_suppression")
}
TVM_REGISTER_API("relay.op.vision._make.roi_align")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 6>(MakeROIAlign, args, rv);
- });
+.set_body_typed(MakeROIAlign);
RELAY_REGISTER_OP("vision.roi_align")
.describe(R"doc(ROI Align operator.
}
TVM_REGISTER_API("relay.op.vision._make.roi_pool")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 5>(MakeROIPool, args, rv);
- });
+.set_body_typed(MakeROIPool);
RELAY_REGISTER_OP("vision.roi_pool")
.describe(R"doc(ROI Pool operator.
}
TVM_REGISTER_API("relay.op.vision._make.proposal")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 11>(MakeProposal, args, rv);
- });
+.set_body_typed(MakeProposal);
RELAY_REGISTER_OP("vision.proposal")
.describe(R"code(Generate region proposals via RPN.
TVM_REGISTER_API("relay.op.vision._make.yolo_reorg")
-.set_body([](const TVMArgs& args, TVMRetValue* rv) {
- runtime::detail::unpack_call<Expr, 2>(MakeYoloReorg, args, rv);
-});
+.set_body_typed(MakeYoloReorg);
RELAY_REGISTER_OP("vision.yolo_reorg")
}
TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-*ret = CanonicalizeOps(args[0]);
-});
+.set_body_typed(CanonicalizeOps);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = CombineParallelConv2D(args[0], args[1]);
-});
+.set_body_typed(CombineParallelConv2D);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = DeadCodeElimination(args[0]);
- });
+.set_body_typed(DeadCodeElimination);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = CollectDeviceInfo(args[0]);
-});
+.set_body_typed(CollectDeviceInfo);
TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = RewriteAnnotatedOps(args[0], args[1]);
-});
+.set_body_typed(RewriteAnnotatedOps);
TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = CollectDeviceAnnotationOps(args[0]);
-});
+.set_body_typed(CollectDeviceAnnotationOps);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.FoldConstant")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = FoldConstant(args[0]);
-});
+.set_body_typed(FoldConstant);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.FuseOps")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = FuseOps(args[0], args[1]);
-});
+.set_body_typed(FuseOps);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args.size(), 2);
- *ret = FirstOrderGradient(args[0], args[1]);
-});
+.set_body_typed(FirstOrderGradient);
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
- ReverseAD(const Var& bp) : bp(bp) { }
+ ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*)
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
}
TVM_REGISTER_API("relay._ir_pass.gradient")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- CHECK_EQ(args.size(), 2);
- *ret = Gradient(args[0], args[1]);
-});
+.set_body_typed(Gradient);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = GetTotalMacNumber(args[0]);
-});
+.set_body_typed(GetTotalMacNumber);
} // namespace mac_count
} // namespace relay
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_API("relay._ir_pass.PassInfo")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- int opt_level = args[0];
- std::string name = args[1];
- tvm::Array<tvm::Expr> required = args[2];
- *ret = PassInfoNode::make(opt_level, name, required);
-});
+.set_body_typed(PassInfoNode::make);
TVM_REGISTER_API("relay._ir_pass.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_API("relay._ir_pass.CreateModulePass")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- PackedFunc pass_func = args[0];
- int opt_level = args[1];
- std::string name = args[2];
- tvm::Array<tvm::Expr> required = args[3];
- *ret = CreateModulePass(pass_func, opt_level, name, required);
-});
+.set_body_typed(CreateModulePass);
TVM_REGISTER_API("relay._ir_pass.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- PackedFunc pass_func = args[0];
- int opt_level = args[1];
- std::string name = args[2];
- tvm::Array<tvm::Expr> required = args[3];
- *ret = CreateFunctionPass(pass_func, opt_level, name, required);
-});
+.set_body_typed(CreateFunctionPass);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._ir_pass.PassContext")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = PassContextNode::make();
-});
+.set_body_typed(PassContextNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const PassContextNode* node,
});
TVM_REGISTER_API("relay._quantize._GetCurrentQConfig")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = QConfig::Current();
- });
+.set_body_typed(QConfig::Current);
TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- QConfig target = args[0];
- QConfig::EnterQConfigScope(target);
- });
+.set_body_typed(QConfig::EnterQConfigScope);
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- QConfig::ExitQConfigScope();
- });
+.set_body_typed(QConfig::ExitQConfigScope);
} // namespace quantize
} // namespace relay
}
TVM_REGISTER_API("relay._ir_pass.simplify_inference")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = SimplifyInference(args[0]);
- });
+.set_body_typed(SimplifyInference);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ToANormalForm(args[0], args[1]);
- });
+.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = ToGraphNormalForm(args[0]);
-});
+.set_body_typed(ToGraphNormalForm);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.infer_type")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = InferType(args[0], args[1]);
+.set_body_typed<Expr(const Expr&, const Module&)>([](const Expr& expr, const Module& mod_ref) {
+ return InferType(expr, mod_ref);
});
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_API("relay._ir_pass.free_vars")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = FreeVars(args[0]);
- });
+.set_body_typed(FreeVars);
TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
});
TVM_REGISTER_API("relay._ir_pass.all_vars")
- .set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = AllVars(args[0]);
- });
+.set_body_typed(AllVars);
TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
}
TVM_REGISTER_API("relay._ir_pass.well_formed")
- .set_body([](TVMArgs args, TVMRetValue *ret) {
- Expr e = args[0];
- *ret = WellFormed(e);
- });
+.set_body_typed(WellFormed);
} // namespace relay
} // namespace tvm
}
TVM_REGISTER_GLOBAL("module.loadfile_cubin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = CUDAModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_ptx")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = CUDAModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = CUDAModuleLoadBinary(args[0]);
- });
+.set_body_typed(CUDAModuleLoadBinary);
} // namespace runtime
} // namespace tvm
}
TVM_REGISTER_GLOBAL("module.loadfile_metal")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = MetalModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(MetalModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_metal")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = MetalModuleLoadBinary(args[0]);
- });
+.set_body_typed(MetalModuleLoadBinary);
} // namespace runtime
} // namespace tvm
}
TVM_REGISTER_GLOBAL("module.loadfile_aocx")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = AOCLModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(AOCLModuleLoadFile);
} // namespace runtime
} // namespace tvm
}
TVM_REGISTER_GLOBAL("module.loadfile_cl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = OpenCLModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(OpenCLModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_clbin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = OpenCLModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(OpenCLModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = OpenCLModuleLoadBinary(args[0]);
- });
+.set_body_typed(OpenCLModuleLoadBinary);
} // namespace runtime
} // namespace tvm
}
TVM_REGISTER_GLOBAL("module.loadfile_xclbin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = SDAccelModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(SDAccelModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = SDAccelModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(SDAccelModuleLoadFile);
} // namespace runtime
} // namespace tvm
TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = ROCMModuleLoadBinary(args[0]);
- });
+.set_body_typed(ROCMModuleLoadBinary);
TVM_REGISTER_GLOBAL("module.loadbinary_hip")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = ROCMModuleLoadBinary(args[0]);
- });
+.set_body_typed(ROCMModuleLoadBinary);
} // namespace runtime
} // namespace tvm
}
TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = CreateEventDrivenServer(args[0], args[1], args[2]);
- });
+.set_body_typed(CreateEventDrivenServer);
} // namespace runtime
} // namespace tvm
}
TVM_REGISTER_GLOBAL("rpc._Connect")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = RPCClientConnect(args[0], args[1], args[2]);
- });
+.set_body_typed(RPCClientConnect);
TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
}
TVM_REGISTER_GLOBAL("module.loadfile_stackvm")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = StackVMModuleNode::LoadFromFile(args[0], args[1]);
- });
+.set_body_typed(StackVMModuleNode::LoadFromFile);
} // namespace runtime
} // namespace tvm
}
TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = VulkanModuleLoadFile(args[0], args[1]);
- });
+.set_body_typed(VulkanModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = VulkanModuleLoadBinary(args[0]);
- });
+.set_body_typed(VulkanModuleLoadBinary);
} // namespace runtime
} // namespace tvm
};
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
+.set_body_typed<std::string(std::string)>([](std::string path) {
static RPCEnv env;
- *rv = env.GetPath(args[0]);
+ return env.GetPath(path);
});
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
- std::string file_name = "/rpc/" + args[0].operator std::string();
- *rv = Module::LoadFromFile(file_name, "");
+.set_body_typed<Module(std::string)>([](std::string path) {
+ std::string file_name = "/rpc/" + path;
LOG(INFO) << "Load module from " << file_name << " ...";
+ return Module::LoadFromFile(file_name, "");
});
} // namespace contrib
} // namespace tvm