From 2f8a01f7071deae4503e9b730304a0e4551c9210 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 17 Jan 2020 15:11:55 -0800 Subject: [PATCH] [REFACTOR] Get rid of packed_func_ext. (#4735) Move the conversion extensions to the specific class definitions so that we longer need to include packed_func_ext. --- apps/extension/src/tvm_ext.cc | 2 +- include/tvm/expr.h | 21 +++++ include/tvm/ir/expr.h | 27 ++++++ include/tvm/node/container.h | 59 ++++++++++++ include/tvm/node/node.h | 3 + include/tvm/packed_func_ext.h | 145 ------------------------------ include/tvm/relay/transform.h | 3 +- include/tvm/relay/type.h | 2 - src/api/api_arith.cc | 1 - src/api/api_base.cc | 2 - src/api/api_codegen.cc | 2 - src/api/api_ir.cc | 1 - src/api/api_lang.cc | 1 - src/api/api_pass.cc | 2 - src/api/api_schedule.cc | 1 - src/api/api_test.cc | 2 - src/arith/bound_deducer.cc | 2 - src/arith/domain_touched.cc | 2 - src/arith/int_set.cc | 1 - src/autotvm/touch_extractor.h | 1 - src/codegen/codegen_c_host.cc | 2 +- src/codegen/codegen_c_host.h | 2 +- src/codegen/codegen_cuda.cc | 1 - src/codegen/codegen_cuda.h | 2 +- src/codegen/codegen_metal.cc | 1 - src/codegen/codegen_metal.h | 1 - src/codegen/codegen_opencl.cc | 1 - src/codegen/codegen_opencl.h | 1 - src/codegen/codegen_opengl.cc | 1 - src/codegen/codegen_opengl.h | 1 - src/codegen/codegen_vhls.h | 2 +- src/codegen/datatype/registry.cc | 8 +- src/codegen/intrin_rule.cc | 1 + src/codegen/intrin_rule.h | 3 - src/codegen/llvm/intrin_rule_llvm.h | 1 - src/codegen/llvm/intrin_rule_nvptx.cc | 2 - src/codegen/llvm/intrin_rule_rocm.cc | 1 - src/codegen/spirv/intrin_rule_spirv.cc | 1 - src/codegen/stackvm/codegen_stackvm.cc | 1 - src/ir/attr_functor.h | 1 + src/ir/attrs.cc | 2 - src/ir/error.cc | 10 +-- src/ir/expr.cc | 21 +++++ src/ir/module.cc | 11 ++- src/ir/span.cc | 1 - src/ir/type.cc | 2 - src/ir/type_relation.cc | 2 - src/pass/hoist_if_then_else.cc | 1 - src/pass/inject_copy_intrin.cc | 1 - src/pass/ir_functor.cc | 1 - src/pass/lower_custom_datatypes.cc | 1 - src/pass/lower_intrin.cc | 1 - src/pass/verify_gpu_code.cc | 1 - src/relay/backend/compile_engine.cc | 1 - src/relay/backend/graph_plan_memory.cc | 1 + src/relay/backend/interpreter.cc | 1 - src/relay/backend/param_dict.h | 2 +- src/relay/ir/base.cc | 1 - src/relay/ir/type.cc | 1 + src/relay/op/vision/multibox_op.cc | 1 + src/relay/pass/type_solver.cc | 1 + src/relay/qnn/util.h | 1 + src/top/operation/tensorize.cc | 1 - tests/cpp/attrs_test.cc | 1 - tests/cpp/build_module_test.cc | 1 - tests/cpp/container_test.cc | 2 +- tests/cpp/packed_func_test.cc | 2 - tests/cpp/relay_build_module_test.cc | 2 - tests/cpp/relay_transform_sequential.cc | 1 - tests/cpp/utvm_runtime_standalone_test.cc | 1 - topi/src/topi.cc | 2 +- 71 files changed, 161 insertions(+), 231 deletions(-) delete mode 100644 include/tvm/packed_func_ext.h diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index 7a685bf..b439deb 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -26,8 +26,8 @@ #include #include #include -#include #include +#include using namespace tvm; using namespace tvm::runtime; diff --git a/include/tvm/expr.h b/include/tvm/expr.h index bdd0b8f..3f154da 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "node/node.h" #include "node/container.h" #include "node/functor.h" @@ -460,6 +461,26 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } } // namespace tvm +namespace tvm { +namespace runtime { +// Additional implementattion overloads for PackedFunc. +inline TVMPODValue_::operator tvm::Integer() const { + if (type_code_ == kTVMNullptr) return Integer(); + if (type_code_ == kDLInt) { + CHECK_LE(value_.v_int64, std::numeric_limits::max()); + CHECK_GE(value_.v_int64, std::numeric_limits::min()); + return Integer(static_cast(value_.v_int64)); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expect type " << ObjectTypeChecker::TypeName() + << " but get " << ptr->GetTypeKey(); + return Integer(ObjectPtr(ptr)); +} +} // namespace runtime +} // namespace tvm + namespace std { template <> struct hash<::tvm::IterVar> : public ::tvm::ObjectHash { diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 87122e8..e8e4597 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -30,6 +30,7 @@ #include #include #include +#include namespace tvm { @@ -114,6 +115,11 @@ class PrimExpr : public BaseExpr { } TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); + + private: + // Internal function for conversion. + friend class runtime::TVMPODValue_; + TVM_DLL static PrimExpr FromObject_(ObjectPtr ptr); }; /*! @@ -322,4 +328,25 @@ inline const TTypeNode* RelayExprNode::type_as() const { } } // namespace tvm + +namespace tvm { +namespace runtime { +// Additional implementattion overloads for PackedFunc. +inline TVMPODValue_::operator tvm::PrimExpr() const { + if (type_code_ == kTVMNullptr) return PrimExpr(); + if (type_code_ == kDLInt) { + CHECK_LE(value_.v_int64, std::numeric_limits::max()); + CHECK_GE(value_.v_int64, std::numeric_limits::min()); + return PrimExpr(static_cast(value_.v_int64)); + } + if (type_code_ == kDLFloat) { + return PrimExpr(static_cast(value_.v_float64)); + } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + Object* ptr = static_cast(value_.v_handle); + return PrimExpr::FromObject_(ObjectPtr(ptr)); +} +} // namespace runtime +} // namespace tvm #endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 7686a96..f5c7198 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -655,6 +655,65 @@ class Map : public ObjectRef { return iterator(static_cast(data_.get())->data.find(key)); } }; +} // namespace tvm +namespace tvm { +namespace runtime { +// Additional overloads for PackedFunc checking. +template +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const ArrayNode* n = static_cast(ptr); + for (const auto& p : n->data) { + if (!ObjectTypeChecker::Check(p.get())) { + return false; + } + } + return true; + } + static std::string TypeName() { + return "List[" + ObjectTypeChecker::TypeName() + "]"; + } +}; + +template +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const StrMapNode* n = static_cast(ptr); + for (const auto& kv : n->data) { + if (!ObjectTypeChecker::Check(kv.second.get())) return false; + } + return true; + } + static std::string TypeName() { + return "Map[str, " + + ObjectTypeChecker::TypeName()+ ']'; + } +}; + +template +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const MapNode* n = static_cast(ptr); + for (const auto& kv : n->data) { + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; + } + return true; + } + static std::string TypeName() { + return "Map[" + + ObjectTypeChecker::TypeName() + + ", " + + ObjectTypeChecker::TypeName()+ ']'; + } +}; +} // namespace runtime } // namespace tvm #endif // TVM_NODE_CONTAINER_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 54eb436..10c577a 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -56,6 +56,9 @@ using runtime::Downcast; using runtime::ObjectHash; using runtime::ObjectEqual; using runtime::make_object; +using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h deleted file mode 100644 index f7b0d08..0000000 --- a/include/tvm/packed_func_ext.h +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/packed_func_ext.h - * \brief Extension package to PackedFunc - * This enales pass ObjectRef types into/from PackedFunc. - */ -#ifndef TVM_PACKED_FUNC_EXT_H_ -#define TVM_PACKED_FUNC_EXT_H_ - -#include - -#include -#include -#include -#include - -#include "expr.h" -#include "runtime/packed_func.h" - -namespace tvm { - -using runtime::TVMArgs; -using runtime::TVMRetValue; -using runtime::PackedFunc; - -namespace runtime { - - -template -struct ObjectTypeChecker > { - static bool Check(const Object* ptr) { - if (ptr == nullptr) return true; - if (!ptr->IsInstance()) return false; - const ArrayNode* n = static_cast(ptr); - for (const auto& p : n->data) { - if (!ObjectTypeChecker::Check(p.get())) { - return false; - } - } - return true; - } - static std::string TypeName() { - return "List[" + ObjectTypeChecker::TypeName() + "]"; - } -}; - -template -struct ObjectTypeChecker > { - static bool Check(const Object* ptr) { - if (ptr == nullptr) return true; - if (!ptr->IsInstance()) return false; - const StrMapNode* n = static_cast(ptr); - for (const auto& kv : n->data) { - if (!ObjectTypeChecker::Check(kv.second.get())) return false; - } - return true; - } - static std::string TypeName() { - return "Map[str, " + - ObjectTypeChecker::TypeName()+ ']'; - } -}; - -template -struct ObjectTypeChecker > { - static bool Check(const Object* ptr) { - if (ptr == nullptr) return true; - if (!ptr->IsInstance()) return false; - const MapNode* n = static_cast(ptr); - for (const auto& kv : n->data) { - if (!ObjectTypeChecker::Check(kv.first.get())) return false; - if (!ObjectTypeChecker::Check(kv.second.get())) return false; - } - return true; - } - static std::string TypeName() { - return "Map[" + - ObjectTypeChecker::TypeName() + - ", " + - ObjectTypeChecker::TypeName()+ ']'; - } -}; - -// extensions for tvm arg value -inline TVMPODValue_::operator tvm::PrimExpr() const { - if (type_code_ == kTVMNullptr) return PrimExpr(); - if (type_code_ == kDLInt) { - CHECK_LE(value_.v_int64, std::numeric_limits::max()); - CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return PrimExpr(static_cast(value_.v_int64)); - } - if (type_code_ == kDLFloat) { - return PrimExpr(static_cast(value_.v_float64)); - } - - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - Object* ptr = static_cast(value_.v_handle); - - if (ptr->IsInstance()) { - return IterVar(ObjectPtr(ptr))->var; - } - if (ptr->IsInstance()) { - return top::Tensor(ObjectPtr(ptr))(); - } - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); - return PrimExpr(ObjectPtr(ptr)); -} - -inline TVMPODValue_::operator tvm::Integer() const { - if (type_code_ == kTVMNullptr) return Integer(); - if (type_code_ == kDLInt) { - CHECK_LE(value_.v_int64, std::numeric_limits::max()); - CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return Integer(static_cast(value_.v_int64)); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); - return Integer(ObjectPtr(ptr)); -} -} // namespace runtime -} // namespace tvm -#endif // TVM_PACKED_FUNC_EXT_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 58cfbfc..8d886aa 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -24,7 +24,6 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ -#include #include #include #include @@ -184,7 +183,7 @@ TVM_DLL Pass InferType(); * * \return The pass. */ -TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); +TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr); /*! * \brief Combine parallel 2d convolutions into a single convolution if the diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index d4243d8..0f81a1b 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -28,9 +28,7 @@ #include #include #include -#include #include - #include #include diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 0c28d08..0062379 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -29,7 +29,6 @@ #include #include #include -#include #include diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 4b74d02..9078507 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -25,8 +25,6 @@ #include #include #include -#include - #include namespace tvm { diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 1d997a2..6c1d193 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -26,8 +26,6 @@ #include #include #include -#include - namespace tvm { namespace codegen { diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 3b29ee4..45f7790 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 89c2c53..cf8e2c3 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -28,7 +28,6 @@ #include #include #include -#include #include #include diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index a822cc1..2154ec5 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -27,8 +27,6 @@ #include #include #include -#include - namespace tvm { namespace ir { diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 7aa305f..976f5a4 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include "../top/schedule/graph.h" diff --git a/src/api/api_test.cc b/src/api/api_test.cc index 957a034..f63adb1 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -26,8 +26,6 @@ #include #include #include -#include - namespace tvm { // Attrs used to python API diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index f1da23b..d6cd47b 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -26,8 +26,6 @@ #include #include #include -#include - #include #include diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 6e665c8..7db03c2 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -26,8 +26,6 @@ #include #include #include -#include - #include #include diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 82f0d2b..2d56596 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 3af368d..360d761 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -28,7 +28,6 @@ #include #include #include -#include #include #include diff --git a/src/codegen/codegen_c_host.cc b/src/codegen/codegen_c_host.cc index 083b25b..b7cd13c 100644 --- a/src/codegen/codegen_c_host.cc +++ b/src/codegen/codegen_c_host.cc @@ -20,7 +20,7 @@ /*! * \file codegen_c_host.cc */ -#include +#include #include #include #include "codegen_c_host.h" diff --git a/src/codegen/codegen_c_host.h b/src/codegen/codegen_c_host.h index 43fe98d..94544a8 100644 --- a/src/codegen/codegen_c_host.h +++ b/src/codegen/codegen_c_host.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_CODEGEN_C_HOST_H_ #include -#include +#include #include #include "codegen_c.h" diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index b39965e..6f394a1 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -22,7 +22,6 @@ */ #include -#include #include #include #include diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 23fbf7f..b7b7f84 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_CODEGEN_CUDA_H_ #include -#include +#include #include #include #include "codegen_c.h" diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc index 4e92fcb..234e628 100644 --- a/src/codegen/codegen_metal.cc +++ b/src/codegen/codegen_metal.cc @@ -20,7 +20,6 @@ /*! * \file codegen_metal.cc */ -#include #include #include #include diff --git a/src/codegen/codegen_metal.h b/src/codegen/codegen_metal.h index d9c5e95..6620789 100644 --- a/src/codegen/codegen_metal.h +++ b/src/codegen/codegen_metal.h @@ -25,7 +25,6 @@ #define TVM_CODEGEN_CODEGEN_METAL_H_ #include -#include #include #include "codegen_c.h" diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index ef90cfc..1a5107c 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -20,7 +20,6 @@ /*! * \file codegen_opencl.cc */ -#include #include #include #include diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 07b28fd..5a8bf12 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -25,7 +25,6 @@ #define TVM_CODEGEN_CODEGEN_OPENCL_H_ #include -#include #include #include "codegen_c.h" diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index cea276d..373bf38 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -23,7 +23,6 @@ * We are targeting OpenGL 3.3. The reason of not targeting a recent version * of OpenGL is to have better compatibility of WebGL 2. */ -#include #include #include #include diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index 19ca2ee..e7cebe0 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -25,7 +25,6 @@ #define TVM_CODEGEN_CODEGEN_OPENGL_H_ #include -#include #include #include #include diff --git a/src/codegen/codegen_vhls.h b/src/codegen/codegen_vhls.h index e406cb5..0651089 100644 --- a/src/codegen/codegen_vhls.h +++ b/src/codegen/codegen_vhls.h @@ -25,7 +25,7 @@ #define TVM_CODEGEN_CODEGEN_VHLS_H_ #include -#include +#include #include #include "codegen_c.h" diff --git a/src/codegen/datatype/registry.cc b/src/codegen/datatype/registry.cc index 62a7550..b49b395 100644 --- a/src/codegen/datatype/registry.cc +++ b/src/codegen/datatype/registry.cc @@ -16,15 +16,15 @@ * specific language governing permissions and limitations * under the License. */ - -#include "registry.h" #include -#include - +#include "registry.h" namespace tvm { namespace datatype { +using runtime::TVMArgs; +using runtime::TVMRetValue; + TVM_REGISTER_GLOBAL("_datatype_register") .set_body([](TVMArgs args, TVMRetValue* ret) { datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc index 0609989..699abd8 100644 --- a/src/codegen/intrin_rule.cc +++ b/src/codegen/intrin_rule.cc @@ -21,6 +21,7 @@ * \file intrin_rule_default.cc * \brief Default intrinsic rules. */ +#include #include "intrin_rule.h" namespace tvm { diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h index 56ba225..b6332f1 100644 --- a/src/codegen/intrin_rule.h +++ b/src/codegen/intrin_rule.h @@ -27,9 +27,6 @@ #include #include #include -#include - -#include #include namespace tvm { diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index 1f839f3..d81c33b 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -27,7 +27,6 @@ #include #include -#include #include #include diff --git a/src/codegen/llvm/intrin_rule_nvptx.cc b/src/codegen/llvm/intrin_rule_nvptx.cc index fcd8a1a..68475f0 100644 --- a/src/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/codegen/llvm/intrin_rule_nvptx.cc @@ -25,8 +25,6 @@ #include #include #include -#include - #include namespace tvm { diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index 41035af..477bcb4 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index d96883e..aa69ebf 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -21,7 +21,6 @@ * \file intrin_rule_spirv.cc */ #include -#include #include #include diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 8253007..c12f66f 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -21,7 +21,6 @@ * \file codegen_stackvm.cc */ #include -#include #include #include #include "codegen_stackvm.h" diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 6872568..378e8a8 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -31,6 +31,7 @@ #define TVM_IR_ATTR_FUNCTOR_H_ #include +#include #include namespace tvm { diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index d9063fb..54f5ee2 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -22,8 +22,6 @@ */ #include #include -#include - #include "attr_functor.h" namespace tvm { diff --git a/src/ir/error.cc b/src/ir/error.cc index 99db14e..62faf50 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -24,11 +24,11 @@ #include #include -// NOTE on dependencies on relay AsText. -// We calls into relay's printing module for better rendering. -// These dependency does not happen at the interface-level. -// And is only used to enhance developer experiences when relay -// functions are presented. +// NOTE: reverse dependency on relay. +// These dependencies do not happen at the interface-level, +// and are only used in minimum cases where they are clearly marked. +// +// Rationale: use relay's printer for astext. #include #include diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 0cf91c2..b173f4f 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -23,9 +23,30 @@ */ #include #include +// NOTE: reverse dependency on top/tir. +// These dependencies do not happen at the interface-level, +// and are only used in minimum cases where they are clearly marked. +// +// Rationale: convert from IterVar and top::Tensor +#include +#include namespace tvm { +PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { + using runtime::ObjectTypeChecker; + if (ptr->IsInstance()) { + return IterVar(ptr)->var; + } + if (ptr->IsInstance()) { + return top::Tensor(ptr)(); + } + CHECK(ObjectTypeChecker::Check(ptr.get())) + << "Expect type " << ObjectTypeChecker::TypeName() + << " but get " << ptr->GetTypeKey(); + return PrimExpr(ptr); +} + IntImm::IntImm(DataType dtype, int64_t value) { CHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar."; diff --git a/src/ir/module.cc b/src/ir/module.cc index 09abac7..01a8baa 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -23,12 +23,11 @@ */ #include #include -// NOTE on dependencies on relay analysis. -// We calls into relay's analysis module to verify correctness -// when a relay function is presented. -// These dependency does not happen at the interface-level. -// And is only used to enhance developer experiences when relay -// functions are presented. +// NOTE: reverse dependency on relay. +// These dependencies do not happen at the interface-level, +// and are only used in minimum cases where they are clearly marked. +// +// Rationale: We calls into relay's analysis module to verify correctness. #include #include diff --git a/src/ir/span.cc b/src/ir/span.cc index 1be4e32..2519321 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -22,7 +22,6 @@ */ #include #include -#include namespace tvm { diff --git a/src/ir/type.cc b/src/ir/type.cc index 4ba1607..9e250db 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -23,8 +23,6 @@ */ #include #include -#include - namespace tvm { PrimType::PrimType(runtime::DataType dtype) { diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index 06d665e..361525c 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -24,8 +24,6 @@ #include #include #include -#include - namespace tvm { TypeCall::TypeCall(Type func, tvm::Array args) { diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index 78d743a..302abea 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index af83f47..29bb5b4 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include #include "../arith/pattern_match.h" diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 2c99674..7292df6 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -20,7 +20,6 @@ * \file ir_functor.cc */ #include -#include namespace tvm { namespace ir { diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index 98eaf8c..b494328 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -23,7 +23,6 @@ #include #include -#include #include "../codegen/datatype/registry.h" namespace tvm { diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 7172a4a..4e1ea8d 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 24f3e19..f9c183e 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -25,7 +25,6 @@ */ #include -#include #include #include diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9336782..d32a6e2 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -24,7 +24,6 @@ #include "compile_engine.h" #include -#include #include #include #include diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index de03d97..fd41655 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -22,6 +22,7 @@ * \brief Memory index assignment pass for executing * the program in the graph runtime. */ +#include #include #include #include diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ff9dbba..7fdfdbb 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -21,7 +21,6 @@ * \file src/tvm/relay/interpreter.cc * \brief An interpreter for the Relay IR. */ -#include #include #include #include diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index e2d225a..dd42eb3 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -25,7 +25,7 @@ #define TVM_RELAY_BACKEND_PARAM_DICT_H_ #include -#include +#include #include #include diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 82b3513..85b17b5 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -24,7 +24,6 @@ #include #include -#include #include namespace tvm { diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 5680a78..a3b4668 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -22,6 +22,7 @@ * \brief The type system AST nodes of Relay. */ #include +#include namespace tvm { namespace relay { diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index 6a1b34d..d837e99 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -21,6 +21,7 @@ * \file multibox_op.cc * \brief Multibox related operators */ +#include #include #include diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 30a9a5c..be5ac51 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -21,6 +21,7 @@ * \file type_solver.cc * \brief Type solver implementations. */ +#include #include #include #include diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 2316bed..6c99ae1 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -26,6 +26,7 @@ #define TVM_RELAY_QNN_UTIL_H_ #include +#include #include #include #include diff --git a/src/top/operation/tensorize.cc b/src/top/operation/tensorize.cc index 413bb42..e7f6b33 100644 --- a/src/top/operation/tensorize.cc +++ b/src/top/operation/tensorize.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include "op_util.h" #include "compute_op.h" diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index 730d204..b87576e 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -21,7 +21,6 @@ #include #include #include -#include #include namespace tvm { diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 30834c5..31d82f0 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -22,7 +22,6 @@ #include #include #include -#include #include #include diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index d5d8aae..5988b2a 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include #include diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index f4f9601..550c93e 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -21,8 +21,6 @@ #include #include #include -#include -#include #include TEST(PackedFunc, Basic) { diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 462d0fe..bf0e338 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -28,8 +28,6 @@ #include #include #include -#include - TVM_REGISTER_GLOBAL("test.sch") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) { diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 4c383b5..4171f9d 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 73a6245..bde3245 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -33,7 +33,6 @@ #include #include #include -#include #include #include #include diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 8197e89..7ae4d88 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include -- 2.7.4