From 8152360466b57a5c848d3cf05efd373a1690b335 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 4 Jan 2020 15:38:56 -0800 Subject: [PATCH] [REFACTOR] TVM_REGISTER_API -> TVM_REGISTER_GLOBAL (#4621) TVM_REGSISTER_API is an alias of TVM_REGISTER_GLOBAL. In the spirit of simplify redirections, this PR removes the original TVM_REGISTER_API macro and directly use TVM_REGISTER_GLOBAL. This type of refactor will also simplify the IDE navigation tools such as FFI navigator to provide better code reading experiences. Move EnvFunc's definition to node. --- docs/dev/codebase_walkthrough.rst | 4 +- docs/dev/relay_add_op.rst | 4 +- docs/dev/relay_pass_infra.rst | 32 +++--- include/tvm/codegen.h | 1 - include/tvm/{api_registry.h => node/env_func.h} | 32 ++---- include/tvm/relay/base.h | 2 +- include/tvm/relay/type.h | 6 +- include/tvm/runtime/registry.h | 12 +- src/api/api_arith.cc | 32 +++--- src/api/api_base.cc | 16 +-- src/api/api_codegen.cc | 8 +- src/api/api_ir.cc | 48 ++++---- src/api/api_lang.cc | 144 ++++++++++++------------ src/api/api_pass.cc | 28 ++--- src/api/api_schedule.cc | 12 +- src/api/api_test.cc | 23 ++-- src/arithmetic/bound_deducer.cc | 4 +- src/arithmetic/domain_touched.cc | 4 +- src/arithmetic/int_set.cc | 6 +- src/autotvm/touch_extractor.cc | 6 +- src/autotvm/touch_extractor.h | 4 +- src/codegen/build_common.h | 5 +- src/codegen/build_module.cc | 35 +++--- src/codegen/codegen_aocl.cc | 4 +- src/codegen/codegen_c_host.cc | 2 +- src/codegen/codegen_metal.cc | 2 +- src/codegen/codegen_opencl.cc | 2 +- src/codegen/codegen_opengl.cc | 2 +- src/codegen/codegen_vhls.cc | 2 +- src/codegen/datatype/registry.cc | 4 +- src/codegen/intrin_rule.h | 4 +- src/codegen/llvm/codegen_amdgpu.cc | 2 +- src/codegen/llvm/codegen_nvptx.cc | 2 +- src/codegen/llvm/intrin_rule_llvm.h | 4 +- src/codegen/llvm/intrin_rule_nvptx.cc | 8 +- src/codegen/llvm/intrin_rule_rocm.cc | 8 +- src/codegen/llvm/llvm_module.cc | 10 +- src/codegen/opt/build_cuda_on.cc | 6 +- src/codegen/spirv/build_vulkan.cc | 6 +- src/codegen/stackvm/codegen_stackvm.cc | 6 +- src/contrib/hybrid/codegen_hybrid.cc | 3 + src/lang/attrs.cc | 6 +- src/{lang/api_registry.cc => node/env_func.cc} | 17 ++- src/op/tensorize.cc | 8 +- src/pass/hoist_if_then_else.cc | 4 +- src/pass/lower_intrin.cc | 6 +- src/pass/verify_gpu_code.cc | 4 +- src/relay/backend/compile_engine.cc | 2 +- src/relay/backend/contrib/codegen_c/codegen.cc | 2 +- src/relay/backend/contrib/dnnl/codegen.cc | 2 +- src/relay/backend/interpreter.cc | 14 +-- src/relay/backend/vm/inline_primitives.cc | 2 +- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/backend/vm/removed_unused_funcs.cc | 2 +- src/relay/ir/adt.cc | 16 +-- src/relay/ir/alpha_equal.cc | 8 +- src/relay/ir/base.cc | 6 +- src/relay/ir/expr.cc | 32 +++--- src/relay/ir/expr_functor.cc | 4 +- src/relay/ir/hash.cc | 4 +- src/relay/ir/module.cc | 34 +++--- src/relay/ir/op.cc | 12 +- src/relay/ir/pretty_printer.cc | 2 +- src/relay/ir/type.cc | 14 +-- src/relay/op/algorithm/argsort.cc | 2 +- src/relay/op/algorithm/topk.cc | 2 +- src/relay/op/annotation/annotation.cc | 6 +- src/relay/op/debug.cc | 2 +- src/relay/op/device_copy.cc | 2 +- src/relay/op/image/resize.cc | 2 +- src/relay/op/memory/memory.cc | 8 +- src/relay/op/nn/bitserial.cc | 6 +- src/relay/op/nn/convolution.cc | 24 ++-- src/relay/op/nn/nn.cc | 40 +++---- src/relay/op/nn/pad.cc | 4 +- src/relay/op/nn/pooling.cc | 20 ++-- src/relay/op/nn/sparse.cc | 4 +- src/relay/op/nn/upsampling.cc | 4 +- src/relay/op/op_common.h | 6 +- src/relay/op/tensor/reduce.cc | 4 +- src/relay/op/tensor/transform.cc | 66 +++++------ src/relay/op/tensor/unary.cc | 6 +- src/relay/op/vision/multibox_op.cc | 4 +- src/relay/op/vision/nms.cc | 4 +- src/relay/op/vision/rcnn_op.cc | 6 +- src/relay/op/vision/yolo.cc | 2 +- src/relay/pass/alter_op_layout.cc | 2 +- src/relay/pass/canonicalize_cast.cc | 2 +- src/relay/pass/canonicalize_ops.cc | 6 +- src/relay/pass/combine_parallel_conv2d.cc | 2 +- src/relay/pass/combine_parallel_dense.cc | 10 +- src/relay/pass/combine_parallel_op_batch.cc | 12 +- src/relay/pass/convert_layout.cc | 2 +- src/relay/pass/de_duplicate.cc | 2 +- src/relay/pass/dead_code.cc | 2 +- src/relay/pass/device_annotation.cc | 8 +- src/relay/pass/eliminate_common_subexpr.cc | 2 +- src/relay/pass/eta_expand.cc | 2 +- src/relay/pass/feature.cc | 2 +- src/relay/pass/fold_constant.cc | 4 +- src/relay/pass/fold_scale_axis.cc | 6 +- src/relay/pass/fuse_ops.cc | 2 +- src/relay/pass/gradient.cc | 4 +- src/relay/pass/kind_check.cc | 2 +- src/relay/pass/legalize.cc | 2 +- src/relay/pass/mac_count.cc | 8 +- src/relay/pass/match_exhaustion.cc | 2 +- src/relay/pass/partial_eval.cc | 2 +- src/relay/pass/pass_manager.cc | 20 ++-- src/relay/pass/print_ir.cc | 2 +- src/relay/pass/quantize/annotate.cc | 4 +- src/relay/pass/quantize/calibrate.cc | 4 +- src/relay/pass/quantize/partition.cc | 4 +- src/relay/pass/quantize/quantize.cc | 8 +- src/relay/pass/quantize/realize.cc | 2 +- src/relay/pass/simplify_inference.cc | 2 +- src/relay/pass/to_a_normal_form.cc | 2 +- src/relay/pass/to_cps.cc | 8 +- src/relay/pass/to_graph_normal_form.cc | 2 +- src/relay/pass/type_infer.cc | 4 +- src/relay/pass/type_solver.cc | 2 +- src/relay/pass/util.cc | 12 +- src/relay/pass/well_formed.cc | 2 +- src/relay/qnn/op/concatenate.cc | 2 +- src/relay/qnn/op/convolution.cc | 2 +- src/relay/qnn/op/dense.cc | 2 +- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/op_common.h | 2 +- src/relay/qnn/op/quantize.cc | 2 +- src/relay/qnn/op/requantize.cc | 2 +- src/relay/qnn/pass/legalize.cc | 2 +- 131 files changed, 602 insertions(+), 550 deletions(-) rename include/tvm/{api_registry.h => node/env_func.h} (86%) rename src/{lang/api_registry.cc => node/env_func.cc} (87%) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 7e78d57..19f185e 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -72,7 +72,7 @@ The Node system is the basis of exposing C++ types to frontend languages, includ :: - TVM_REGISTER_API("_ComputeOp") + TVM_REGISTER_GLOBAL("_ComputeOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ComputeOpNode::make(args[0], args[1], @@ -174,7 +174,7 @@ The ``Build()`` function looks up the code generator for the given target in the :: - TVM_REGISTER_API("codegen.build_cuda") + TVM_REGISTER_GLOBAL("codegen.build_cuda") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildCUDA(args[0]); }); diff --git a/docs/dev/relay_add_op.rst b/docs/dev/relay_add_op.rst index 466dca0..f494cc6 100644 --- a/docs/dev/relay_add_op.rst +++ b/docs/dev/relay_add_op.rst @@ -96,7 +96,7 @@ the arguments to the call node, as below. .. code:: c - TVM_REGISTER_API("relay.op._make.add") + TVM_REGISTER_GLOBAL("relay.op._make.add") .set_body_typed([](Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); @@ -106,7 +106,7 @@ Including a Python API Hook --------------------------- It is generally the convention in Relay, that functions exported -through ``TVM_REGISTER_API`` should be wrapped in a separate +through ``TVM_REGISTER_GLOBAL`` should be wrapped in a separate Python function rather than called directly in Python. In the case of the functions that produce calls to operators, it may be convenient to bundle them, as in ``python/tvm/relay/op/tensor.py``, where diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 6593215..57dcca1 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -131,13 +131,13 @@ Python APIs to create a compilation pipeline using pass context. TVM_DLL static PassContext Create(); TVM_DLL static PassContext Current(); /* Other fields are omitted. */ - + private: // The entry of a pass context scope. TVM_DLL void EnterWithScope(); // The exit of a pass context scope. TVM_DLL void ExitWithScope(); - + // Classes to get the Python `with` like syntax. friend class tvm::With; }; @@ -225,7 +225,7 @@ cannot add or delete a function through these passes as they are not aware of the global information. .. code:: c++ - + class FunctionPassNode : PassNode { PassInfo pass_info; runtime::TypedPackedFunc pass_func; @@ -319,7 +319,7 @@ favorably use Python APIs to create a specific pass object. ModulePass CreateModulePass(std::string name, int opt_level, PassFunc pass_func); - + SequentialPass CreateSequentialPass(std::string name, int opt_level, Array passes, @@ -347,14 +347,14 @@ registration. auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool()); auto x = relay::VarNode::make("x", relay::Type()); auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); - + auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, tvm::Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); - + // Create a module for optimization. auto mod = relay::ModuleNode::FromExpr(fx); - + // Create a sequential pass. tvm::Array pass_seqs{ relay::transform::InferType(), @@ -363,7 +363,7 @@ registration. relay::transform::AlterOpLayout() }; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); - + // Create a pass context for the optimization. auto ctx = relay::transform::PassContext::Create(); ctx->opt_level = 2; @@ -421,7 +421,7 @@ Python when needed. return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } - TVM_REGISTER_API("relay._transform.FoldConstant") + TVM_REGISTER_GLOBAL("relay._transform.FoldConstant") .set_body_typed(FoldConstant); } // namespace transform @@ -457,10 +457,10 @@ a certain scope. def __enter__(self): _transform.EnterPassContext(self) return self - + def __exit__(self, ptype, value, trace): _transform.ExitPassContext(self) - + @staticmethod def current(): """Return the current pass context.""" @@ -580,18 +580,18 @@ using ``Sequential`` associated with other types of passes. z1 = relay.add(y, c) z2 = relay.add(z, z1) func = relay.Function([x], z2) - - # Customize the optimization pipeline. + + # Customize the optimization pipeline. seq = _transform.Sequential([ relay.transform.InferType(), relay.transform.FoldConstant(), relay.transform.EliminateCommonSubexpr(), relay.transform.AlterOpLayout() ]) - + # Create a module to perform optimizations. mod = relay.Module({"main": func}) - + # Users can disable any passes that they don't want to execute by providing # a list, e.g. disabled_pass=["EliminateCommonSubexpr"]. with relay.build_config(opt_level=3): @@ -629,7 +629,7 @@ For more pass infra related examples in Python and C++, please refer to .. _Block: https://mxnet.incubator.apache.org/api/python/docs/api/gluon/block.html#gluon-block -.. _Relay module: https://docs.tvm.ai/langref/relay_expr.html#module-and-global-functions +.. _Relay module: https://docs.tvm.ai/langref/relay_expr.html#module-and-global-functions .. _include/tvm/relay/transform.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index 2f4058e..78fb7d1 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -28,7 +28,6 @@ #include "base.h" #include "expr.h" #include "lowered_func.h" -#include "api_registry.h" #include "runtime/packed_func.h" namespace tvm { diff --git a/include/tvm/api_registry.h b/include/tvm/node/env_func.h similarity index 86% rename from include/tvm/api_registry.h rename to include/tvm/node/env_func.h index 292e494..c2ea2b4 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/node/env_func.h @@ -18,33 +18,19 @@ */ /*! - * \file tvm/api_registry.h - * \brief This file contains utilities related to - * the TVM's global function registry. + * \file tvm/node/env_func.h + * \brief Serializable global function. */ -#ifndef TVM_API_REGISTRY_H_ -#define TVM_API_REGISTRY_H_ +#ifndef TVM_NODE_ENV_FUNC_H_ +#define TVM_NODE_ENV_FUNC_H_ + +#include #include #include -#include "base.h" -#include "packed_func_ext.h" -#include "runtime/registry.h" -namespace tvm { -/*! - * \brief Register an API function globally. - * It simply redirects to TVM_REGISTER_GLOBAL - * - * \code - * TVM_REGISTER_API(MyPrint) - * .set_body([](TVMArgs args, TVMRetValue* rv) { - * // my code. - * }); - * \endcode - */ -#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName) +namespace tvm { /*! * \brief Node container of EnvFunc * \sa EnvFunc @@ -54,7 +40,7 @@ class EnvFuncNode : public Object { /*! \brief Unique name of the global function */ std::string name; /*! \brief The internal packed function */ - PackedFunc func; + runtime::PackedFunc func; /*! \brief constructor */ EnvFuncNode() {} @@ -154,4 +140,4 @@ class TypedEnvFunc : public ObjectRef { }; } // namespace tvm -#endif // TVM_API_REGISTRY_H_ +#endif // TVM_NODE_ENV_FUNC_H_ diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 7191e1f..b4164fb 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ -#include + #include #include #include diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index c6a560a..c8a02a8 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -24,8 +24,12 @@ #ifndef TVM_RELAY_TYPE_H_ #define TVM_RELAY_TYPE_H_ -#include + #include +#include +#include +#include + #include #include diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index e51b806..a7e8041 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -70,7 +70,7 @@ class Registry { * * \code * - * TVM_REGISTER_API("addone") + * TVM_REGISTER_GLOBAL("addone") * .set_body_typed([](int x) { return x + 1; }); * * \endcode @@ -96,7 +96,7 @@ class Registry { * return x * y; * } * - * TVM_REGISTER_API("multiply") + * TVM_REGISTER_GLOBAL("multiply") * .set_body_typed(multiply); // will have type int(int, int) * * \endcode @@ -120,7 +120,7 @@ class Registry { * struct Example { * int doThing(int x); * } - * TVM_REGISTER_API("Example_doThing") + * TVM_REGISTER_GLOBAL("Example_doThing") * .set_body_method(&Example::doThing); // will have type int(Example, int) * * \endcode @@ -148,7 +148,7 @@ class Registry { * struct Example { * int doThing(int x); * } - * TVM_REGISTER_API("Example_doThing") + * TVM_REGISTER_GLOBAL("Example_doThing") * .set_body_method(&Example::doThing); // will have type int(Example, int) * * \endcode @@ -181,7 +181,7 @@ class Registry { * // noderef subclass * struct Example; * - * TVM_REGISTER_API("Example_doThing") + * TVM_REGISTER_GLOBAL("Example_doThing") * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) * * // note that just doing: @@ -221,7 +221,7 @@ class Registry { * // noderef subclass * struct Example; * - * TVM_REGISTER_API("Example_doThing") + * TVM_REGISTER_GLOBAL("Example_doThing") * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) * * // note that just doing: diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 499d43d..5eef8db 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -23,29 +23,31 @@ */ #include #include -#include +#include +#include + #include namespace tvm { namespace arith { -TVM_REGISTER_API("arith.intset_single_point") +TVM_REGISTER_GLOBAL("arith.intset_single_point") .set_body_typed(IntSet::single_point); -TVM_REGISTER_API("arith.intset_vector") +TVM_REGISTER_GLOBAL("arith.intset_vector") .set_body_typed(IntSet::vector); -TVM_REGISTER_API("arith.intset_interval") +TVM_REGISTER_GLOBAL("arith.intset_interval") .set_body_typed(IntSet::interval); -TVM_REGISTER_API("arith.DetectLinearEquation") +TVM_REGISTER_GLOBAL("arith.DetectLinearEquation") .set_body_typed(DetectLinearEquation); -TVM_REGISTER_API("arith.DetectClipBound") +TVM_REGISTER_GLOBAL("arith.DetectClipBound") .set_body_typed(DetectClipBound); -TVM_REGISTER_API("arith.DeduceBound") +TVM_REGISTER_GLOBAL("arith.DeduceBound") .set_body_typed, Map)>([]( Expr v, Expr cond, const Map hint_map, @@ -55,36 +57,36 @@ TVM_REGISTER_API("arith.DeduceBound") }); -TVM_REGISTER_API("arith.DomainTouched") +TVM_REGISTER_GLOBAL("arith.DomainTouched") .set_body_typed(DomainTouched); -TVM_REGISTER_API("_IntervalSetGetMin") +TVM_REGISTER_GLOBAL("_IntervalSetGetMin") .set_body_method(&IntSet::min); -TVM_REGISTER_API("_IntervalSetGetMax") +TVM_REGISTER_GLOBAL("_IntervalSetGetMax") .set_body_method(&IntSet::max); -TVM_REGISTER_API("_IntSetIsNothing") +TVM_REGISTER_GLOBAL("_IntSetIsNothing") .set_body_method(&IntSet::is_nothing); -TVM_REGISTER_API("_IntSetIsEverything") +TVM_REGISTER_GLOBAL("_IntSetIsEverything") .set_body_method(&IntSet::is_everything); ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_REGISTER_API("arith._make_ConstIntBound") +TVM_REGISTER_GLOBAL("arith._make_ConstIntBound") .set_body_typed(MakeConstIntBound); ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_REGISTER_API("arith._make_ModularSet") +TVM_REGISTER_GLOBAL("arith._make_ModularSet") .set_body_typed(MakeModularSet); -TVM_REGISTER_API("arith._CreateAnalyzer") +TVM_REGISTER_GLOBAL("arith._CreateAnalyzer") .set_body([](TVMArgs args, TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; diff --git a/src/api/api_base.cc b/src/api/api_base.cc index bcfd82b..89dd4fc 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -24,11 +24,13 @@ #include #include #include -#include +#include +#include + #include namespace tvm { -TVM_REGISTER_API("_format_str") +TVM_REGISTER_GLOBAL("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { CHECK(args[0].type_code() == kObjectHandle); std::ostringstream os; @@ -36,22 +38,22 @@ TVM_REGISTER_API("_format_str") *ret = os.str(); }); -TVM_REGISTER_API("_raw_ptr") +TVM_REGISTER_GLOBAL("_raw_ptr") .set_body([](TVMArgs args, TVMRetValue *ret) { CHECK(args[0].type_code() == kObjectHandle); *ret = reinterpret_cast(args[0].value().v_handle); }); -TVM_REGISTER_API("_save_json") +TVM_REGISTER_GLOBAL("_save_json") .set_body_typed(SaveJSON); -TVM_REGISTER_API("_load_json") +TVM_REGISTER_GLOBAL("_load_json") .set_body_typed(LoadJSON); -TVM_REGISTER_API("_TVMSetStream") +TVM_REGISTER_GLOBAL("_TVMSetStream") .set_body_typed(TVMSetStream); -TVM_REGISTER_API("_save_param_dict") +TVM_REGISTER_GLOBAL("_save_param_dict") .set_body([](TVMArgs args, TVMRetValue *rv) { CHECK_EQ(args.size() % 2, 0u); constexpr uint64_t TVMNDArrayListMagic = 0xF7E58D4F05049CB7; diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 5b6050d..a58e905 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -25,12 +25,14 @@ #include #include #include -#include +#include +#include + namespace tvm { namespace codegen { -TVM_REGISTER_API("codegen._Build") +TVM_REGISTER_GLOBAL("codegen._Build") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { *ret = Build({args[0]}, args[1]); @@ -39,7 +41,7 @@ TVM_REGISTER_API("codegen._Build") } }); -TVM_REGISTER_API("module._PackImportsToC") +TVM_REGISTER_GLOBAL("module._PackImportsToC") .set_body_typed(PackImportsToC); } // namespace codegen } // namespace tvm diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 03f37b1..2b7a36f 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -23,45 +23,47 @@ */ #include #include -#include +#include +#include + #include namespace tvm { namespace ir { -TVM_REGISTER_API("_Var") +TVM_REGISTER_GLOBAL("_Var") .set_body_typed([](std::string s, DataType t) { return Variable::make(t, s); }); -TVM_REGISTER_API("make.abs") +TVM_REGISTER_GLOBAL("make.abs") .set_body_typed(tvm::abs); -TVM_REGISTER_API("make.isnan") +TVM_REGISTER_GLOBAL("make.isnan") .set_body_typed(tvm::isnan); -TVM_REGISTER_API("make.floor") +TVM_REGISTER_GLOBAL("make.floor") .set_body_typed(tvm::floor); -TVM_REGISTER_API("make.ceil") +TVM_REGISTER_GLOBAL("make.ceil") .set_body_typed(tvm::ceil); -TVM_REGISTER_API("make.round") +TVM_REGISTER_GLOBAL("make.round") .set_body_typed(tvm::round); -TVM_REGISTER_API("make.nearbyint") +TVM_REGISTER_GLOBAL("make.nearbyint") .set_body_typed(tvm::nearbyint); -TVM_REGISTER_API("make.trunc") +TVM_REGISTER_GLOBAL("make.trunc") .set_body_typed(tvm::trunc); -TVM_REGISTER_API("make._cast") +TVM_REGISTER_GLOBAL("make._cast") .set_body_typed(tvm::cast); -TVM_REGISTER_API("make._range_by_min_extent") +TVM_REGISTER_GLOBAL("make._range_by_min_extent") .set_body_typed(Range::make_by_min_extent); -TVM_REGISTER_API("make.For") +TVM_REGISTER_GLOBAL("make.For") .set_body_typed([]( VarExpr loop_var, Expr min, Expr extent, int for_type, int device_api, Stmt body) { @@ -73,7 +75,7 @@ TVM_REGISTER_API("make.For") body); }); -TVM_REGISTER_API("make.Load") +TVM_REGISTER_GLOBAL("make.Load") .set_body([](TVMArgs args, TVMRetValue *ret) { DataType t = args[0]; if (args.size() == 3) { @@ -83,7 +85,7 @@ TVM_REGISTER_API("make.Load") } }); -TVM_REGISTER_API("make.Store") +TVM_REGISTER_GLOBAL("make.Store") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr value = args[1]; if (args.size() == 3) { @@ -93,10 +95,10 @@ TVM_REGISTER_API("make.Store") } }); -TVM_REGISTER_API("make.Realize") +TVM_REGISTER_GLOBAL("make.Realize") .set_body_typed(Realize::make); -TVM_REGISTER_API("make.Call") +TVM_REGISTER_GLOBAL("make.Call") .set_body_typed, int, FunctionRef, int)>([]( DataType type, std::string name, Array args, int call_type, @@ -110,12 +112,12 @@ TVM_REGISTER_API("make.Call") value_index); }); -TVM_REGISTER_API("make.CommReducer") +TVM_REGISTER_GLOBAL("make.CommReducer") .set_body_typed(CommReducerNode::make); // make from two arguments #define REGISTER_MAKE(Node) \ - TVM_REGISTER_API("make."#Node) \ + TVM_REGISTER_GLOBAL("make."#Node) \ .set_body_typed(Node::make); \ REGISTER_MAKE(Reduce); @@ -161,11 +163,11 @@ REGISTER_MAKE(IfThenElse); REGISTER_MAKE(Evaluate); // overloaded, needs special handling -TVM_REGISTER_API("make.Block") +TVM_REGISTER_GLOBAL("make.Block") .set_body_typed(static_cast(Block::make)); // has default args -TVM_REGISTER_API("make.Allocate") +TVM_REGISTER_GLOBAL("make.Allocate") .set_body_typed, Expr, Stmt)>([]( VarExpr buffer_var, DataType type, Array extents, Expr condition, Stmt body ){ @@ -174,13 +176,13 @@ TVM_REGISTER_API("make.Allocate") // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_API("make."#Node) \ + TVM_REGISTER_GLOBAL("make."#Node) \ .set_body_typed([](Expr a, Expr b) { \ return (Func(a, b)); \ }) #define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_API("make."#Node) \ + TVM_REGISTER_GLOBAL("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ bool lhs_is_int = args[0].type_code() == kDLInt; \ bool rhs_is_int = args[1].type_code() == kDLInt; \ @@ -221,7 +223,7 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); -TVM_REGISTER_API("make._OpIfThenElse") +TVM_REGISTER_GLOBAL("make._OpIfThenElse") .set_body_typed([] (Expr cond, Expr true_value, Expr false_value) { return if_then_else(cond, true_value, false_value); }); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 00ceaf7..b8f7d0f 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -27,20 +27,22 @@ #include #include #include -#include +#include +#include + #include #include namespace tvm { -TVM_REGISTER_API("_min_value") +TVM_REGISTER_GLOBAL("_min_value") .set_body_typed(min_value); -TVM_REGISTER_API("_max_value") +TVM_REGISTER_GLOBAL("_max_value") .set_body_typed(max_value); -TVM_REGISTER_API("_const") +TVM_REGISTER_GLOBAL("_const") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { *ret = make_const(args[1], args[0].operator int64_t()); @@ -51,11 +53,11 @@ TVM_REGISTER_API("_const") } }); -TVM_REGISTER_API("_str") +TVM_REGISTER_GLOBAL("_str") .set_body_typed(ir::StringImm::make); -TVM_REGISTER_API("_Array") +TVM_REGISTER_GLOBAL("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { std::vector data; for (int i = 0; i < args.size(); ++i) { @@ -70,7 +72,7 @@ TVM_REGISTER_API("_Array") *ret = Array(node); }); -TVM_REGISTER_API("_ArrayGetItem") +TVM_REGISTER_GLOBAL("_ArrayGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { int64_t i = args[1]; CHECK_EQ(args[0].type_code(), kObjectHandle); @@ -82,7 +84,7 @@ TVM_REGISTER_API("_ArrayGetItem") *ret = n->data[static_cast(i)]; }); -TVM_REGISTER_API("_ArraySize") +TVM_REGISTER_GLOBAL("_ArraySize") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast(args[0].value().v_handle); @@ -91,7 +93,7 @@ TVM_REGISTER_API("_ArraySize") static_cast(ptr)->data.size()); }); -TVM_REGISTER_API("_Map") +TVM_REGISTER_GLOBAL("_Map") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size() % 2, 0); if (args.size() != 0 && args[0].type_code() == kStr) { @@ -125,7 +127,7 @@ TVM_REGISTER_API("_Map") } }); -TVM_REGISTER_API("_MapSize") +TVM_REGISTER_GLOBAL("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast(args[0].value().v_handle); @@ -139,7 +141,7 @@ TVM_REGISTER_API("_MapSize") } }); -TVM_REGISTER_API("_MapGetItem") +TVM_REGISTER_GLOBAL("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast(args[0].value().v_handle); @@ -161,7 +163,7 @@ TVM_REGISTER_API("_MapGetItem") } }); -TVM_REGISTER_API("_MapCount") +TVM_REGISTER_GLOBAL("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast(args[0].value().v_handle); @@ -179,7 +181,7 @@ TVM_REGISTER_API("_MapCount") } }); -TVM_REGISTER_API("_MapItems") +TVM_REGISTER_GLOBAL("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kObjectHandle); Object* ptr = static_cast(args[0].value().v_handle); @@ -203,7 +205,7 @@ TVM_REGISTER_API("_MapItems") } }); -TVM_REGISTER_API("Range") +TVM_REGISTER_GLOBAL("Range") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = Range(0, args[0]); @@ -212,7 +214,7 @@ TVM_REGISTER_API("Range") } }); -TVM_REGISTER_API("_Buffer") +TVM_REGISTER_GLOBAL("_Buffer") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size(), 10); auto buffer_type = args[9].operator std::string(); @@ -221,105 +223,105 @@ TVM_REGISTER_API("_Buffer") args[5], args[6], args[7], args[8], type); }); -TVM_REGISTER_API("_BufferAccessPtr") +TVM_REGISTER_GLOBAL("_BufferAccessPtr") .set_body_method(&Buffer::access_ptr); -TVM_REGISTER_API("_BufferVLoad") +TVM_REGISTER_GLOBAL("_BufferVLoad") .set_body_method(&Buffer::vload); -TVM_REGISTER_API("_BufferVStore") +TVM_REGISTER_GLOBAL("_BufferVStore") .set_body_method(&Buffer::vstore); -TVM_REGISTER_API("_Layout") +TVM_REGISTER_GLOBAL("_Layout") .set_body_typed(LayoutNode::make); -TVM_REGISTER_API("_LayoutIndexOf") +TVM_REGISTER_GLOBAL("_LayoutIndexOf") .set_body_typed([](Layout layout, std::string axis) { return layout.IndexOf(LayoutAxis::make(axis)); }); -TVM_REGISTER_API("_LayoutFactorOf") +TVM_REGISTER_GLOBAL("_LayoutFactorOf") .set_body_typed([](Layout layout, std::string axis) { return layout.FactorOf(LayoutAxis::make(axis)); }); -TVM_REGISTER_API("_LayoutNdim") +TVM_REGISTER_GLOBAL("_LayoutNdim") .set_body_typed([](Layout layout) { return layout.ndim(); }); -TVM_REGISTER_API("_LayoutGetItem") +TVM_REGISTER_GLOBAL("_LayoutGetItem") .set_body_typed([](Layout layout, int idx) { const LayoutAxis& axis = layout[idx]; return axis.name(); }); -TVM_REGISTER_API("_BijectiveLayout") +TVM_REGISTER_GLOBAL("_BijectiveLayout") .set_body_typed(BijectiveLayoutNode::make); -TVM_REGISTER_API("_BijectiveLayoutForwardIndex") +TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardIndex") .set_body_method(&BijectiveLayout::ForwardIndex); -TVM_REGISTER_API("_BijectiveLayoutBackwardIndex") +TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardIndex") .set_body_method(&BijectiveLayout::BackwardIndex); -TVM_REGISTER_API("_BijectiveLayoutForwardShape") +TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape") .set_body_method(&BijectiveLayout::ForwardShape); -TVM_REGISTER_API("_BijectiveLayoutBackwardShape") +TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape") .set_body_method(&BijectiveLayout::BackwardShape); -TVM_REGISTER_API("_Tensor") +TVM_REGISTER_GLOBAL("_Tensor") .set_body_typed(TensorNode::make); -TVM_REGISTER_API("_TensorIntrin") +TVM_REGISTER_GLOBAL("_TensorIntrin") .set_body_typed(TensorIntrinNode::make); -TVM_REGISTER_API("_TensorIntrinCall") +TVM_REGISTER_GLOBAL("_TensorIntrinCall") .set_body_typed(TensorIntrinCallNode::make); -TVM_REGISTER_API("_TensorEqual") +TVM_REGISTER_GLOBAL("_TensorEqual") .set_body_method(&Tensor::operator==); -TVM_REGISTER_API("_TensorHash") +TVM_REGISTER_GLOBAL("_TensorHash") .set_body_typed([](Tensor tensor) { return static_cast(std::hash()(tensor)); }); -TVM_REGISTER_API("_Placeholder") +TVM_REGISTER_GLOBAL("_Placeholder") .set_body_typed, DataType, std::string)>([]( Array shape, DataType dtype, std::string name ) { return placeholder(shape, dtype, name); }); -TVM_REGISTER_API("_ComputeOp") +TVM_REGISTER_GLOBAL("_ComputeOp") .set_body_typed(ComputeOpNode::make); -TVM_REGISTER_API("_ScanOp") +TVM_REGISTER_GLOBAL("_ScanOp") .set_body_typed(ScanOpNode::make); -TVM_REGISTER_API("_TensorComputeOp") +TVM_REGISTER_GLOBAL("_TensorComputeOp") .set_body_typed(TensorComputeOpNode::make); -TVM_REGISTER_API("_ExternOp") +TVM_REGISTER_GLOBAL("_ExternOp") .set_body_typed(ExternOpNode::make); -TVM_REGISTER_API("_HybridOp") +TVM_REGISTER_GLOBAL("_HybridOp") .set_body_typed(HybridOpNode::make); -TVM_REGISTER_API("_OpGetOutput") +TVM_REGISTER_GLOBAL("_OpGetOutput") .set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); -TVM_REGISTER_API("_OpNumOutputs") +TVM_REGISTER_GLOBAL("_OpNumOutputs") .set_body_method(&OperationNode::num_outputs); -TVM_REGISTER_API("_OpInputTensors") +TVM_REGISTER_GLOBAL("_OpInputTensors") .set_body_method(&OperationNode::InputTensors); -TVM_REGISTER_API("_IterVar") +TVM_REGISTER_GLOBAL("_IterVar") .set_body_typed([]( Range dom, Var var, int iter_type, std::string thread_tag ) { @@ -329,16 +331,16 @@ TVM_REGISTER_API("_IterVar") thread_tag); }); -TVM_REGISTER_API("_CreateSchedule") +TVM_REGISTER_GLOBAL("_CreateSchedule") .set_body_typed(create_schedule); -TVM_REGISTER_API("_StageSetScope") +TVM_REGISTER_GLOBAL("_StageSetScope") .set_body_method(&Stage::set_scope); -TVM_REGISTER_API("_StageBind") +TVM_REGISTER_GLOBAL("_StageBind") .set_body_method(&Stage::bind); -TVM_REGISTER_API("_StageSplitByFactor") +TVM_REGISTER_GLOBAL("_StageSplitByFactor") .set_body_typed(Stage, IterVar, Expr)>([]( Stage stage, IterVar parent, Expr factor ) { @@ -347,7 +349,7 @@ TVM_REGISTER_API("_StageSplitByFactor") return Array({outer, inner}); }); -TVM_REGISTER_API("_StageSplitByNParts") +TVM_REGISTER_GLOBAL("_StageSplitByNParts") .set_body_typed(Stage, IterVar, Expr)>([]( Stage stage, IterVar parent, Expr nparts ) { @@ -356,26 +358,26 @@ TVM_REGISTER_API("_StageSplitByNParts") return Array({outer, inner}); }); -TVM_REGISTER_API("_StageFuse") +TVM_REGISTER_GLOBAL("_StageFuse") .set_body_typed)>([](Stage stage, Array axes) { IterVar fused; stage.fuse(axes, &fused); return fused; }); -TVM_REGISTER_API("_StageComputeAt") +TVM_REGISTER_GLOBAL("_StageComputeAt") .set_body_method(&Stage::compute_at); -TVM_REGISTER_API("_StageComputeInline") +TVM_REGISTER_GLOBAL("_StageComputeInline") .set_body_method(&Stage::compute_inline); -TVM_REGISTER_API("_StageComputeRoot") +TVM_REGISTER_GLOBAL("_StageComputeRoot") .set_body_method(&Stage::compute_root); -TVM_REGISTER_API("_StageReorder") +TVM_REGISTER_GLOBAL("_StageReorder") .set_body_method(&Stage::reorder); -TVM_REGISTER_API("_StageTile") +TVM_REGISTER_GLOBAL("_StageTile") .set_body_typed(Stage, IterVar, IterVar, Expr, Expr)>([]( Stage stage, IterVar x_parent, IterVar y_parent, @@ -389,49 +391,49 @@ TVM_REGISTER_API("_StageTile") return Array({x_outer, y_outer, x_inner, y_inner}); }); -TVM_REGISTER_API("_StageEnvThreads") +TVM_REGISTER_GLOBAL("_StageEnvThreads") .set_body_method(&Stage::env_threads); -TVM_REGISTER_API("_StageSetStorePredicate") +TVM_REGISTER_GLOBAL("_StageSetStorePredicate") .set_body_method(&Stage::set_store_predicate); -TVM_REGISTER_API("_StageUnroll") +TVM_REGISTER_GLOBAL("_StageUnroll") .set_body_method(&Stage::unroll); -TVM_REGISTER_API("_StageVectorize") +TVM_REGISTER_GLOBAL("_StageVectorize") .set_body_method(&Stage::vectorize); -TVM_REGISTER_API("_StageTensorize") +TVM_REGISTER_GLOBAL("_StageTensorize") .set_body_method(&Stage::tensorize); -TVM_REGISTER_API("_StageParallel") +TVM_REGISTER_GLOBAL("_StageParallel") .set_body_method(&Stage::parallel); -TVM_REGISTER_API("_StagePragma") +TVM_REGISTER_GLOBAL("_StagePragma") .set_body_method(&Stage::pragma); -TVM_REGISTER_API("_StagePrefetch") +TVM_REGISTER_GLOBAL("_StagePrefetch") .set_body_method(&Stage::prefetch); -TVM_REGISTER_API("_StageStorageAlign") +TVM_REGISTER_GLOBAL("_StageStorageAlign") .set_body_method(&Stage::storage_align); -TVM_REGISTER_API("_StageDoubleBuffer") +TVM_REGISTER_GLOBAL("_StageDoubleBuffer") .set_body_method(&Stage::double_buffer); -TVM_REGISTER_API("_StageOpenGL") +TVM_REGISTER_GLOBAL("_StageOpenGL") .set_body_method(&Stage::opengl); -TVM_REGISTER_API("_ScheduleNormalize") +TVM_REGISTER_GLOBAL("_ScheduleNormalize") .set_body_method(&Schedule::normalize); -TVM_REGISTER_API("_ScheduleCreateGroup") +TVM_REGISTER_GLOBAL("_ScheduleCreateGroup") .set_body_method(&Schedule::create_group); -TVM_REGISTER_API("_ScheduleCacheRead") +TVM_REGISTER_GLOBAL("_ScheduleCacheRead") .set_body_method(&Schedule::cache_read); -TVM_REGISTER_API("_ScheduleCacheWrite") +TVM_REGISTER_GLOBAL("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args[1].IsObjectRef()) { *ret = args[0].operator Schedule() @@ -442,10 +444,10 @@ TVM_REGISTER_API("_ScheduleCacheWrite") } }); -TVM_REGISTER_API("_ScheduleRFactor") +TVM_REGISTER_GLOBAL("_ScheduleRFactor") .set_body_method(&Schedule::rfactor); -TVM_REGISTER_API("_CommReducerCombine") +TVM_REGISTER_GLOBAL("_CommReducerCombine") .set_body_method(&ir::CommReducerNode::operator()); } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index f1d97f4..7390e8b 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -26,12 +26,14 @@ #include #include #include -#include +#include +#include + namespace tvm { namespace ir { -TVM_REGISTER_API("ir_pass.Simplify") +TVM_REGISTER_GLOBAL("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { if (args.size() > 1) { @@ -48,7 +50,7 @@ TVM_REGISTER_API("ir_pass.Simplify") } }); -TVM_REGISTER_API("ir_pass.CanonicalSimplify") +TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { if (args.size() > 1) { @@ -65,7 +67,7 @@ TVM_REGISTER_API("ir_pass.CanonicalSimplify") } }); -TVM_REGISTER_API("ir_pass.Substitute") +TVM_REGISTER_GLOBAL("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); @@ -74,7 +76,7 @@ TVM_REGISTER_API("ir_pass.Substitute") } }); -TVM_REGISTER_API("ir_pass.Equal") +TVM_REGISTER_GLOBAL("ir_pass.Equal") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); @@ -83,7 +85,7 @@ TVM_REGISTER_API("ir_pass.Equal") } }); -TVM_REGISTER_API("ir_pass.StorageFlatten") +TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() <= 3) { *ret = StorageFlatten(args[0], args[1], args[2]); @@ -92,30 +94,30 @@ TVM_REGISTER_API("ir_pass.StorageFlatten") } }); -TVM_REGISTER_API("ir_pass.RewriteForTensorCore") +TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") .set_body_typed&)> ([](const Stmt& stmt, const Schedule& schedule, const Map& extern_buffer) { return RewriteForTensorCore(stmt, schedule, extern_buffer); }); -TVM_REGISTER_API("ir_pass.AttrsEqual") +TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual") .set_body_typed( [](const ObjectRef& lhs, const ObjectRef& rhs) { return AttrsEqual()(lhs, rhs); }); -TVM_REGISTER_API("ir_pass.AttrsHash") +TVM_REGISTER_GLOBAL("ir_pass.AttrsHash") .set_body_typed([](const ObjectRef &node) { return AttrsHash()(node); }); -TVM_REGISTER_API("ir_pass.ExprUseVar") +TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var()); }); -TVM_REGISTER_API("ir_pass.PostOrderVisit") +TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc f = args[1]; ir::PostOrderVisit(args[0], [f](const ObjectRef& n) { @@ -123,7 +125,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") }); }); -TVM_REGISTER_API("ir_pass.LowerStorageAccess") +TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess") .set_body([](TVMArgs args, TVMRetValue *ret) { LoweredFunc f = args[0]; auto n = make_object(*f.operator->()); @@ -133,7 +135,7 @@ TVM_REGISTER_API("ir_pass.LowerStorageAccess") // make from two arguments #define REGISTER_PASS(PassName) \ - TVM_REGISTER_API("ir_pass."#PassName) \ + TVM_REGISTER_GLOBAL("ir_pass."#PassName) \ .set_body_typed(PassName); \ diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index cf0e0f3..a7c27e4 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -25,20 +25,22 @@ #include #include #include -#include +#include +#include + #include "../schedule/graph.h" namespace tvm { namespace schedule { -TVM_REGISTER_API("schedule.AutoInlineElemWise") +TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise") .set_body_typed(AutoInlineElemWise); -TVM_REGISTER_API("schedule.AutoInlineInjective") +TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective") .set_body_typed(AutoInlineInjective); -TVM_REGISTER_API("schedule.ScheduleOps") +TVM_REGISTER_GLOBAL("schedule.ScheduleOps") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 2) *ret = ScheduleOps(args[0], args[1], false); @@ -47,7 +49,7 @@ TVM_REGISTER_API("schedule.ScheduleOps") }); #define REGISTER_SCHEDULE_PASS(PassName) \ - TVM_REGISTER_API("schedule."#PassName) \ + TVM_REGISTER_GLOBAL("schedule."#PassName) \ .set_body_typed(PassName); \ diff --git a/src/api/api_test.cc b/src/api/api_test.cc index 3900b56..d57a4e9 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,7 +24,10 @@ #include #include #include -#include +#include +#include +#include + namespace tvm { // Attrs used to python API @@ -53,11 +56,11 @@ struct TestAttrs : public AttrsNode { TVM_REGISTER_NODE_TYPE(TestAttrs); -TVM_REGISTER_API("_nop") +TVM_REGISTER_GLOBAL("_nop") .set_body([](TVMArgs args, TVMRetValue *ret) { }); -TVM_REGISTER_API("_test_wrap_callback") +TVM_REGISTER_GLOBAL("_test_wrap_callback") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc pf = args[0]; *ret = runtime::TypedPackedFunc([pf](){ @@ -65,7 +68,7 @@ TVM_REGISTER_API("_test_wrap_callback") }); }); -TVM_REGISTER_API("_test_raise_error_callback") +TVM_REGISTER_GLOBAL("_test_raise_error_callback") .set_body([](TVMArgs args, TVMRetValue *ret) { std::string msg = args[0]; *ret = runtime::TypedPackedFunc([msg](){ @@ -73,7 +76,7 @@ TVM_REGISTER_API("_test_raise_error_callback") }); }); -TVM_REGISTER_API("_test_check_eq_callback") +TVM_REGISTER_GLOBAL("_test_check_eq_callback") .set_body([](TVMArgs args, TVMRetValue *ret) { std::string msg = args[0]; *ret = runtime::TypedPackedFunc([msg](int x, int y){ @@ -81,7 +84,7 @@ TVM_REGISTER_API("_test_check_eq_callback") }); }); -TVM_REGISTER_API("_context_test") +TVM_REGISTER_GLOBAL("_context_test") .set_body([](TVMArgs args, TVMRetValue *ret) { DLContext ctx = args[0]; int dtype = args[1]; @@ -102,11 +105,11 @@ void ErrorTest(int x, int y) { } } -TVM_REGISTER_API("_ErrorTest") +TVM_REGISTER_GLOBAL("_ErrorTest") .set_body_typed(ErrorTest); // internal function used for debug and testing purposes -TVM_REGISTER_API("_ndarray_use_count") +TVM_REGISTER_GLOBAL("_ndarray_use_count") .set_body([](TVMArgs args, TVMRetValue *ret) { runtime::NDArray nd = args[0]; // substract the current one diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 6f98017..bb2e340 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -25,7 +25,9 @@ #include #include #include -#include +#include +#include + #include #include diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc index bdd5daa..02f3578 100644 --- a/src/arithmetic/domain_touched.cc +++ b/src/arithmetic/domain_touched.cc @@ -25,7 +25,9 @@ #include #include #include -#include +#include +#include + #include #include diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 79b3974..bdfcc1a 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -23,7 +23,9 @@ */ #include #include -#include +#include +#include + #include #include #include @@ -47,7 +49,7 @@ IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { return IntervalSet(min_value, max_value); } -TVM_REGISTER_API("arith._make_IntervalSet") +TVM_REGISTER_GLOBAL("arith._make_IntervalSet") .set_body_typed(MakeIntervalSet); diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index 31c8035..51b1354 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -485,7 +485,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // register API for front end -TVM_REGISTER_API("autotvm.feature.GetItervarFeature") +TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature") .set_body([](TVMArgs args, TVMRetValue *ret) { Stmt stmt = args[0]; bool take_log = args[1]; @@ -497,7 +497,7 @@ TVM_REGISTER_API("autotvm.feature.GetItervarFeature") }); -TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten") +TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeatureFlatten") .set_body([](TVMArgs args, TVMRetValue *ret) { Stmt stmt = args[0]; bool take_log = args[1]; @@ -512,7 +512,7 @@ TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten") }); -TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten") +TVM_REGISTER_GLOBAL("autotvm.feature.GetCurveSampleFeatureFlatten") .set_body([](TVMArgs args, TVMRetValue *ret) { Stmt stmt = args[0]; int sample_n = args[1]; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index b456e4b..2bcf6b8 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -27,7 +27,9 @@ #include #include -#include +#include +#include + #include #include #include diff --git a/src/codegen/build_common.h b/src/codegen/build_common.h index b2c8953..47f70d9 100644 --- a/src/codegen/build_common.h +++ b/src/codegen/build_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #define TVM_CODEGEN_BUILD_COMMON_H_ #include +#include #include #include #include diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 3ea2cb7..eab220e 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,10 @@ namespace tvm { +using runtime::TVMArgs; +using runtime::TVMRetValue; +using runtime::PackedFunc; + TVM_REGISTER_NODE_TYPE(TargetNode); TVM_REGISTER_NODE_TYPE(GenericFuncNode); @@ -142,7 +147,7 @@ Target CreateTarget(const std::string& target_name, return Target(t); } -TVM_REGISTER_API("_TargetCreate") +TVM_REGISTER_GLOBAL("_TargetCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_name = args[0]; std::vector options; @@ -154,7 +159,7 @@ TVM_REGISTER_API("_TargetCreate") *ret = CreateTarget(target_name, options); }); -TVM_REGISTER_API("_TargetFromString") +TVM_REGISTER_GLOBAL("_TargetFromString") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; *ret = Target::Create(target_str); @@ -768,7 +773,7 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { func.CallPacked(args, ret); } -TVM_REGISTER_API("_GetCurrentBuildConfig") +TVM_REGISTER_GLOBAL("_GetCurrentBuildConfig") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = BuildConfig::Current(); }); @@ -783,13 +788,13 @@ class BuildConfig::Internal { } }; -TVM_REGISTER_API("_EnterBuildConfigScope") +TVM_REGISTER_GLOBAL("_EnterBuildConfigScope") .set_body_typed(BuildConfig::Internal::EnterScope); -TVM_REGISTER_API("_ExitBuildConfigScope") +TVM_REGISTER_GLOBAL("_ExitBuildConfigScope") .set_body_typed(BuildConfig::Internal::ExitScope); -TVM_REGISTER_API("_BuildConfigSetAddLowerPass") +TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass") .set_body([](TVMArgs args, TVMRetValue* ret) { BuildConfig cfg = args[0]; std::vector< std::pair > add_lower_pass; @@ -802,7 +807,7 @@ TVM_REGISTER_API("_BuildConfigSetAddLowerPass") cfg->add_lower_pass = add_lower_pass; }); -TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo") +TVM_REGISTER_GLOBAL("_BuildConfigGetAddLowerPassInfo") .set_body([](TVMArgs args, TVMRetValue* ret) { // Return one of the following: // * Size of add_lower_pass if num_args == 1 @@ -823,18 +828,18 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo") } }); -TVM_REGISTER_API("_GenericFuncCreate") +TVM_REGISTER_GLOBAL("_GenericFuncCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = GenericFunc(make_object()); }); -TVM_REGISTER_API("_GenericFuncGetGlobal") +TVM_REGISTER_GLOBAL("_GenericFuncGetGlobal") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string func_name = args[0]; *ret = GenericFunc::Get(func_name); }); -TVM_REGISTER_API("_GenericFuncSetDefault") +TVM_REGISTER_GLOBAL("_GenericFuncSetDefault") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown @@ -845,7 +850,7 @@ TVM_REGISTER_API("_GenericFuncSetDefault") .set_default(*func, allow_override); }); -TVM_REGISTER_API("_GenericFuncRegisterFunc") +TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown @@ -862,7 +867,7 @@ TVM_REGISTER_API("_GenericFuncRegisterFunc") .register_func(tags_vector, *func, allow_override); }); -TVM_REGISTER_API("_GenericFuncCallFunc") +TVM_REGISTER_GLOBAL("_GenericFuncCallFunc") .set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); @@ -871,7 +876,7 @@ TVM_REGISTER_API("_GenericFuncCallFunc") .CallPacked(func_args, ret); }); -TVM_REGISTER_API("_GetCurrentTarget") +TVM_REGISTER_GLOBAL("_GetCurrentTarget") .set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; *ret = Target::Current(allow_not_defined); @@ -887,10 +892,10 @@ class Target::Internal { } }; -TVM_REGISTER_API("_EnterTargetScope") +TVM_REGISTER_GLOBAL("_EnterTargetScope") .set_body_typed(Target::Internal::EnterScope); -TVM_REGISTER_API("_ExitTargetScope") +TVM_REGISTER_GLOBAL("_ExitTargetScope") .set_body_typed(Target::Internal::ExitScope); } // namespace tvm diff --git a/src/codegen/codegen_aocl.cc b/src/codegen/codegen_aocl.cc index 625682b..ea3677e 100644 --- a/src/codegen/codegen_aocl.cc +++ b/src/codegen/codegen_aocl.cc @@ -71,12 +71,12 @@ runtime::Module BuildAOCL(Array funcs, std::string target_str, return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(funcs), code); } -TVM_REGISTER_API("codegen.build_aocl") +TVM_REGISTER_GLOBAL("codegen.build_aocl") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildAOCL(args[0], args[1], false); }); -TVM_REGISTER_API("codegen.build_aocl_sw_emu") +TVM_REGISTER_GLOBAL("codegen.build_aocl_sw_emu") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildAOCL(args[0], args[1], true); }); diff --git a/src/codegen/codegen_c_host.cc b/src/codegen/codegen_c_host.cc index f2c54c2..5066182 100644 --- a/src/codegen/codegen_c_host.cc +++ b/src/codegen/codegen_c_host.cc @@ -290,7 +290,7 @@ runtime::Module BuildCHost(Array funcs) { return CSourceModuleCreate(code, "c"); } -TVM_REGISTER_API("codegen.build_c") +TVM_REGISTER_GLOBAL("codegen.build_c") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildCHost(args[0]); }); diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc index f4ff014..b239578 100644 --- a/src/codegen/codegen_metal.cc +++ b/src/codegen/codegen_metal.cc @@ -277,7 +277,7 @@ runtime::Module BuildMetal(Array funcs) { return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source); } -TVM_REGISTER_API("codegen.build_metal") +TVM_REGISTER_GLOBAL("codegen.build_metal") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildMetal(args[0]); }); diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index ae43419..e466e28 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -275,7 +275,7 @@ runtime::Module BuildOpenCL(Array funcs) { return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs), code); } -TVM_REGISTER_API("codegen.build_opencl") +TVM_REGISTER_GLOBAL("codegen.build_opencl") .set_body_typed(BuildOpenCL); } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index db14be3..29fcf85 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -300,7 +300,7 @@ runtime::Module BuildOpenGL(Array funcs) { return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs)); } -TVM_REGISTER_API("codegen.build_opengl") +TVM_REGISTER_GLOBAL("codegen.build_opengl") .set_body_typed(BuildOpenGL); } // namespace codegen diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index 40550d9..d12e54d 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -162,7 +162,7 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), whole_code); } -TVM_REGISTER_API("codegen.build_sdaccel") +TVM_REGISTER_GLOBAL("codegen.build_sdaccel") .set_body_typed(BuildSDAccel); } // namespace codegen diff --git a/src/codegen/datatype/registry.cc b/src/codegen/datatype/registry.cc index 28cc582..62d36c4 100644 --- a/src/codegen/datatype/registry.cc +++ b/src/codegen/datatype/registry.cc @@ -18,7 +18,9 @@ */ #include "registry.h" -#include +#include +#include + namespace tvm { namespace datatype { diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h index 581387d..f64887e 100644 --- a/src/codegen/intrin_rule.h +++ b/src/codegen/intrin_rule.h @@ -26,7 +26,9 @@ #include #include -#include +#include +#include + #include #include diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index f57a3ca..a2b3685 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -311,7 +311,7 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly); } -TVM_REGISTER_API("codegen.build_rocm") +TVM_REGISTER_GLOBAL("codegen.build_rocm") .set_body_typed(BuildAMDGPU); } // namespace codegen diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc index 372408c..a0caf65 100644 --- a/src/codegen/llvm/codegen_nvptx.cc +++ b/src/codegen/llvm/codegen_nvptx.cc @@ -253,7 +253,7 @@ runtime::Module BuildNVPTX(Array funcs, std::string target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll); } -TVM_REGISTER_API("codegen.build_nvptx") +TVM_REGISTER_GLOBAL("codegen.build_nvptx") .set_body_typed(BuildNVPTX); } // namespace codegen diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index 7863a3d..0d65576 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -26,7 +26,9 @@ #ifdef TVM_LLVM_VERSION #include -#include +#include +#include + #include #include #include "llvm_common.h" diff --git a/src/codegen/llvm/intrin_rule_nvptx.cc b/src/codegen/llvm/intrin_rule_nvptx.cc index 862d06b..2f0e5c5 100644 --- a/src/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/codegen/llvm/intrin_rule_nvptx.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,7 +24,9 @@ #include #include -#include +#include +#include + #include namespace tvm { diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index 22b3245..380f9a9 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,7 +24,9 @@ #include #include -#include +#include +#include + #include namespace tvm { diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index d874b46..e042081 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -329,33 +329,33 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { return llvm::Function::lookupIntrinsicID(name); } -TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id") +TVM_REGISTER_GLOBAL("codegen.llvm_lookup_intrinsic_id") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = static_cast(LookupLLVMIntrinsic(args[0])); }); -TVM_REGISTER_API("codegen.build_llvm") +TVM_REGISTER_GLOBAL("codegen.build_llvm") .set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); n->Init(args[0], args[1]); *rv = runtime::Module(n); }); -TVM_REGISTER_API("codegen.llvm_version_major") +TVM_REGISTER_GLOBAL("codegen.llvm_version_major") .set_body([](TVMArgs args, TVMRetValue* rv) { std::ostringstream os; int major = TVM_LLVM_VERSION / 10; *rv = major; }); -TVM_REGISTER_API("module.loadfile_ll") +TVM_REGISTER_GLOBAL("module.loadfile_ll") .set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); n->LoadIR(args[0]); *rv = runtime::Module(n); }); -TVM_REGISTER_API("codegen.llvm_target_enabled") +TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body([](TVMArgs args, TVMRetValue* rv) { InitializeLLVM(); *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); diff --git a/src/codegen/opt/build_cuda_on.cc b/src/codegen/opt/build_cuda_on.cc index 1992ac5..b5f42bf 100644 --- a/src/codegen/opt/build_cuda_on.cc +++ b/src/codegen/opt/build_cuda_on.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -154,7 +154,7 @@ runtime::Module BuildCUDA(Array funcs) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code); } -TVM_REGISTER_API("codegen.build_cuda") +TVM_REGISTER_GLOBAL("codegen.build_cuda") .set_body_typed(BuildCUDA); } // namespace codegen } // namespace tvm diff --git a/src/codegen/spirv/build_vulkan.cc b/src/codegen/spirv/build_vulkan.cc index fb66d89..6c90e1d 100644 --- a/src/codegen/spirv/build_vulkan.cc +++ b/src/codegen/spirv/build_vulkan.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -103,7 +103,7 @@ runtime::Module BuildSPIRV(Array funcs) { smap, ExtractFuncInfo(funcs), code_data.str()); } -TVM_REGISTER_API("codegen.build_vulkan") +TVM_REGISTER_GLOBAL("codegen.build_vulkan") .set_body_typed(BuildSPIRV); } // namespace codegen diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 52cabaf..9482b2c 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -520,7 +520,7 @@ runtime::Module BuildStackVM(const Array& funcs) { return runtime::StackVMModuleCreate(fmap, funcs[0]->name); } -TVM_REGISTER_API("codegen.build_stackvm") +TVM_REGISTER_GLOBAL("codegen.build_stackvm") .set_body_typed(BuildStackVM); } // namespace codegen } // namespace tvm diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index c723a22..beda99d 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -27,6 +27,9 @@ namespace tvm { namespace contrib { +using runtime::TVMArgs; +using runtime::TVMRetValue; + using namespace ir; std::string dot_to_underscore(std::string s) { diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 1c341d5..fd28268 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -21,7 +21,9 @@ * \file attrs.cc */ #include -#include +#include +#include + #include "attr_functor.h" namespace tvm { @@ -345,7 +347,7 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const { return equal(this->dict, static_cast(other)->dict); } -TVM_REGISTER_API("_AttrsListFieldInfo") +TVM_REGISTER_GLOBAL("_AttrsListFieldInfo") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Attrs()->ListFieldInfo(); }); diff --git a/src/lang/api_registry.cc b/src/node/env_func.cc similarity index 87% rename from src/lang/api_registry.cc rename to src/node/env_func.cc index 68d42a2..52bb61d 100644 --- a/src/lang/api_registry.cc +++ b/src/node/env_func.cc @@ -18,12 +18,19 @@ */ /*! - * \file api_registry.cc + * \file env_func.cc */ -#include +#include +#include +#include namespace tvm { +using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; + + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& node, IRPrinter *p) { auto* op = static_cast(node.get()); @@ -43,10 +50,10 @@ EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_REGISTER_API("_EnvFuncGet") +TVM_REGISTER_GLOBAL("_EnvFuncGet") .set_body_typed(EnvFunc::Get); -TVM_REGISTER_API("_EnvFuncCall") +TVM_REGISTER_GLOBAL("_EnvFuncCall") .set_body([](TVMArgs args, TVMRetValue* rv) { EnvFunc env = args[0]; CHECK_GE(args.size(), 1); @@ -55,7 +62,7 @@ TVM_REGISTER_API("_EnvFuncCall") args.size() - 1), rv); }); -TVM_REGISTER_API("_EnvFuncGetPackedFunc") +TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc") .set_body_typed([](const EnvFunc&n) { return n->func; }); diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index dba9ca0..f6fa00d 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -24,7 +24,9 @@ #include #include #include -#include +#include +#include + #include "op_util.h" #include "compute_op.h" #include "../schedule/message_passing.h" @@ -496,7 +498,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, } // Register functions for unittests -TVM_REGISTER_API("test.op.InferTensorizeRegion") +TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion") .set_body([](TVMArgs args, TVMRetValue* ret) { Stage stage = args[0]; Map dmap = args[1]; @@ -511,7 +513,7 @@ TVM_REGISTER_API("test.op.InferTensorizeRegion") Map >(in_region)}; }); -TVM_REGISTER_API("test.op.MatchTensorizeBody") +TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody") .set_body([](TVMArgs args, TVMRetValue* ret) { Stage stage = args[0]; Map out_dom = args[1]; diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index 7a50113..5748e9f 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -23,7 +23,9 @@ #include #include #include -#include +#include +#include + #include #include #include diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index bbd6c35..0f49710 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -23,7 +23,9 @@ */ #include #include -#include +#include +#include + #include #include #include "ir_util.h" @@ -292,7 +294,7 @@ LowerIntrin(LoweredFunc f, const std::string& target) { } // Register the api only for test purposes -TVM_REGISTER_API("ir_pass._LowerIntrinStmt") +TVM_REGISTER_GLOBAL("ir_pass._LowerIntrinStmt") .set_body_typed(LowerIntrinStmt); } // namespace ir diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 1adc685..08ec413 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -24,7 +24,9 @@ * in a block exceeds the limit */ -#include +#include +#include + #include #include diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 68a3bed..ae993e9 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -76,7 +76,7 @@ bool IsDynamic(const Type& ty) { } // TODO(@jroesch): MOVE ME -TVM_REGISTER_API("relay._make.IsDynamic") +TVM_REGISTER_GLOBAL("relay._make.IsDynamic") .set_body_typed(IsDynamic); Array GetShape(const Array& shape) { diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 84fada0..642dbb0 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -219,7 +219,7 @@ runtime::Module CCompiler(const ObjectRef& ref) { return csource.CreateCSourceModule(ref); } -TVM_REGISTER_API("relay.ext.ccompiler").set_body_typed(CCompiler); +TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 675198f..4c0fe34 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -303,7 +303,7 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) { return dnnl.CreateCSourceModule(ref); } -TVM_REGISTER_API("relay.ext.dnnl").set_body_typed(DNNLCompiler); +TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler); } // namespace contrib } // namespace relay diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index b477784..203fbfa 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -51,7 +51,7 @@ Closure ClosureNode::make(tvm::Map env, Function func) { return Closure(n); } -TVM_REGISTER_API("relay._make.Closure") +TVM_REGISTER_GLOBAL("relay._make.Closure") .set_body_typed(ClosureNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -70,7 +70,7 @@ RecClosure RecClosureNode::make(Closure clos, Var bind) { return RecClosure(n); } -TVM_REGISTER_API("relay._make.RecClosure") +TVM_REGISTER_GLOBAL("relay._make.RecClosure") .set_body_typed(RecClosureNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -85,7 +85,7 @@ TupleValue TupleValueNode::make(tvm::Array value) { return TupleValue(n); } -TVM_REGISTER_API("relay._make.TupleValue") +TVM_REGISTER_GLOBAL("relay._make.TupleValue") .set_body_typed(TupleValueNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -108,7 +108,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "TensorValueNode(" << data_str << ")"; }); -TVM_REGISTER_API("relay._make.TensorValue") +TVM_REGISTER_GLOBAL("relay._make.TensorValue") .set_body_typed(TensorValueNode::make); RefValue RefValueNode::make(Value value) { @@ -117,7 +117,7 @@ RefValue RefValueNode::make(Value value) { return RefValue(n); } -TVM_REGISTER_API("relay._make.RefValue") +TVM_REGISTER_GLOBAL("relay._make.RefValue") .set_body_typed(RefValueNode::make); TVM_REGISTER_NODE_TYPE(RefValueNode); @@ -138,7 +138,7 @@ ConstructorValue ConstructorValueNode::make(int32_t tag, return ConstructorValue(n); } -TVM_REGISTER_API("relay._make.ConstructorValue") +TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") .set_body_typed(ConstructorValueNode::make); TVM_REGISTER_NODE_TYPE(ConstructorValueNode); @@ -817,7 +817,7 @@ CreateInterpreter( return TypedPackedFunc(packed); } -TVM_REGISTER_API("relay.backend.CreateInterpreter") +TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") .set_body_typed(CreateInterpreter); TVM_REGISTER_NODE_TYPE(ClosureNode); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index f94f837..25b0735 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -142,7 +142,7 @@ Pass InlinePrimitives() { return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives"); } -TVM_REGISTER_API("relay._transform.InlinePrimitives") +TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives") .set_body_typed(InlinePrimitives); } // namespace transform diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 7298c50..b6cb1aa 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -217,7 +217,7 @@ Pass LambdaLift() { return CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_REGISTER_API("relay._transform.LambdaLift") +TVM_REGISTER_GLOBAL("relay._transform.LambdaLift") .set_body_typed(LambdaLift); } // namespace transform diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index 546f1d3..c6fe490 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -129,7 +129,7 @@ Pass RemoveUnusedFunctions(Array entry_functions) { return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); } -TVM_REGISTER_API("relay._transform.RemoveUnusedFunctions") +TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions") .set_body_typed(RemoveUnusedFunctions); } // namespace transform diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 7317287..ff47789 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -34,7 +34,7 @@ PatternWildcard PatternWildcardNode::make() { TVM_REGISTER_NODE_TYPE(PatternWildcardNode); -TVM_REGISTER_API("relay._make.PatternWildcard") +TVM_REGISTER_GLOBAL("relay._make.PatternWildcard") .set_body_typed(PatternWildcardNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -50,7 +50,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) { TVM_REGISTER_NODE_TYPE(PatternVarNode); -TVM_REGISTER_API("relay._make.PatternVar") +TVM_REGISTER_GLOBAL("relay._make.PatternVar") .set_body_typed(PatternVarNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -69,7 +69,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor, TVM_REGISTER_NODE_TYPE(PatternConstructorNode); -TVM_REGISTER_API("relay._make.PatternConstructor") +TVM_REGISTER_GLOBAL("relay._make.PatternConstructor") .set_body_typed(PatternConstructorNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -87,7 +87,7 @@ PatternTuple PatternTupleNode::make(tvm::Array patterns) { TVM_REGISTER_NODE_TYPE(PatternTupleNode); -TVM_REGISTER_API("relay._make.PatternTuple") +TVM_REGISTER_GLOBAL("relay._make.PatternTuple") .set_body_typed(PatternTupleNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -108,7 +108,7 @@ Constructor ConstructorNode::make(std::string name_hint, TVM_REGISTER_NODE_TYPE(ConstructorNode); -TVM_REGISTER_API("relay._make.Constructor") +TVM_REGISTER_GLOBAL("relay._make.Constructor") .set_body_typed(ConstructorNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -130,7 +130,7 @@ TypeData TypeDataNode::make(GlobalTypeVar header, TVM_REGISTER_NODE_TYPE(TypeDataNode); -TVM_REGISTER_API("relay._make.TypeData") +TVM_REGISTER_GLOBAL("relay._make.TypeData") .set_body_typed(TypeDataNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -149,7 +149,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) { TVM_REGISTER_NODE_TYPE(ClauseNode); -TVM_REGISTER_API("relay._make.Clause") +TVM_REGISTER_GLOBAL("relay._make.Clause") .set_body_typed(ClauseNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -169,7 +169,7 @@ Match MatchNode::make(Expr data, tvm::Array clauses, bool complete) { TVM_REGISTER_NODE_TYPE(MatchNode); -TVM_REGISTER_API("relay._make.Match") +TVM_REGISTER_GLOBAL("relay._make.Match") .set_body_typed(MatchNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index d8dcddd..7fe39db 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -594,23 +594,23 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { } // TODO(@jroesch): move to correct namespace? -TVM_REGISTER_API("relay._make._alpha_equal") +TVM_REGISTER_GLOBAL("relay._make._alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(false, false).Equal(a, b); }); -TVM_REGISTER_API("relay._make._assert_alpha_equal") +TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; }); -TVM_REGISTER_API("relay._make._graph_equal") +TVM_REGISTER_GLOBAL("relay._make._graph_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(true, false).Equal(a, b); }); -TVM_REGISTER_API("relay._make._assert_graph_equal") +TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 3f98d87..176ee08 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -21,8 +21,10 @@ * \file base.cc * \brief The core base types for Relay. */ -#include + #include +#include +#include #include namespace tvm { @@ -32,7 +34,7 @@ using namespace tvm::runtime; TVM_REGISTER_NODE_TYPE(IdNode); -TVM_REGISTER_API("relay._base.set_span") +TVM_REGISTER_GLOBAL("relay._base.set_span") .set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { CHECK(rn); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 66e083d..11689b0 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -37,7 +37,7 @@ Constant ConstantNode::make(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_API("relay._make.Constant") +TVM_REGISTER_GLOBAL("relay._make.Constant") .set_body_typed(ConstantNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -70,7 +70,7 @@ Tuple TupleNode::make(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_API("relay._make.Tuple") +TVM_REGISTER_GLOBAL("relay._make.Tuple") .set_body_typed(TupleNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -95,7 +95,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_API("relay._make.Var") +TVM_REGISTER_GLOBAL("relay._make.Var") .set_body_typed(static_cast(VarNode::make)); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -117,7 +117,7 @@ GlobalVar GlobalVarNode::make(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_API("relay._make.GlobalVar") +TVM_REGISTER_GLOBAL("relay._make.GlobalVar") .set_body_typed(GlobalVarNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -166,7 +166,7 @@ Function FunctionNode::SetParams(const tvm::Map& parameters) cons return FunctionSetAttr(GetRef(this), attr::kParams, parameters); } -TVM_REGISTER_API("relay._expr.FunctionSetParams") +TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams") .set_body_typed&)>( [](const Function& func, const tvm::Map& parameters) { return func->SetParams(parameters); @@ -177,7 +177,7 @@ tvm::Map FunctionNode::GetParams() const { return Downcast>(node_ref); } -TVM_REGISTER_API("relay._expr.FunctionGetParams") +TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams") .set_body_typed(const Function&)>([](const Function& func) { return func->GetParams(); }); @@ -223,7 +223,7 @@ Function FunctionSetAttr(const Function& func, const std::string& key, const Obj TVM_REGISTER_NODE_TYPE(FunctionNode); -TVM_REGISTER_API("relay._make.Function") +TVM_REGISTER_GLOBAL("relay._make.Function") .set_body_typed(FunctionNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -246,7 +246,7 @@ Call CallNode::make(Expr op, Array args, Attrs attrs, TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_API("relay._make.Call") +TVM_REGISTER_GLOBAL("relay._make.Call") .set_body_typed(CallNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -266,7 +266,7 @@ Let LetNode::make(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_API("relay._make.Let") +TVM_REGISTER_GLOBAL("relay._make.Let") .set_body_typed(LetNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -286,7 +286,7 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); -TVM_REGISTER_API("relay._make.If") +TVM_REGISTER_GLOBAL("relay._make.If") .set_body_typed(IfNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -305,7 +305,7 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_API("relay._make.TupleGetItem") +TVM_REGISTER_GLOBAL("relay._make.TupleGetItem") .set_body_typed(TupleGetItemNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -322,7 +322,7 @@ RefCreate RefCreateNode::make(Expr value) { TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_API("relay._make.RefCreate") +TVM_REGISTER_GLOBAL("relay._make.RefCreate") .set_body_typed(RefCreateNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -339,7 +339,7 @@ RefRead RefReadNode::make(Expr ref) { TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_API("relay._make.RefRead") +TVM_REGISTER_GLOBAL("relay._make.RefRead") .set_body_typed(RefReadNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -357,7 +357,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_API("relay._make.RefWrite") +TVM_REGISTER_GLOBAL("relay._make.RefWrite") .set_body_typed(RefWriteNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -366,12 +366,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); -TVM_REGISTER_API("relay._expr.TempExprRealize") +TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") .set_body_typed([](TempExpr temp) { return temp->Realize(); }); -TVM_REGISTER_API("relay._expr.FunctionSetAttr") +TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") .set_body_typed( [](Function func, std::string name, ObjectRef ref) { return FunctionSetAttr(func, name, ref); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index e3846c9..d6e4d41 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -347,7 +347,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_API("relay._analysis.post_order_visit") +TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit") .set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); @@ -443,7 +443,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } -TVM_REGISTER_API("relay._expr.Bind") +TVM_REGISTER_GLOBAL("relay._expr.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef input = args[0]; if (input->IsInstance()) { diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 6199c54..b940666 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -423,12 +423,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { return RelayHashHandler().ExprHash(expr); } -TVM_REGISTER_API("relay._analysis._expr_hash") +TVM_REGISTER_GLOBAL("relay._analysis._expr_hash") .set_body_typed([](ObjectRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); -TVM_REGISTER_API("relay._analysis._type_hash") +TVM_REGISTER_GLOBAL("relay._analysis._type_hash") .set_body_typed([](Type type) { return static_cast(RelayHashHandler().TypeHash(type)); }); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 9f371dd..4e57258 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -317,13 +317,13 @@ Module FromText(const std::string& source, const std::string& source_name) { TVM_REGISTER_NODE_TYPE(ModuleNode); -TVM_REGISTER_API("relay._make.Module") +TVM_REGISTER_GLOBAL("relay._make.Module") .set_body_typed, tvm::Map)>( [](tvm::Map funcs, tvm::Map types) { return ModuleNode::make(funcs, types, {}); }); -TVM_REGISTER_API("relay._module.Module_Add") +TVM_REGISTER_GLOBAL("relay._module.Module_Add") .set_body([](TVMArgs args, TVMRetValue* ret) { Module mod = args[0]; GlobalVar var = args[1]; @@ -346,50 +346,50 @@ TVM_REGISTER_API("relay._module.Module_Add") *ret = mod; }); -TVM_REGISTER_API("relay._module.Module_AddDef") +TVM_REGISTER_GLOBAL("relay._module.Module_AddDef") .set_body_method(&ModuleNode::AddDef); -TVM_REGISTER_API("relay._module.Module_GetGlobalVar") +TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar") .set_body_method(&ModuleNode::GetGlobalVar); -TVM_REGISTER_API("relay._module.Module_GetGlobalVars") +TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVars") .set_body_method(&ModuleNode::GetGlobalVars); -TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVars") +TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVars") .set_body_method(&ModuleNode::GetGlobalTypeVars); -TVM_REGISTER_API("relay._module.Module_ContainGlobalVar") +TVM_REGISTER_GLOBAL("relay._module.Module_ContainGlobalVar") .set_body_method(&ModuleNode::ContainGlobalVar); -TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") +TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar") .set_body_method(&ModuleNode::GetGlobalTypeVar); -TVM_REGISTER_API("relay._module.Module_Lookup") +TVM_REGISTER_GLOBAL("relay._module.Module_Lookup") .set_body_typed([](Module mod, GlobalVar var) { return mod->Lookup(var); }); -TVM_REGISTER_API("relay._module.Module_Lookup_str") +TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") .set_body_typed([](Module mod, std::string var) { return mod->Lookup(var); }); -TVM_REGISTER_API("relay._module.Module_LookupDef") +TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") .set_body_typed([](Module mod, GlobalTypeVar var) { return mod->LookupDef(var); }); -TVM_REGISTER_API("relay._module.Module_LookupDef_str") +TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") .set_body_typed([](Module mod, std::string var) { return mod->LookupDef(var); }); -TVM_REGISTER_API("relay._module.Module_LookupTag") +TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") .set_body_typed([](Module mod, int32_t tag) { return mod->LookupTag(tag); }); -TVM_REGISTER_API("relay._module.Module_FromExpr") +TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") .set_body_typed< Module(Expr, tvm::Map, @@ -399,17 +399,17 @@ TVM_REGISTER_API("relay._module.Module_FromExpr") return ModuleNode::FromExpr(e, funcs, type_defs); }); -TVM_REGISTER_API("relay._module.Module_Update") +TVM_REGISTER_GLOBAL("relay._module.Module_Update") .set_body_typed([](Module mod, Module from) { mod->Update(from); }); -TVM_REGISTER_API("relay._module.Module_Import") +TVM_REGISTER_GLOBAL("relay._module.Module_Import") .set_body_typed([](Module mod, std::string path) { mod->Import(path); }); -TVM_REGISTER_API("relay._module.Module_ImportFromStd") +TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") .set_body_typed([](Module mod, std::string path) { mod->ImportFromStd(path); });; diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 05788b1..4bef724 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -135,7 +135,7 @@ void OpRegistry::UpdateAttr(const std::string& key, } // Frontend APIs -TVM_REGISTER_API("relay.op._ListOpNames") +TVM_REGISTER_GLOBAL("relay.op._ListOpNames") .set_body_typed()>([]() { Array ret; for (const std::string& name : @@ -145,9 +145,9 @@ TVM_REGISTER_API("relay.op._ListOpNames") return ret; }); -TVM_REGISTER_API("relay.op._GetOp").set_body_typed(Op::Get); +TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get); -TVM_REGISTER_API("relay.op._OpGetAttr") +TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") .set_body([](TVMArgs args, TVMRetValue* rv) { Op op = args[0]; std::string attr_name = args[1]; @@ -157,7 +157,7 @@ TVM_REGISTER_API("relay.op._OpGetAttr") } }); -TVM_REGISTER_API("relay.op._OpSetAttr") +TVM_REGISTER_GLOBAL("relay.op._OpSetAttr") .set_body([](TVMArgs args, TVMRetValue* rv) { Op op = args[0]; std::string attr_name = args[1]; @@ -168,7 +168,7 @@ TVM_REGISTER_API("relay.op._OpSetAttr") reg.set_attr(attr_name, value, plevel); }); -TVM_REGISTER_API("relay.op._OpResetAttr") +TVM_REGISTER_GLOBAL("relay.op._OpResetAttr") .set_body([](TVMArgs args, TVMRetValue* rv) { Op op = args[0]; std::string attr_name = args[1]; @@ -177,7 +177,7 @@ TVM_REGISTER_API("relay.op._OpResetAttr") reg.reset_attr(attr_name); }); -TVM_REGISTER_API("relay.op._Register") +TVM_REGISTER_GLOBAL("relay.op._Register") .set_body([](TVMArgs args, TVMRetValue* rv) { std::string op_name = args[0]; std::string attr_key = args[1]; diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 9926844..612d586 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -990,7 +990,7 @@ std::string AsText(const ObjectRef& node, return PrettyPrint_(node, show_meta_data, annotate); } -TVM_REGISTER_API("relay._expr.AsText") +TVM_REGISTER_GLOBAL("relay._expr.AsText") .set_body_typed)>(AsText); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index f1efddf..aa9d376 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -54,7 +54,7 @@ IndexExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_API("relay._make.TensorType") +TVM_REGISTER_GLOBAL("relay._make.TensorType") .set_body_typed(TensorTypeNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -72,7 +72,7 @@ TypeCall TypeCallNode::make(Type func, tvm::Array args) { TVM_REGISTER_NODE_TYPE(TypeCallNode); -TVM_REGISTER_API("relay._make.TypeCall") +TVM_REGISTER_GLOBAL("relay._make.TypeCall") .set_body_typed(TypeCallNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -90,7 +90,7 @@ IncompleteType IncompleteTypeNode::make(Kind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); -TVM_REGISTER_API("relay._make.IncompleteType") +TVM_REGISTER_GLOBAL("relay._make.IncompleteType") .set_body_typed([](int kind) { return IncompleteTypeNode::make(static_cast(kind)); }); @@ -115,7 +115,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); -TVM_REGISTER_API("relay._make.TypeRelation") +TVM_REGISTER_GLOBAL("relay._make.TypeRelation") .set_body_typed(TypeRelationNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -134,7 +134,7 @@ TupleType TupleTypeNode::make(Array fields) { TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_API("relay._make.TupleType") +TVM_REGISTER_GLOBAL("relay._make.TupleType") .set_body_typed(TupleTypeNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -149,7 +149,7 @@ RefType RefTypeNode::make(Type value) { return RefType(n); } -TVM_REGISTER_API("relay._make.RefType") +TVM_REGISTER_GLOBAL("relay._make.RefType") .set_body_typed(RefTypeNode::make); TVM_REGISTER_NODE_TYPE(RefTypeNode); @@ -160,7 +160,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "RefTypeNode(" << node->value << ")"; }); -TVM_REGISTER_API("relay._make.Any") +TVM_REGISTER_GLOBAL("relay._make.Any") .set_body_typed([]() { return Any::make(); }); diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index 7a58cfd..0d68b44 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -60,7 +60,7 @@ Expr MakeArgsort(Expr data, } -TVM_REGISTER_API("relay.op._make.argsort") +TVM_REGISTER_GLOBAL("relay.op._make.argsort") .set_body_typed(MakeArgsort); RELAY_REGISTER_OP("argsort") diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 055d65b..161ca1c 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -83,7 +83,7 @@ Expr MakeTopK(Expr data, } -TVM_REGISTER_API("relay.op._make.topk") +TVM_REGISTER_GLOBAL("relay.op._make.topk") .set_body_typed(MakeTopK); RELAY_REGISTER_OP("topk") diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 9234591..253af5b 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -39,7 +39,7 @@ namespace relay { // relay.annotation.on_device TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); -TVM_REGISTER_API("relay.op.annotation._make.on_device") +TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") .set_body_typed([](Expr data, int device_type) { auto attrs = make_object(); attrs->device_type = device_type; @@ -62,7 +62,7 @@ Expr StopFusion(Expr data) { return CallNode::make(op, {data}, Attrs{}, {}); } -TVM_REGISTER_API("relay.op.annotation._make.stop_fusion") +TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion") .set_body_typed([](Expr data) { return StopFusion(data); }); @@ -144,7 +144,7 @@ Mark the end of bitpacking. return {topi::identity(inputs[0])}; }); -TVM_REGISTER_API("relay.op.annotation._make.checkpoint") +TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint") .set_body_typed([](Expr data) { static const Op& op = Op::Get("annotation.checkpoint"); return CallNode::make(op, {data}, Attrs{}, {}); diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index f592d3e..cdfdac0 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -65,7 +65,7 @@ Expr MakeDebug(Expr expr, std::string name) { return CallNode::make(op, {expr}, Attrs(dattrs), {}); } -TVM_REGISTER_API("relay.op._make.debug") +TVM_REGISTER_GLOBAL("relay.op._make.debug") .set_body_typed(MakeDebug); } // namespace relay diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 290ccef..9c3f6af 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -41,7 +41,7 @@ namespace relay { // relay.device_copy TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); -TVM_REGISTER_API("relay.op._make.device_copy") +TVM_REGISTER_GLOBAL("relay.op._make.device_copy") .set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) { auto attrs = make_object(); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index baab0ea..b7169b1 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -84,7 +84,7 @@ Expr MakeResize(Expr data, } -TVM_REGISTER_API("relay.op.image._make.resize") +TVM_REGISTER_GLOBAL("relay.op.image._make.resize") .set_body_typed(MakeResize); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 72edeac..af7291d 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -41,7 +41,7 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); // The passing value in attrs and args doesn't seem super great. // We should consider a better solution, i.e the type relation // being able to see the arguments as well? -TVM_REGISTER_API("relay.op.memory._make.alloc_storage") +TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage") .set_body_typed([](Expr size, Expr alignment, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; @@ -87,7 +87,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") return {topi::identity(inputs[0])}; }); -TVM_REGISTER_API("relay.op.memory._make.alloc_tensor") +TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") .set_body_typed assert_shape)>( [](Expr storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); @@ -208,7 +208,7 @@ bool InvokeTVMOPRel(const Array& types, int num_inputs, const Attrs& attrs return true; } -TVM_REGISTER_API("relay.op.memory._make.invoke_tvm_op") +TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op") .set_body_typed( [](Expr func, Expr inputs, Expr outputs) { return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs()); @@ -256,7 +256,7 @@ RELAY_REGISTER_OP("memory.kill") return {topi::identity(inputs[0])}; }); -TVM_REGISTER_API("relay.op.memory._make.shape_func") +TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func") .set_body_typed)>( [](Expr func, Expr inputs, Expr outputs, Array is_input) { static const Op& op = Op::Get("memory.shape_func"); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index 973ee0b..09c060d 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -96,7 +96,7 @@ Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack); +TVM_REGISTER_GLOBAL("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack); RELAY_REGISTER_OP("nn.bitpack") .describe(R"code(Bitpack layer that prepares data for bitserial operations. @@ -167,7 +167,7 @@ Expr MakeBinaryConv2D(Expr data, Expr weight, Array strides, Array& types, return true; } -TVM_REGISTER_API("relay.op.nn._make.fifo_buffer") +TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer") .set_body_typed(MakeFIFOBuffer); RELAY_REGISTER_OP("nn.fifo_buffer") @@ -183,7 +183,7 @@ Expr MakeDense(Expr data, } -TVM_REGISTER_API("relay.op.nn._make.dense") +TVM_REGISTER_GLOBAL("relay.op.nn._make.dense") .set_body_typed(MakeDense); @@ -215,7 +215,7 @@ Expr MakeLeakyRelu(Expr data, } -TVM_REGISTER_API("relay.op.nn._make.leaky_relu") +TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu") .set_body_typed(MakeLeakyRelu); @@ -295,7 +295,7 @@ Expr MakePRelu(Expr data, } -TVM_REGISTER_API("relay.op.nn._make.prelu") +TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu") .set_body_typed(MakePRelu); @@ -325,7 +325,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. // relay.softmax TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); -TVM_REGISTER_API("relay.op.nn._make.softmax") +TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") .set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; @@ -360,7 +360,7 @@ RELAY_REGISTER_OP("nn.softmax") // relay.nn.log_softmax -TVM_REGISTER_API("relay.op.nn._make.log_softmax") +TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") .set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; @@ -429,7 +429,7 @@ Expr MakeBatchFlatten(Expr data) { } -TVM_REGISTER_API("relay.op.nn._make.batch_flatten") +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten") .set_body_typed(MakeBatchFlatten); @@ -469,7 +469,7 @@ Example:: // relu -TVM_REGISTER_API("relay.op.nn._make.relu") +TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") .set_body_typed([](Expr data) { static const Op& op = Op::Get("nn.relu"); return CallNode::make(op, {data}, Attrs(), {}); @@ -514,7 +514,7 @@ Expr MakeLRN(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.lrn") +TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn") .set_body_typed(MakeLRN); RELAY_REGISTER_OP("nn.lrn") @@ -552,7 +552,7 @@ Expr MakeL2Normalize(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.l2_normalize") +TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize") .set_body_typed(MakeL2Normalize); RELAY_REGISTER_OP("nn.l2_normalize") @@ -597,7 +597,7 @@ Expr MakeDropout(Expr data, double rate) { return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.dropout") +TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout") .set_body_typed(MakeDropout); RELAY_REGISTER_OP("nn.dropout") @@ -689,7 +689,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi return CallNode::make(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.batch_norm") +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm") .set_body_typed(MakeBatchNorm); RELAY_REGISTER_OP("nn.batch_norm") @@ -772,7 +772,7 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.instance_norm") +TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeInstanceNorm, args, rv); }); @@ -842,7 +842,7 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.layer_norm") +TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeLayerNorm, args, rv); }); @@ -894,7 +894,7 @@ Expr MakeBatchMatmul(Expr x, } -TVM_REGISTER_API("relay.op.nn._make.batch_matmul") +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul") .set_body_typed(MakeBatchMatmul); @@ -951,7 +951,7 @@ Expr MakeCrossEntropy(Expr predictions, Expr targets) { } -TVM_REGISTER_API("relay.op.nn._make.cross_entropy") +TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy") .set_body_typed(MakeCrossEntropy); @@ -974,7 +974,7 @@ Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { } -TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits") +TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy_with_logits") .set_body_typed(MakeCrossEntropyWithLogits); @@ -1032,7 +1032,7 @@ Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.depth_to_space").set_body_typed(MakeDepthToSpace); +TVM_REGISTER_GLOBAL("relay.op.nn._make.depth_to_space").set_body_typed(MakeDepthToSpace); RELAY_REGISTER_OP("nn.depth_to_space") .describe(R"code(Rearrange input channels into spatial pixels. @@ -1089,7 +1089,7 @@ Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) { return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.space_to_depth").set_body_typed(MakeSpaceToDepth); +TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_depth").set_body_typed(MakeSpaceToDepth); RELAY_REGISTER_OP("nn.space_to_depth") .describe(R"code(Rearrange spatial pixels into new output channels. diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 5cde414..0d5810f 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -200,7 +200,7 @@ Expr MakePad(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.pad") +TVM_REGISTER_GLOBAL("relay.op.nn._make.pad") .set_body_typed(MakePad); RELAY_REGISTER_OP("nn.pad") @@ -274,7 +274,7 @@ Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mo return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.mirror_pad") +TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad") .set_body_typed(MakeMirrorPad); RELAY_REGISTER_OP("nn.mirror_pad") diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0021690..88d4306 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -213,7 +213,7 @@ Array Pool2DCompute(const Attrs& attrs, } } -TVM_REGISTER_API("relay.op.nn._make.max_pool2d") +TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") .set_body_typed, Array, Array, std::string, bool)>([](Expr data, Array pool_size, @@ -257,7 +257,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") // AvgPool2D -TVM_REGISTER_API("relay.op.nn._make.avg_pool2d") +TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") .set_body_typed, Array, Array, std::string, bool, bool)>([](Expr data, Array pool_size, @@ -366,7 +366,7 @@ Expr MakeGlobalAvgPool2D(Expr data, } -TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d") +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d") .set_body_typed(MakeGlobalAvgPool2D); // GlobalAvgPool @@ -397,7 +397,7 @@ Expr MakeGlobalMaxPool2D(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d") +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d") .set_body_typed(MakeGlobalMaxPool2D); @@ -518,7 +518,7 @@ Expr MakeAdaptiveAvgPool2D(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.contrib._make.adaptive_avg_pool2d") +TVM_REGISTER_GLOBAL("relay.op.contrib._make.adaptive_avg_pool2d") .set_body_typed(MakeAdaptiveAvgPool2D); RELAY_REGISTER_OP("contrib.adaptive_avg_pool2d") @@ -557,7 +557,7 @@ Expr MakeAdaptiveMaxPool2D(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.contrib._make.adaptive_max_pool2d") +TVM_REGISTER_GLOBAL("relay.op.contrib._make.adaptive_max_pool2d") .set_body_typed(MakeAdaptiveMaxPool2D); RELAY_REGISTER_OP("contrib.adaptive_max_pool2d") @@ -657,7 +657,7 @@ Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, return CallNode::make(op, {out_grad, data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad); RELAY_REGISTER_OP("nn.max_pool2d_grad") @@ -706,7 +706,7 @@ Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, return CallNode::make(op, {out_grad, data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad); RELAY_REGISTER_OP("nn.avg_pool2d_grad") @@ -867,7 +867,7 @@ Array Pool3DCompute(const Attrs& attrs, } } -TVM_REGISTER_API("relay.op.nn._make.max_pool3d") +TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") .set_body_typed, Array, Array, std::string, bool)>([](Expr data, Array pool_size, @@ -911,7 +911,7 @@ RELAY_REGISTER_OP("nn.max_pool3d") // AvgPool3D -TVM_REGISTER_API("relay.op.nn._make.avg_pool3d") +TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") .set_body_typed, Array, Array, std::string, bool, bool)>([](Expr data, Array pool_size, diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index fc22725..75aefe2 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -70,7 +70,7 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.sparse_dense") +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeSparseDense, args, rv); }); @@ -119,7 +119,7 @@ Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indp return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.sparse_transpose") +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose") .set_body_typed(MakeSparseTranspose); diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 2ba7b2f..1f2a016 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -112,7 +112,7 @@ Expr MakeUpSampling(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.upsampling") +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling") .set_body_typed(MakeUpSampling); @@ -193,7 +193,7 @@ Expr MakeUpSampling3D(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.nn._make.upsampling3d") +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d") .set_body_typed(MakeUpSampling3D); diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 04f26b9..fc8978b 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -48,7 +48,7 @@ namespace relay { * \param OpName the name of registry. */ #define RELAY_REGISTER_UNARY_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data) { \ static const Op& op = Op::Get(OpName); \ return CallNode::make(op, {data}, Attrs(), {}); \ @@ -74,7 +74,7 @@ namespace relay { * \param OpName the name of registry. */ #define RELAY_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ @@ -91,7 +91,7 @@ namespace relay { // Comparisons #define RELAY_REGISTER_CMP_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 4e9a900..07f1f56 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -302,7 +302,7 @@ bool ReduceRel(const Array& types, } #define RELAY_REGISTER_REDUCE_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed, bool, bool)>([]( \ Expr data, \ Array axis, \ @@ -633,7 +633,7 @@ Expr MakeVariance(Expr data, return CallNode::make(op, {data, mean}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make._variance") +TVM_REGISTER_GLOBAL("relay.op._make._variance") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeVariance, args, rv); }); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7407f21..1d56a0f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -83,7 +83,7 @@ Expr MakeCast(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay._make.cast") +TVM_REGISTER_GLOBAL("relay._make.cast") .set_body_typed(MakeCast); RELAY_REGISTER_OP("cast") @@ -140,7 +140,7 @@ Expr MakeCastLike(Expr data, } -TVM_REGISTER_API("relay._make.cast_like") +TVM_REGISTER_GLOBAL("relay._make.cast_like") .set_body_typed(MakeCastLike); RELAY_REGISTER_OP("cast_like") @@ -171,7 +171,7 @@ Expr MakeReinterpret(Expr data, DataType dtype) { return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeReinterpret, args, rv); }); @@ -249,7 +249,7 @@ Expr MakeExpandDims(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.expand_dims") +TVM_REGISTER_GLOBAL("relay.op._make.expand_dims") .set_body_typed(MakeExpandDims); RELAY_REGISTER_OP("expand_dims") @@ -334,7 +334,7 @@ Expr MakeConcatenate(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.concatenate") +TVM_REGISTER_GLOBAL("relay.op._make.concatenate") .set_body_typed(MakeConcatenate); RELAY_REGISTER_OP("concatenate") @@ -429,7 +429,7 @@ Expr MakeStack(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.stack") +TVM_REGISTER_GLOBAL("relay.op._make.stack") .set_body_typed(MakeStack); RELAY_REGISTER_OP("stack") @@ -521,7 +521,7 @@ Expr MakeTranspose(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.transpose") +TVM_REGISTER_GLOBAL("relay.op._make.transpose") .set_body_typed(MakeTranspose); RELAY_REGISTER_OP("transpose") @@ -713,7 +713,7 @@ Expr MakeReshape(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.reshape") +TVM_REGISTER_GLOBAL("relay.op._make.reshape") .set_body_typed(MakeReshape); RELAY_REGISTER_OP("reshape") @@ -821,7 +821,7 @@ Expr MakeReshapeLike(Expr data, } -TVM_REGISTER_API("relay.op._make.reshape_like") +TVM_REGISTER_GLOBAL("relay.op._make.reshape_like") .set_body_typed(MakeReshapeLike); @@ -857,7 +857,7 @@ bool ArgWhereRel(const Array& types, return true; } -TVM_REGISTER_API("relay.op._make.argwhere") +TVM_REGISTER_GLOBAL("relay.op._make.argwhere") .set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); auto attrs = make_object(); @@ -945,7 +945,7 @@ Expr MakeTake(Expr data, return CallNode::make(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.take") +TVM_REGISTER_GLOBAL("relay.op._make.take") .set_body_typed(MakeTake); RELAY_REGISTER_OP("take") @@ -1026,7 +1026,7 @@ Expr MakeFull(Expr fill_value, return CallNode::make(op, {fill_value}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.full") +TVM_REGISTER_GLOBAL("relay.op._make.full") .set_body_typed(MakeFull); RELAY_REGISTER_OP("full") @@ -1061,7 +1061,7 @@ Expr MakeZeros(Array shape, return CallNode::make(op, {}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.zeros") +TVM_REGISTER_GLOBAL("relay.op._make.zeros") .set_body_typed(MakeZeros); RELAY_REGISTER_OP("zeros") @@ -1082,7 +1082,7 @@ Expr MakeOnes(Array shape, return CallNode::make(op, {}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.ones") +TVM_REGISTER_GLOBAL("relay.op._make.ones") .set_body_typed(MakeOnes); RELAY_REGISTER_OP("ones") @@ -1129,7 +1129,7 @@ Expr MakeFullLike(Expr data, return CallNode::make(op, {data, fill_value}, Attrs(), {}); } -TVM_REGISTER_API("relay.op._make.full_like") +TVM_REGISTER_GLOBAL("relay.op._make.full_like") .set_body_typed(MakeFullLike); RELAY_REGISTER_OP("full_like") @@ -1253,7 +1253,7 @@ Expr MakeArange(Expr start, return CallNode::make(op, {start, stop, step}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.arange") +TVM_REGISTER_GLOBAL("relay.op._make.arange") .set_body_typed(MakeArange); // An issue with the existing design is that we require dependency @@ -1342,7 +1342,7 @@ Expr MakeRepeat(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.repeat") +TVM_REGISTER_GLOBAL("relay.op._make.repeat") .set_body_typed(MakeRepeat); RELAY_REGISTER_OP("repeat") @@ -1451,7 +1451,7 @@ Expr MakeTile(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.tile") +TVM_REGISTER_GLOBAL("relay.op._make.tile") .set_body_typed(MakeTile); RELAY_REGISTER_OP("tile") @@ -1512,7 +1512,7 @@ Expr MakeReverse(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.reverse") +TVM_REGISTER_GLOBAL("relay.op._make.reverse") .set_body_typed(MakeReverse); RELAY_REGISTER_OP("reverse") @@ -1576,7 +1576,7 @@ Array WhereCompute(const Attrs& attrs, return { topi::where(inputs[0], inputs[1], inputs[2]) }; } -TVM_REGISTER_API("relay.op._make.where") +TVM_REGISTER_GLOBAL("relay.op._make.where") .set_body_typed(MakeWhere); RELAY_REGISTER_OP("where") @@ -1629,7 +1629,7 @@ Expr MakeSqueeze(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.squeeze") +TVM_REGISTER_GLOBAL("relay.op._make.squeeze") .set_body_typed(MakeSqueeze); @@ -1733,7 +1733,7 @@ Array CollapseSumLikeCompute(const Attrs& attrs, return { topi::collapse_sum(inputs[0], out_ttype->shape) }; } -TVM_REGISTER_API("relay.op._make.collapse_sum_like") +TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like") .set_body_typed(MakeCollapseSumLike); RELAY_REGISTER_OP("collapse_sum_like") @@ -1778,7 +1778,7 @@ Array BroadCastToCompute(const Attrs& attrs, return { topi::broadcast_to(inputs[0], ioattrs->shape) }; } -TVM_REGISTER_API("relay.op._make.broadcast_to") +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to") .set_body_typed(MakeBroadCastTo); RELAY_REGISTER_OP("broadcast_to") @@ -1816,7 +1816,7 @@ Array BroadCastToLikeCompute(const Attrs& attrs, return { topi::broadcast_to(inputs[0], out_ttype->shape) }; } -TVM_REGISTER_API("relay.op._make.broadcast_to_like") +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like") .set_body_typed(MakeBroadCastToLike); RELAY_REGISTER_OP("broadcast_to_like") @@ -2026,7 +2026,7 @@ Array StridedSliceCompute(const Attrs& attrs, } -TVM_REGISTER_API("relay.op._make.strided_slice") +TVM_REGISTER_GLOBAL("relay.op._make.strided_slice") .set_body_typed(MakeStridedSlice); @@ -2082,7 +2082,7 @@ Expr MakeStridedSet(Expr data, return CallNode::make(op, {data, v, begin, end, strides}, {}); } -TVM_REGISTER_API("relay.op._make.strided_set") +TVM_REGISTER_GLOBAL("relay.op._make.strided_set") .set_body_typed(MakeStridedSet); @@ -2198,7 +2198,7 @@ Expr MakeSplit(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.split") +TVM_REGISTER_GLOBAL("relay.op._make.split") .set_body([](const TVMArgs& args, TVMRetValue* rv) { if (args.type_codes[1] == kDLInt) { *rv = MakeSplit(args[0], make_const(DataType::Int(64), int64_t(args[1])), args[2]); @@ -2347,7 +2347,7 @@ Array SliceLikeCompute(const Attrs& attrs, } -TVM_REGISTER_API("relay.op._make.slice_like") +TVM_REGISTER_GLOBAL("relay.op._make.slice_like") .set_body_typed(MakeSliceLike); @@ -2410,7 +2410,7 @@ Expr MakeLayoutTransform(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.layout_transform") +TVM_REGISTER_GLOBAL("relay.op._make.layout_transform") .set_body_typed(MakeLayoutTransform); RELAY_REGISTER_OP("layout_transform") @@ -2438,7 +2438,7 @@ Expr MakeReverseReshape(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape") +TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape") .set_body_typed(MakeReverseReshape); RELAY_REGISTER_OP("_contrib_reverse_reshape") @@ -2512,7 +2512,7 @@ Expr MakeGatherND(Expr data, return CallNode::make(op, {data, indices}, {}); } -TVM_REGISTER_API("relay.op._make.gather_nd") +TVM_REGISTER_GLOBAL("relay.op._make.gather_nd") .set_body_typed(MakeGatherND); RELAY_REGISTER_OP("gather_nd") @@ -2573,7 +2573,7 @@ Expr MakeSequenceMask(Expr data, return CallNode::make(op, {data, valid_length}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.sequence_mask") +TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask") .set_body_typed(MakeSequenceMask); RELAY_REGISTER_OP("sequence_mask") @@ -2695,7 +2695,7 @@ Expr MakeOneHot(Expr indices, return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op._make.one_hot") +TVM_REGISTER_GLOBAL("relay.op._make.one_hot") .set_body_typed(MakeOneHot); RELAY_REGISTER_OP("one_hot") diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index d4cd7be..cc8419c 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -157,7 +157,7 @@ RELAY_REGISTER_UNARY_OP("copy") // relay.clip TVM_REGISTER_NODE_TYPE(ClipAttrs); -TVM_REGISTER_API("relay.op._make.clip") +TVM_REGISTER_GLOBAL("relay.op._make.clip") .set_body_typed([](Expr a, double a_min, double a_max) { auto attrs = make_object(); attrs->a_min = a_min; @@ -300,7 +300,7 @@ Array ShapeOfCompute(const Attrs& attrs, return {topi::shape(inputs[0], param->dtype)}; } -TVM_REGISTER_API("relay.op._make.shape_of") +TVM_REGISTER_GLOBAL("relay.op._make.shape_of") .set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; @@ -351,7 +351,7 @@ Array NdarraySizeCompute(const Attrs& attrs, return Array{topi::ndarray_size(inputs[0], param->dtype)}; } -TVM_REGISTER_API("relay.op.contrib._make.ndarray_size") +TVM_REGISTER_GLOBAL("relay.op.contrib._make.ndarray_size") .set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index 2dd0940..d9ec21f 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -71,7 +71,7 @@ Expr MakeMultiBoxPrior(Expr data, } -TVM_REGISTER_API("relay.op.vision._make.multibox_prior") +TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior") .set_body_typed(MakeMultiBoxPrior); @@ -143,7 +143,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, return CallNode::make(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc") +TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc") .set_body_typed(MakeMultiBoxTransformLoc); RELAY_REGISTER_OP("vision.multibox_transform_loc") diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 6759e18..307b2a4 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -61,7 +61,7 @@ Expr MakeGetValidCounts(Expr data, } -TVM_REGISTER_API("relay.op.vision._make.get_valid_counts") +TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts") .set_body_typed(MakeGetValidCounts); @@ -129,7 +129,7 @@ Expr MakeNMS(Expr data, } -TVM_REGISTER_API("relay.op.vision._make.non_max_suppression") +TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression") .set_body_typed(MakeNMS); diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 24f4b98..7b3533d 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -60,7 +60,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spa return CallNode::make(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.vision._make.roi_align") +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align") .set_body_typed(MakeROIAlign); RELAY_REGISTER_OP("vision.roi_align") @@ -110,7 +110,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spat return CallNode::make(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.vision._make.roi_pool") +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool") .set_body_typed(MakeROIPool); RELAY_REGISTER_OP("vision.roi_pool") @@ -176,7 +176,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array return CallNode::make(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.vision._make.proposal") +TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal") .set_body_typed(MakeProposal); RELAY_REGISTER_OP("vision.proposal") diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 74b59f6..616dc2a 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -69,7 +69,7 @@ Expr MakeYoloReorg(Expr data, } -TVM_REGISTER_API("relay.op.vision._make.yolo_reorg") +TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg") .set_body_typed(MakeYoloReorg); diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index b3b08c1..630c25e 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -126,7 +126,7 @@ Pass AlterOpLayout() { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.AlterOpLayout") +TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") .set_body_typed(AlterOpLayout); } // namespace transform diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index c790659..861efb4 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -137,7 +137,7 @@ Pass CanonicalizeCast() { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.CanonicalizeCast") +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") .set_body_typed(CanonicalizeCast); } // namespace transform diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 64b702c..78001bb 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -77,7 +77,7 @@ Pass CanonicalizeOps() { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.CanonicalizeOps") +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") .set_body_typed(CanonicalizeOps); } // namespace transform diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index e5c253e..869aa28 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -224,7 +224,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.CombineParallelConv2D") +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") .set_body_typed(CombineParallelConv2D); } // namespace transform diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index e7a03da..af43225 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -23,8 +23,8 @@ * \brief Combine parallel dense ops into a single dense. * * This pass replaces dense ops that share the same input node, same shape, - * and don't have "units" defined with a single batch matrix multiplication. - * The inputs of the new batch_matmul is the stack of the original inputs. + * and don't have "units" defined with a single batch matrix multiplication. + * The inputs of the new batch_matmul is the stack of the original inputs. * Elemwise and broadcast ops following dense are also combined if possible. * * This prevents launching multiple kernels in networks with multiple @@ -84,7 +84,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.CombineParallelDense") +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") .set_body_typed(CombineParallelDense); } // namespace transform diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc index 75cebfa..d8152f6 100644 --- a/src/relay/pass/combine_parallel_op_batch.cc +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,13 +21,13 @@ * * \file combine_parallel_op_batch.cc * \brief Combine parallel ops into a single batch op. - * + * * This pass replaces ops that share the same input node and same shape * with a single op that takes in batched input. The inputs of the new * batched op are the stack of the original inputs. Elementwise and * broadcast ops following the original op are also stacked * and fused if possible. For example: - * + * * data * / \ * add (2,2) add (2,2) @@ -36,7 +36,7 @@ * | | * * Would become: - * + * * data * | * add+elemwise (2,2,2) @@ -197,7 +197,7 @@ Pass CombineParallelOpBatch(const std::string& op_name, {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.CombineParallelOpBatch") +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") .set_body_typed(CombineParallelOpBatch); } // namespace transform diff --git a/src/relay/pass/convert_layout.cc b/src/relay/pass/convert_layout.cc index 8b223ee..da0c28f 100644 --- a/src/relay/pass/convert_layout.cc +++ b/src/relay/pass/convert_layout.cc @@ -138,7 +138,7 @@ Pass ConvertLayout(const std::string& desired_layout) { ir::StringImm::make("CanonicalizeOps")}); } -TVM_REGISTER_API("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); +TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); } // namespace transform diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index cf99dc3..3cfed1b 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -114,7 +114,7 @@ Expr DeDup(const Expr& e) { return ret; } -TVM_REGISTER_API("relay._transform.dedup") +TVM_REGISTER_GLOBAL("relay._transform.dedup") .set_body_typed(DeDup); } // namespace relay diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 14bca58..05324af 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -147,7 +147,7 @@ Pass DeadCodeElimination(bool inline_once) { return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_REGISTER_API("relay._transform.DeadCodeElimination") +TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination") .set_body_typed(DeadCodeElimination); } // namespace transform diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 91a7fa3..1229324 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -560,13 +560,13 @@ Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_API("relay._analysis.CollectDeviceInfo") +TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceInfo") .set_body_typed(CollectDeviceInfo); -TVM_REGISTER_API("relay._analysis.RewriteDeviceAnnotation") +TVM_REGISTER_GLOBAL("relay._analysis.RewriteDeviceAnnotation") .set_body_typed(RewriteAnnotatedOps); -TVM_REGISTER_API("relay._analysis.CollectDeviceAnnotationOps") +TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); namespace transform { @@ -580,7 +580,7 @@ Pass RewriteAnnotatedOps(int fallback_device) { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation") +TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") .set_body_typed(RewriteAnnotatedOps); } // namespace transform diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index d180fcc..f9a303b 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -95,7 +95,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr") +TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") .set_body_typed(EliminateCommonSubexpr); } // namespace transform diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 888874c..672d551 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -160,7 +160,7 @@ Pass EtaExpand(bool expand_constructor, bool expand_global_var) { return CreateModulePass(pass_func, 1, "EtaExpand", {}); } -TVM_REGISTER_API("relay._transform.EtaExpand") +TVM_REGISTER_GLOBAL("relay._transform.EtaExpand") .set_body_typed(EtaExpand); } // namespace transform diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index 79830a7..ad0ce95 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -104,7 +104,7 @@ Array PyDetectFeature(const Expr& expr, const Module& mod) { return static_cast>(fs); } -TVM_REGISTER_API("relay._analysis.detect_feature") +TVM_REGISTER_GLOBAL("relay._analysis.detect_feature") .set_body_typed(PyDetectFeature); } // namespace relay diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 4a6417b..b830de0 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -70,7 +70,7 @@ bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); } -TVM_REGISTER_API("relay._analysis.check_constant") +TVM_REGISTER_GLOBAL("relay._analysis.check_constant") .set_body_typed(ConstantCheck); // TODO(tvm-team) consider combine dead-code with constant folder. @@ -297,7 +297,7 @@ Pass FoldConstant() { return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } -TVM_REGISTER_API("relay._transform.FoldConstant") +TVM_REGISTER_GLOBAL("relay._transform.FoldConstant") .set_body_typed(FoldConstant); } // namespace transform diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 711297c..fea5cdb 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -958,7 +958,7 @@ Pass ForwardFoldScaleAxis() { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.ForwardFoldScaleAxis") +TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") .set_body_typed(ForwardFoldScaleAxis); Pass BackwardFoldScaleAxis() { @@ -971,7 +971,7 @@ Pass BackwardFoldScaleAxis() { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.BackwardFoldScaleAxis") +TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") .set_body_typed(BackwardFoldScaleAxis); Pass FoldScaleAxis() { @@ -983,7 +983,7 @@ Pass FoldScaleAxis() { return pass; } -TVM_REGISTER_API("relay._transform.FoldScaleAxis") +TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis") .set_body_typed(FoldScaleAxis); } // namespace transform diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 7b8f6de..eb050fe 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -985,7 +985,7 @@ Pass FuseOps(int fuse_opt_level) { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.FuseOps") +TVM_REGISTER_GLOBAL("relay._transform.FuseOps") .set_body_typed(FuseOps); } // namespace transform diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 61f7e2d..cd86aaf 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -254,7 +254,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_API("relay._transform.first_order_gradient") +TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") .set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { @@ -582,7 +582,7 @@ Expr Gradient(const Expr& re, const Module& mod) { return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_API("relay._transform.gradient") +TVM_REGISTER_GLOBAL("relay._transform.gradient") .set_body_typed(Gradient); } // namespace relay diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 3bd8e87..5b7e1c0 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -182,7 +182,7 @@ Kind KindCheck(const Type& t, const Module& mod) { return kc.Check(t); } -TVM_REGISTER_API("relay._analysis.check_kind") +TVM_REGISTER_GLOBAL("relay._analysis.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = KindCheck(args[0], ModuleNode::make({}, {})); diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc index d2554d4..8f3830e 100644 --- a/src/relay/pass/legalize.cc +++ b/src/relay/pass/legalize.cc @@ -105,7 +105,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) { return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize); +TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize); } // namespace transform diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index 86bf972..a5cd93a 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,7 +20,7 @@ /*! * * \file mac_count.cc - * \brief Pass to roughly count the number of MACs (Multiply-Accumulate) + * \brief Pass to roughly count the number of MACs (Multiply-Accumulate) * operations of a model. Only MACs in CONV and Dense ops are counted. * This pass is valid after the type infer pass is called, * otherwise the count is 0. @@ -206,7 +206,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_API("relay._analysis.GetTotalMacNumber") +TVM_REGISTER_GLOBAL("relay._analysis.GetTotalMacNumber") .set_body_typed(GetTotalMacNumber); } // namespace mac_count diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index 6c17529..d9e8d87 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -326,7 +326,7 @@ Array UnmatchedCases(const Match& match, const Module& mod) { } // expose for testing only -TVM_REGISTER_API("relay._analysis.unmatched_cases") +TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") .set_body_typed(const Match&, const Module&)>( [](const Match& match, const Module& mod_ref) { Module call_mod = mod_ref; diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 7a524ee..a6b8671 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1270,7 +1270,7 @@ Pass PartialEval() { return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } -TVM_REGISTER_API("relay._transform.PartialEvaluate") +TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate") .set_body_typed(PartialEval); } // namespace transform diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 909ba0b..ae02d70 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -444,10 +444,10 @@ Pass CreateFunctionPass( TVM_REGISTER_NODE_TYPE(PassInfoNode); -TVM_REGISTER_API("relay._transform.PassInfo") +TVM_REGISTER_GLOBAL("relay._transform.PassInfo") .set_body_typed(PassInfoNode::make); -TVM_REGISTER_API("relay._transform.Info") +TVM_REGISTER_GLOBAL("relay._transform.Info") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); @@ -469,10 +469,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_API("relay._transform.MakeModulePass") +TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass") .set_body_typed(ModulePassNode::make); -TVM_REGISTER_API("relay._transform.RunPass") +TVM_REGISTER_GLOBAL("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; Module mod = args[1]; @@ -489,7 +489,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_REGISTER_API("relay._transform.MakeFunctionPass") +TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") .set_body_typed(FunctionPassNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -502,7 +502,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_API("relay._transform.Sequential") +TVM_REGISTER_GLOBAL("relay._transform.Sequential") .set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; @@ -528,7 +528,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_API("relay._transform.PassContext") +TVM_REGISTER_GLOBAL("relay._transform.PassContext") .set_body([](TVMArgs args, TVMRetValue* ret) { auto pctx = PassContext::Create(); int opt_level = args[0]; @@ -575,13 +575,13 @@ class PassContext::Internal { } }; -TVM_REGISTER_API("relay._transform.GetCurrentPassContext") +TVM_REGISTER_GLOBAL("relay._transform.GetCurrentPassContext") .set_body_typed(PassContext::Current); -TVM_REGISTER_API("relay._transform.EnterPassContext") +TVM_REGISTER_GLOBAL("relay._transform.EnterPassContext") .set_body_typed(PassContext::Internal::EnterScope); -TVM_REGISTER_API("relay._transform.ExitPassContext") +TVM_REGISTER_GLOBAL("relay._transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); } // namespace transform diff --git a/src/relay/pass/print_ir.cc b/src/relay/pass/print_ir.cc index e32865a..5191a2e 100644 --- a/src/relay/pass/print_ir.cc +++ b/src/relay/pass/print_ir.cc @@ -40,7 +40,7 @@ Pass PrintIR(bool show_meta_data) { return CreateModulePass(pass_func, 0, "PrintIR", {}); } -TVM_REGISTER_API("relay._transform.PrintIR") +TVM_REGISTER_GLOBAL("relay._transform.PrintIR") .set_body_typed(PrintIR); } // namespace transform diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc index c3d0107..5e1083a 100644 --- a/src/relay/pass/quantize/annotate.cc +++ b/src/relay/pass/quantize/annotate.cc @@ -70,7 +70,7 @@ QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { return QAnnotateExpr(rnode); } -TVM_REGISTER_API("relay._quantize.make_annotate_expr") +TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = QAnnotateExprNode::make(args[0], static_cast(args[1].operator int())); @@ -108,7 +108,7 @@ Pass QuantizeAnnotate() { return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } -TVM_REGISTER_API("relay._quantize.QuantizeAnnotate") +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate") .set_body_typed(QuantizeAnnotate); TVM_REGISTER_NODE_TYPE(QAnnotateExprNode); diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index bcf82c0..f6f0112 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -202,11 +202,11 @@ Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); } -TVM_REGISTER_API("relay._quantize.CreateStatsCollector") +TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector") .set_body_typed(CreateStatsCollector); -TVM_REGISTER_API("relay._quantize.FindScaleByKLMinimization") +TVM_REGISTER_GLOBAL("relay._quantize.FindScaleByKLMinimization") .set_body([](TVMArgs args, TVMRetValue *ret) { int* hist_ptr = static_cast(static_cast(args[0])); float* hist_edges_ptr = static_cast(static_cast(args[1])); diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc index 710684c..6ad05e8 100644 --- a/src/relay/pass/quantize/partition.cc +++ b/src/relay/pass/quantize/partition.cc @@ -72,7 +72,7 @@ QPartitionExpr QPartitionExprNode::make(Expr expr) { return QPartitionExpr(rnode); } -TVM_REGISTER_API("relay._quantize.make_partition_expr") +TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = QPartitionExprNode::make(args[0]); }); @@ -87,7 +87,7 @@ Pass QuantizePartition() { return CreateFunctionPass(pass_func, 1, "QuantizePartition", {}); } -TVM_REGISTER_API("relay._quantize.QuantizePartition") +TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition") .set_body_typed(QuantizePartition); TVM_REGISTER_NODE_TYPE(QPartitionExprNode); diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index ef78bf2..c995994 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -66,7 +66,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .set_support_level(11) .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); -TVM_REGISTER_API("relay._quantize.simulated_quantize") +TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") .set_body_typed( [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, std::string rounding) { @@ -134,13 +134,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ")"; }); -TVM_REGISTER_API("relay._quantize._GetCurrentQConfig") +TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig") .set_body_typed(QConfig::Current); -TVM_REGISTER_API("relay._quantize._EnterQConfigScope") +TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope") .set_body_typed(QConfig::EnterQConfigScope); -TVM_REGISTER_API("relay._quantize._ExitQConfigScope") +TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope") .set_body_typed(QConfig::ExitQConfigScope); } // namespace quantize diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index bb8edf1..b3b44cc 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -530,7 +530,7 @@ Pass QuantizeRealizePass() { return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {}); } -TVM_REGISTER_API("relay._quantize.QuantizeRealize") +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize") .set_body_typed(QuantizeRealizePass); } // namespace quantize diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 6d6171c..5e67085 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -191,7 +191,7 @@ Pass SimplifyInference() { {ir::StringImm::make("InferType")}); } -TVM_REGISTER_API("relay._transform.SimplifyInference") +TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") .set_body_typed(SimplifyInference); } // namespace transform diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 57894e0..c839beb 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -328,7 +328,7 @@ Pass ToANormalForm() { return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } -TVM_REGISTER_API("relay._transform.ToANormalForm") +TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm") .set_body_typed(ToANormalForm); } // namespace transform diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 96e7f1a..3ca7a08 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -359,10 +359,10 @@ Function UnCPS(const Function& f) { f->attrs); } -TVM_REGISTER_API("relay._transform.to_cps") +TVM_REGISTER_GLOBAL("relay._transform.to_cps") .set_body_typed(static_cast(ToCPS)); -TVM_REGISTER_API("relay._transform.un_cps") +TVM_REGISTER_GLOBAL("relay._transform.un_cps") .set_body_typed(UnCPS); namespace transform { @@ -375,7 +375,7 @@ Pass ToCPS() { return CreateFunctionPass(pass_func, 1, "ToCPS", {}); } -TVM_REGISTER_API("relay._transform.ToCPS") +TVM_REGISTER_GLOBAL("relay._transform.ToCPS") .set_body_typed(ToCPS); @@ -387,7 +387,7 @@ Pass UnCPS() { return CreateFunctionPass(pass_func, 1, "UnCPS", {}); } -TVM_REGISTER_API("relay._transform.UnCPS") +TVM_REGISTER_GLOBAL("relay._transform.UnCPS") .set_body_typed(UnCPS); } // namespace transform diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index b00e0d4..c9eeefd 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -86,7 +86,7 @@ Pass ToGraphNormalForm() { return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); } -TVM_REGISTER_API("relay._transform.ToGraphNormalForm") +TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm") .set_body_typed(ToGraphNormalForm); } // namespace transform diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 6e992bb..a2944a9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -78,7 +78,7 @@ bool TupleGetItemRel(const Array& types, } TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); -TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") +TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem") .set_body_typed&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); @@ -839,7 +839,7 @@ Pass InferType() { return CreateFunctionPass(pass_func, 0, "InferType", {}); } -TVM_REGISTER_API("relay._transform.InferType") +TVM_REGISTER_GLOBAL("relay._transform.InferType") .set_body_typed([]() { return InferType(); }); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 86ebe0f..221f2c1 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -656,7 +656,7 @@ bool TypeSolver::Solve() { } // Expose type solver only for debugging purposes. -TVM_REGISTER_API("relay._analysis._test_type_solver") +TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { using runtime::PackedFunc; using runtime::TypedPackedFunc; diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 2efb479..3ad5dd1 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -274,10 +274,10 @@ tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -TVM_REGISTER_API("relay._analysis.free_vars") +TVM_REGISTER_GLOBAL("relay._analysis.free_vars") .set_body_typed(FreeVars); -TVM_REGISTER_API("relay._analysis.bound_vars") +TVM_REGISTER_GLOBAL("relay._analysis.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; if (x.as()) { @@ -287,10 +287,10 @@ TVM_REGISTER_API("relay._analysis.bound_vars") } }); -TVM_REGISTER_API("relay._analysis.all_vars") +TVM_REGISTER_GLOBAL("relay._analysis.all_vars") .set_body_typed(AllVars); -TVM_REGISTER_API("relay._analysis.free_type_vars") +TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; Module mod = args[1]; @@ -301,7 +301,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars") } }); -TVM_REGISTER_API("relay._analysis.bound_type_vars") +TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; Module mod = args[1]; @@ -312,7 +312,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars") } }); -TVM_REGISTER_API("relay._analysis.all_type_vars") +TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; Module mod = args[1]; diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index 2bbf979..ed95eb0 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -125,7 +125,7 @@ bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } -TVM_REGISTER_API("relay._analysis.well_formed") +TVM_REGISTER_GLOBAL("relay._analysis.well_formed") .set_body_typed(WellFormed); } // namespace relay diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 7dfa63f..685fb9f 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -161,7 +161,7 @@ RELAY_REGISTER_OP("qnn.concatenate") .add_type_rel("QnnConcatenate", QnnConcatenateRel) .set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize); -TVM_REGISTER_API("relay.qnn.op._make.concatenate") +TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate") .set_body_typed(MakeQnnConcatenate); } // namespace qnn diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 839fcbd..c9ce0ec 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -676,7 +676,7 @@ operator to understand how to scale back the int32 output to (u)int8. .add_type_rel("QnnConv2D", QnnConv2DRel) .set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize); -TVM_REGISTER_API("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); +TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index c762331..b7a12e1 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -190,7 +190,7 @@ RELAY_REGISTER_OP("qnn.dense") .add_type_rel("QDense", QnnDenseRel) .set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); -TVM_REGISTER_API("relay.qnn.op._make.dense") +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense") .set_body_typed(MakeQuantizedDense); } // namespace qnn diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 94f2f89..6579c3d 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -94,7 +94,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv .add_type_rel("Dequantize", DequantizeRel) .set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); -TVM_REGISTER_API("relay.qnn.op._make.dequantize") +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize") .set_body_typed(MakeDequantize); } // namespace qnn diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 41e8335..2a33009 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -64,7 +64,7 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con * \param OpName the name of registry. */ #define QNN_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_API("relay.qnn.op._make." OpName) \ + TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ .set_body_typed( \ [](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 9749fb8..27b3c0f 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -116,7 +116,7 @@ scale and zero point. .add_type_rel("Quantize", QuantizeRel) .set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); -TVM_REGISTER_API("relay.qnn.op._make.quantize") +TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize") .set_body_typed(MakeQuantize); } // namespace qnn diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 68b0b08..fc6a9b4 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -207,7 +207,7 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) .add_type_rel("Requantize", RequantizeRel) .set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize); -TVM_REGISTER_API("relay.qnn.op._make.requantize") +TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize") .set_body_typed(MakeRequantize); } // namespace qnn diff --git a/src/relay/qnn/pass/legalize.cc b/src/relay/qnn/pass/legalize.cc index 07864ad..33b9e59 100644 --- a/src/relay/qnn/pass/legalize.cc +++ b/src/relay/qnn/pass/legalize.cc @@ -38,7 +38,7 @@ Pass Legalize() { return seq; } -TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize); +TVM_REGISTER_GLOBAL("relay.qnn._transform.Legalize").set_body_typed(Legalize); } // namespace transform -- 2.7.4