[REFACTOR] Get rid of packed_func_ext. (#4735)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 17 Jan 2020 23:11:55 +0000 (15:11 -0800)
committerGitHub <noreply@github.com>
Fri, 17 Jan 2020 23:11:55 +0000 (15:11 -0800)
Move the conversion extensions to the specific class definitions
so that we longer need to include packed_func_ext.

71 files changed:
apps/extension/src/tvm_ext.cc
include/tvm/expr.h
include/tvm/ir/expr.h
include/tvm/node/container.h
include/tvm/node/node.h
include/tvm/packed_func_ext.h [deleted file]
include/tvm/relay/transform.h
include/tvm/relay/type.h
src/api/api_arith.cc
src/api/api_base.cc
src/api/api_codegen.cc
src/api/api_ir.cc
src/api/api_lang.cc
src/api/api_pass.cc
src/api/api_schedule.cc
src/api/api_test.cc
src/arith/bound_deducer.cc
src/arith/domain_touched.cc
src/arith/int_set.cc
src/autotvm/touch_extractor.h
src/codegen/codegen_c_host.cc
src/codegen/codegen_c_host.h
src/codegen/codegen_cuda.cc
src/codegen/codegen_cuda.h
src/codegen/codegen_metal.cc
src/codegen/codegen_metal.h
src/codegen/codegen_opencl.cc
src/codegen/codegen_opencl.h
src/codegen/codegen_opengl.cc
src/codegen/codegen_opengl.h
src/codegen/codegen_vhls.h
src/codegen/datatype/registry.cc
src/codegen/intrin_rule.cc
src/codegen/intrin_rule.h
src/codegen/llvm/intrin_rule_llvm.h
src/codegen/llvm/intrin_rule_nvptx.cc
src/codegen/llvm/intrin_rule_rocm.cc
src/codegen/spirv/intrin_rule_spirv.cc
src/codegen/stackvm/codegen_stackvm.cc
src/ir/attr_functor.h
src/ir/attrs.cc
src/ir/error.cc
src/ir/expr.cc
src/ir/module.cc
src/ir/span.cc
src/ir/type.cc
src/ir/type_relation.cc
src/pass/hoist_if_then_else.cc
src/pass/inject_copy_intrin.cc
src/pass/ir_functor.cc
src/pass/lower_custom_datatypes.cc
src/pass/lower_intrin.cc
src/pass/verify_gpu_code.cc
src/relay/backend/compile_engine.cc
src/relay/backend/graph_plan_memory.cc
src/relay/backend/interpreter.cc
src/relay/backend/param_dict.h
src/relay/ir/base.cc
src/relay/ir/type.cc
src/relay/op/vision/multibox_op.cc
src/relay/pass/type_solver.cc
src/relay/qnn/util.h
src/top/operation/tensorize.cc
tests/cpp/attrs_test.cc
tests/cpp/build_module_test.cc
tests/cpp/container_test.cc
tests/cpp/packed_func_test.cc
tests/cpp/relay_build_module_test.cc
tests/cpp/relay_transform_sequential.cc
tests/cpp/utvm_runtime_standalone_test.cc
topi/src/topi.cc

index 7a685bf..b439deb 100644 (file)
@@ -26,8 +26,8 @@
 #include <tvm/runtime/module.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/ndarray.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/runtime/device_api.h>
+#include <tvm/expr_operator.h>
 
 using namespace tvm;
 using namespace tvm::runtime;
index bdd0b8f..3f154da 100644 (file)
@@ -29,6 +29,7 @@
 #include <algorithm>
 #include <unordered_map>
 #include <iostream>
+#include <limits>
 #include "node/node.h"
 #include "node/container.h"
 #include "node/functor.h"
@@ -460,6 +461,26 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
 }
 }  // namespace tvm
 
+namespace tvm {
+namespace runtime {
+// Additional implementattion overloads for PackedFunc.
+inline TVMPODValue_::operator tvm::Integer() const {
+  if (type_code_ == kTVMNullptr) return Integer();
+  if (type_code_ == kDLInt) {
+    CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
+    CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
+    return Integer(static_cast<int>(value_.v_int64));
+  }
+  TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
+  Object* ptr = static_cast<Object*>(value_.v_handle);
+  CHECK(ObjectTypeChecker<Integer>::Check(ptr))
+      << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
+      << " but get " << ptr->GetTypeKey();
+  return Integer(ObjectPtr<Object>(ptr));
+}
+}  // namespace runtime
+}  // namespace tvm
+
 namespace std {
 template <>
 struct hash<::tvm::IterVar> : public ::tvm::ObjectHash {
index 87122e8..e8e4597 100644 (file)
@@ -30,6 +30,7 @@
 #include <tvm/ir/span.h>
 #include <tvm/ir/type.h>
 #include <string>
+#include <limits>
 
 namespace tvm {
 
@@ -114,6 +115,11 @@ class PrimExpr : public BaseExpr {
   }
 
   TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
+
+ private:
+  // Internal function for conversion.
+  friend class runtime::TVMPODValue_;
+  TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
 };
 
 /*!
@@ -322,4 +328,25 @@ inline const TTypeNode* RelayExprNode::type_as() const {
 }
 
 }  // namespace tvm
+
+namespace tvm {
+namespace runtime {
+// Additional implementattion overloads for PackedFunc.
+inline TVMPODValue_::operator tvm::PrimExpr() const {
+  if (type_code_ == kTVMNullptr) return PrimExpr();
+  if (type_code_ == kDLInt) {
+    CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
+    CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
+    return PrimExpr(static_cast<int>(value_.v_int64));
+  }
+  if (type_code_ == kDLFloat) {
+    return PrimExpr(static_cast<float>(value_.v_float64));
+  }
+
+  TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
+  Object* ptr = static_cast<Object*>(value_.v_handle);
+  return PrimExpr::FromObject_(ObjectPtr<Object>(ptr));
+}
+}  // namespace runtime
+}  // namespace tvm
 #endif  // TVM_IR_EXPR_H_
index 7686a96..f5c7198 100644 (file)
@@ -655,6 +655,65 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
     return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key));
   }
 };
+}  // namespace tvm
 
+namespace tvm {
+namespace runtime {
+// Additional overloads for PackedFunc checking.
+template<typename T>
+struct ObjectTypeChecker<Array<T> > {
+  static bool Check(const Object* ptr) {
+    if (ptr == nullptr) return true;
+    if (!ptr->IsInstance<ArrayNode>()) return false;
+    const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
+    for (const auto& p : n->data) {
+      if (!ObjectTypeChecker<T>::Check(p.get())) {
+        return false;
+      }
+    }
+    return true;
+  }
+  static std::string TypeName() {
+    return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
+  }
+};
+
+template<typename V>
+struct ObjectTypeChecker<Map<std::string, V> > {
+  static bool Check(const Object* ptr) {
+    if (ptr == nullptr) return true;
+    if (!ptr->IsInstance<StrMapNode>()) return false;
+    const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
+    for (const auto& kv : n->data) {
+      if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
+    }
+    return true;
+  }
+  static std::string TypeName() {
+    return "Map[str, " +
+        ObjectTypeChecker<V>::TypeName()+ ']';
+  }
+};
+
+template<typename K, typename V>
+struct ObjectTypeChecker<Map<K, V> > {
+  static bool Check(const Object* ptr) {
+    if (ptr == nullptr) return true;
+    if (!ptr->IsInstance<MapNode>()) return false;
+    const MapNode* n = static_cast<const MapNode*>(ptr);
+    for (const auto& kv : n->data) {
+      if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
+      if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
+    }
+    return true;
+  }
+  static std::string TypeName() {
+    return "Map[" +
+        ObjectTypeChecker<K>::TypeName() +
+        ", " +
+        ObjectTypeChecker<V>::TypeName()+ ']';
+  }
+};
+}  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_NODE_CONTAINER_H_
index 54eb436..10c577a 100644 (file)
@@ -56,6 +56,9 @@ using runtime::Downcast;
 using runtime::ObjectHash;
 using runtime::ObjectEqual;
 using runtime::make_object;
+using runtime::PackedFunc;
+using runtime::TVMArgs;
+using runtime::TVMRetValue;
 
 }  // namespace tvm
 #endif  // TVM_NODE_NODE_H_
diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h
deleted file mode 100644 (file)
index f7b0d08..0000000
+++ /dev/null
@@ -1,145 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/packed_func_ext.h
- * \brief Extension package to PackedFunc
- *   This enales pass ObjectRef types into/from PackedFunc.
- */
-#ifndef TVM_PACKED_FUNC_EXT_H_
-#define TVM_PACKED_FUNC_EXT_H_
-
-#include <tvm/top/tensor.h>
-
-#include <string>
-#include <memory>
-#include <limits>
-#include <type_traits>
-
-#include "expr.h"
-#include "runtime/packed_func.h"
-
-namespace tvm {
-
-using runtime::TVMArgs;
-using runtime::TVMRetValue;
-using runtime::PackedFunc;
-
-namespace runtime {
-
-
-template<typename T>
-struct ObjectTypeChecker<Array<T> > {
-  static bool Check(const Object* ptr) {
-    if (ptr == nullptr) return true;
-    if (!ptr->IsInstance<ArrayNode>()) return false;
-    const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
-    for (const auto& p : n->data) {
-      if (!ObjectTypeChecker<T>::Check(p.get())) {
-        return false;
-      }
-    }
-    return true;
-  }
-  static std::string TypeName() {
-    return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
-  }
-};
-
-template<typename V>
-struct ObjectTypeChecker<Map<std::string, V> > {
-  static bool Check(const Object* ptr) {
-    if (ptr == nullptr) return true;
-    if (!ptr->IsInstance<StrMapNode>()) return false;
-    const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
-    for (const auto& kv : n->data) {
-      if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
-    }
-    return true;
-  }
-  static std::string TypeName() {
-    return "Map[str, " +
-        ObjectTypeChecker<V>::TypeName()+ ']';
-  }
-};
-
-template<typename K, typename V>
-struct ObjectTypeChecker<Map<K, V> > {
-  static bool Check(const Object* ptr) {
-    if (ptr == nullptr) return true;
-    if (!ptr->IsInstance<MapNode>()) return false;
-    const MapNode* n = static_cast<const MapNode*>(ptr);
-    for (const auto& kv : n->data) {
-      if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
-      if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
-    }
-    return true;
-  }
-  static std::string TypeName() {
-    return "Map[" +
-        ObjectTypeChecker<K>::TypeName() +
-        ", " +
-        ObjectTypeChecker<V>::TypeName()+ ']';
-  }
-};
-
-// extensions for tvm arg value
-inline TVMPODValue_::operator tvm::PrimExpr() const {
-  if (type_code_ == kTVMNullptr) return PrimExpr();
-  if (type_code_ == kDLInt) {
-    CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
-    CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
-    return PrimExpr(static_cast<int>(value_.v_int64));
-  }
-  if (type_code_ == kDLFloat) {
-    return PrimExpr(static_cast<float>(value_.v_float64));
-  }
-
-  TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
-  Object* ptr = static_cast<Object*>(value_.v_handle);
-
-  if (ptr->IsInstance<IterVarNode>()) {
-    return IterVar(ObjectPtr<Object>(ptr))->var;
-  }
-  if (ptr->IsInstance<top::TensorNode>()) {
-    return top::Tensor(ObjectPtr<Object>(ptr))();
-  }
-  CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
-      << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
-      << " but get " << ptr->GetTypeKey();
-  return PrimExpr(ObjectPtr<Object>(ptr));
-}
-
-inline TVMPODValue_::operator tvm::Integer() const {
-  if (type_code_ == kTVMNullptr) return Integer();
-  if (type_code_ == kDLInt) {
-    CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
-    CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
-    return Integer(static_cast<int>(value_.v_int64));
-  }
-  TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
-  Object* ptr = static_cast<Object*>(value_.v_handle);
-  CHECK(ObjectTypeChecker<Integer>::Check(ptr))
-      << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
-      << " but get " << ptr->GetTypeKey();
-  return Integer(ObjectPtr<Object>(ptr));
-}
-}  // namespace runtime
-}  // namespace tvm
-#endif  // TVM_PACKED_FUNC_EXT_H_
index 58cfbfc..8d886aa 100644 (file)
@@ -24,7 +24,6 @@
 #ifndef TVM_RELAY_TRANSFORM_H_
 #define TVM_RELAY_TRANSFORM_H_
 
-#include <tvm/packed_func_ext.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/ir/transform.h>
 #include <tvm/relay/expr.h>
@@ -184,7 +183,7 @@ TVM_DLL Pass InferType();
  *
  * \return The pass.
  */
-TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
+TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr);
 
 /*!
  * \brief Combine parallel 2d convolutions into a single convolution if the
index d4243d8..0f81a1b 100644 (file)
@@ -28,9 +28,7 @@
 #include <tvm/ir/type_relation.h>
 #include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/ir/env_func.h>
-
 #include <tvm/ir.h>
 #include <string>
 
index 0c28d08..0062379 100644 (file)
@@ -29,7 +29,6 @@
 #include <tvm/expr.h>
 #include <tvm/ir.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <tvm/top/tensor.h>
 
index 4b74d02..9078507 100644 (file)
@@ -25,8 +25,6 @@
 #include <tvm/expr.h>
 #include <tvm/top/tensor.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 #include <tvm/node/serialization.h>
 
 namespace tvm {
index 1d997a2..6c1d193 100644 (file)
@@ -26,8 +26,6 @@
 #include <tvm/codegen.h>
 #include <tvm/lowered_func.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 
 namespace tvm {
 namespace codegen {
index 3b29ee4..45f7790 100644 (file)
@@ -24,7 +24,6 @@
 #include <tvm/expr.h>
 #include <tvm/ir.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <tvm/expr_operator.h>
 
index 89c2c53..cf8e2c3 100644 (file)
@@ -28,7 +28,6 @@
 #include <tvm/buffer.h>
 #include <tvm/top/schedule.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <tvm/build_module.h>
 #include <tvm/data_layout.h>
index a822cc1..2154ec5 100644 (file)
@@ -27,8 +27,6 @@
 #include <tvm/ir_pass.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 
 namespace tvm {
 namespace ir {
index 7aa305f..976f5a4 100644 (file)
@@ -26,7 +26,6 @@
 #include <tvm/top/schedule.h>
 #include <tvm/top/schedule_pass.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include "../top/schedule/graph.h"
 
index 957a034..f63adb1 100644 (file)
@@ -26,8 +26,6 @@
 #include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/ir/env_func.h>
-#include <tvm/packed_func_ext.h>
-
 
 namespace tvm {
 // Attrs used to python API
index f1da23b..d6cd47b 100644 (file)
@@ -26,8 +26,6 @@
 #include <tvm/ir_functor_ext.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 
 #include <unordered_set>
 #include <unordered_map>
index 6e665c8..7db03c2 100644 (file)
@@ -26,8 +26,6 @@
 #include <tvm/ir_functor_ext.h>
 #include <tvm/top/tensor.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 
 #include <unordered_set>
 #include <unordered_map>
index 82f0d2b..2d56596 100644 (file)
@@ -25,7 +25,6 @@
 #include <tvm/ir.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <utility>
 #include <algorithm>
index 3af368d..360d761 100644 (file)
@@ -28,7 +28,6 @@
 #include <tvm/ir.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <stack>
 #include <vector>
index 083b25b..b7cd13c 100644 (file)
@@ -20,7 +20,7 @@
 /*!
  * \file codegen_c_host.cc
  */
-#include <tvm/packed_func_ext.h>
+#include <tvm/codegen.h>
 #include <vector>
 #include <string>
 #include "codegen_c_host.h"
index 43fe98d..94544a8 100644 (file)
@@ -25,7 +25,7 @@
 #define TVM_CODEGEN_CODEGEN_C_HOST_H_
 
 #include <tvm/codegen.h>
-#include <tvm/packed_func_ext.h>
+#include <tvm/ir.h>
 #include <string>
 #include "codegen_c.h"
 
index b39965e..6f394a1 100644 (file)
@@ -22,7 +22,6 @@
  */
 
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 #include <cmath>
 #include <vector>
 #include <string>
index 23fbf7f..b7b7f84 100644 (file)
@@ -25,7 +25,7 @@
 #define TVM_CODEGEN_CODEGEN_CUDA_H_
 
 #include <tvm/codegen.h>
-#include <tvm/packed_func_ext.h>
+#include <tvm/ir.h>
 #include <string>
 #include <unordered_map>
 #include "codegen_c.h"
index 4e92fcb..234e628 100644 (file)
@@ -20,7 +20,6 @@
 /*!
  * \file codegen_metal.cc
  */
-#include <tvm/packed_func_ext.h>
 #include <vector>
 #include <string>
 #include <algorithm>
index d9c5e95..6620789 100644 (file)
@@ -25,7 +25,6 @@
 #define TVM_CODEGEN_CODEGEN_METAL_H_
 
 #include <tvm/codegen.h>
-#include <tvm/packed_func_ext.h>
 #include <string>
 #include "codegen_c.h"
 
index ef90cfc..1a5107c 100644 (file)
@@ -20,7 +20,6 @@
 /*!
  * \file codegen_opencl.cc
  */
-#include <tvm/packed_func_ext.h>
 #include <cmath>
 #include <vector>
 #include <string>
index 07b28fd..5a8bf12 100644 (file)
@@ -25,7 +25,6 @@
 #define TVM_CODEGEN_CODEGEN_OPENCL_H_
 
 #include <tvm/codegen.h>
-#include <tvm/packed_func_ext.h>
 #include <string>
 #include "codegen_c.h"
 
index cea276d..373bf38 100644 (file)
@@ -23,7 +23,6 @@
  * We are targeting OpenGL 3.3. The reason of not targeting a recent version
  * of OpenGL is to have better compatibility of WebGL 2.
  */
-#include <tvm/packed_func_ext.h>
 #include <vector>
 #include <string>
 #include <utility>
index 19ca2ee..e7cebe0 100644 (file)
@@ -25,7 +25,6 @@
 #define TVM_CODEGEN_CODEGEN_OPENGL_H_
 
 #include <tvm/codegen.h>
-#include <tvm/packed_func_ext.h>
 #include <string>
 #include <unordered_set>
 #include <unordered_map>
index e406cb5..0651089 100644 (file)
@@ -25,7 +25,7 @@
 #define TVM_CODEGEN_CODEGEN_VHLS_H_
 
 #include <tvm/codegen.h>
-#include <tvm/packed_func_ext.h>
+#include <tvm/ir.h>
 #include <string>
 #include "codegen_c.h"
 
index 62a7550..b49b395 100644 (file)
  * specific language governing permissions and limitations
  * under the License.
  */
-
-#include "registry.h"
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
+#include "registry.h"
 
 namespace tvm {
 namespace datatype {
 
+using runtime::TVMArgs;
+using runtime::TVMRetValue;
+
 TVM_REGISTER_GLOBAL("_datatype_register")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
index 0609989..699abd8 100644 (file)
@@ -21,6 +21,7 @@
  * \file intrin_rule_default.cc
  * \brief Default intrinsic rules.
  */
+#include <tvm/expr_operator.h>
 #include "intrin_rule.h"
 
 namespace tvm {
index 56ba225..b6332f1 100644 (file)
@@ -27,9 +27,6 @@
 #include <tvm/ir.h>
 #include <tvm/expr.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
-#include <tvm/runtime/registry.h>
 #include <string>
 
 namespace tvm {
index 1f839f3..d81c33b 100644 (file)
@@ -27,7 +27,6 @@
 
 #include <tvm/ir.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <tvm/codegen.h>
 #include <string>
index fcd8a1a..68475f0 100644 (file)
@@ -25,8 +25,6 @@
 #include <tvm/ir.h>
 #include <tvm/expr.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 #include <sstream>
 
 namespace tvm {
index 41035af..477bcb4 100644 (file)
@@ -25,7 +25,6 @@
 #include <tvm/ir.h>
 #include <tvm/expr.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <sstream>
 
index d96883e..aa69ebf 100644 (file)
@@ -21,7 +21,6 @@
  * \file intrin_rule_spirv.cc
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/ir.h>
 #include <GLSL.std.450.h>
 
index 8253007..c12f66f 100644 (file)
@@ -21,7 +21,6 @@
  * \file codegen_stackvm.cc
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 #include <limits>
 #include <utility>
 #include "codegen_stackvm.h"
index 6872568..378e8a8 100644 (file)
@@ -31,6 +31,7 @@
 #define TVM_IR_ATTR_FUNCTOR_H_
 
 #include <tvm/node/functor.h>
+#include <tvm/ir.h>
 #include <utility>
 
 namespace tvm {
index d9063fb..54f5ee2 100644 (file)
@@ -22,8 +22,6 @@
  */
 #include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 #include "attr_functor.h"
 
 namespace tvm {
index 99db14e..62faf50 100644 (file)
 
 #include <tvm/ir/module.h>
 #include <tvm/ir/error.h>
-// NOTE on dependencies on relay AsText.
-// We calls into relay's printing module for better rendering.
-// These dependency does not happen at the interface-level.
-// And is only used to enhance developer experiences when relay
-// functions are presented.
+// NOTE: reverse dependency on relay.
+// These dependencies do not happen at the interface-level,
+// and are only used in minimum cases where they are clearly marked.
+//
+// Rationale: use relay's printer for astext.
 #include <tvm/relay/expr.h>
 
 #include <string>
index 0cf91c2..b173f4f 100644 (file)
  */
 #include <tvm/runtime/registry.h>
 #include <tvm/ir/expr.h>
+// NOTE: reverse dependency on top/tir.
+// These dependencies do not happen at the interface-level,
+// and are only used in minimum cases where they are clearly marked.
+//
+// Rationale: convert from IterVar and top::Tensor
+#include <tvm/top/tensor.h>
+#include <tvm/expr.h>
 
 namespace tvm {
 
+PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
+  using runtime::ObjectTypeChecker;
+  if (ptr->IsInstance<IterVarNode>()) {
+    return IterVar(ptr)->var;
+  }
+  if (ptr->IsInstance<top::TensorNode>()) {
+    return top::Tensor(ptr)();
+  }
+  CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
+      << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
+      << " but get " << ptr->GetTypeKey();
+  return PrimExpr(ptr);
+}
+
 IntImm::IntImm(DataType dtype, int64_t value) {
   CHECK(dtype.is_scalar())
       << "ValueError: IntImm can only take scalar.";
index 09abac7..01a8baa 100644 (file)
  */
 #include <tvm/runtime/registry.h>
 #include <tvm/ir/module.h>
-// NOTE on dependencies on relay analysis.
-// We calls into relay's analysis module to verify correctness
-// when a relay function is presented.
-// These dependency does not happen at the interface-level.
-// And is only used to enhance developer experiences when relay
-// functions are presented.
+// NOTE: reverse dependency on relay.
+// These dependencies do not happen at the interface-level,
+// and are only used in minimum cases where they are clearly marked.
+//
+// Rationale: We calls into relay's analysis module to verify correctness.
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/transform.h>
 
index 1be4e32..2519321 100644 (file)
@@ -22,7 +22,6 @@
  */
 #include <tvm/ir/span.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 namespace tvm {
 
index 4ba1607..9e250db 100644 (file)
@@ -23,8 +23,6 @@
  */
 #include <tvm/ir/type.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 namespace tvm {
 
 PrimType::PrimType(runtime::DataType dtype) {
index 06d665e..361525c 100644 (file)
@@ -24,8 +24,6 @@
 #include <tvm/ir/type.h>
 #include <tvm/ir/type_relation.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 namespace tvm {
 
 TypeCall::TypeCall(Type func, tvm::Array<Type> args) {
index 78d743a..302abea 100644 (file)
@@ -24,7 +24,6 @@
 #include <tvm/ir_functor_ext.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <unordered_map>
 #include <unordered_set>
index af83f47..29bb5b4 100644 (file)
@@ -23,7 +23,6 @@
  */
 #include <tvm/arith/pattern.h>
 #include <tvm/ir.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/ir_pass.h>
 #include "../arith/pattern_match.h"
index 2c99674..7292df6 100644 (file)
@@ -20,7 +20,6 @@
  * \file ir_functor.cc
  */
 #include <tvm/ir_functor_ext.h>
-#include <tvm/packed_func_ext.h>
 
 namespace tvm {
 namespace ir {
index 98eaf8c..b494328 100644 (file)
@@ -23,7 +23,6 @@
 
 #include <tvm/ir_functor_ext.h>
 #include <tvm/ir_pass.h>
-#include <tvm/packed_func_ext.h>
 #include "../codegen/datatype/registry.h"
 
 namespace tvm {
index 7172a4a..4e1ea8d 100644 (file)
@@ -24,7 +24,6 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <tvm/expr_operator.h>
 #include <unordered_set>
index 24f3e19..f9c183e 100644 (file)
@@ -25,7 +25,6 @@
  */
 
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include <tvm/ir.h>
 #include <tvm/ir_functor_ext.h>
index 9336782..d32a6e2 100644 (file)
@@ -24,7 +24,6 @@
 #include "compile_engine.h"
 
 #include <tvm/top/schedule.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/top/operation.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/relay/attrs/device_copy.h>
index de03d97..fd41655 100644 (file)
@@ -22,6 +22,7 @@
  * \brief Memory index assignment pass for executing
  *   the program in the graph runtime.
  */
+#include <tvm/expr_operator.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/analysis.h>
index ff9dbba..7fdfdbb 100644 (file)
@@ -21,7 +21,6 @@
  * \file src/tvm/relay/interpreter.cc
  * \brief An interpreter for the Relay IR.
  */
-#include <tvm/packed_func_ext.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
index e2d225a..dd42eb3 100644 (file)
@@ -25,7 +25,7 @@
 #define TVM_RELAY_BACKEND_PARAM_DICT_H_
 
 #include <tvm/node/node.h>
-#include <tvm/packed_func_ext.h>
+#include <tvm/ir.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 
index 82b3513..85b17b5 100644 (file)
@@ -24,7 +24,6 @@
 
 #include <tvm/ir/type.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/relay/base.h>
 
 namespace tvm {
index 5680a78..a3b4668 100644 (file)
@@ -22,6 +22,7 @@
  * \brief The type system AST nodes of Relay.
  */
 #include <tvm/relay/type.h>
+#include <tvm/expr_operator.h>
 
 namespace tvm {
 namespace relay {
index 6a1b34d..d837e99 100644 (file)
@@ -21,6 +21,7 @@
  * \file multibox_op.cc
  * \brief Multibox related operators
  */
+#include <tvm/expr_operator.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/vision.h>
 
index 30a9a5c..be5ac51 100644 (file)
@@ -21,6 +21,7 @@
  * \file type_solver.cc
  * \brief Type solver implementations.
  */
+#include <tvm/expr_operator.h>
 #include <string>
 #include <memory>
 #include <tuple>
index 2316bed..6c99ae1 100644 (file)
@@ -26,6 +26,7 @@
 #define TVM_RELAY_QNN_UTIL_H_
 
 #include <tvm/expr.h>
+#include <tvm/expr_operator.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/qnn/attrs.h>
 #include <limits>
index 413bb42..e7f6b33 100644 (file)
@@ -25,7 +25,6 @@
 #include <tvm/ir_functor_ext.h>
 #include <tvm/ir_pass.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 
 #include "op_util.h"
 #include "compute_op.h"
index 730d204..b87576e 100644 (file)
@@ -21,7 +21,6 @@
 #include <gtest/gtest.h>
 #include <tvm/ir/attrs.h>
 #include <tvm/expr_operator.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/ir.h>
 
 namespace tvm {
index 30834c5..31d82f0 100644 (file)
@@ -22,7 +22,6 @@
 #include <topi/cuda/injective.h>
 #include <tvm/top/operation.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/build_module.h>
 
 #include <string>
index d5d8aae..5988b2a 100644 (file)
@@ -19,7 +19,7 @@
 
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
-#include <tvm/packed_func_ext.h>
+#include <tvm/expr_operator.h>
 #include <tvm/runtime/container.h>
 #include <new>
 #include <unordered_map>
index f4f9601..550c93e 100644 (file)
@@ -21,8 +21,6 @@
 #include <gtest/gtest.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-#include <tvm/runtime/registry.h>
 #include <tvm/ir.h>
 
 TEST(PackedFunc, Basic) {
index 462d0fe..bf0e338 100644 (file)
@@ -28,8 +28,6 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/module.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
-
 
 TVM_REGISTER_GLOBAL("test.sch")
 .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
index 4c383b5..4171f9d 100644 (file)
@@ -20,7 +20,6 @@
 #include <gtest/gtest.h>
 #include <topi/generic/injective.h>
 #include <tvm/build_module.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/relay/expr.h>
 #include <tvm/ir/module.h>
 #include <tvm/relay/analysis.h>
index 73a6245..bde3245 100644 (file)
@@ -33,7 +33,6 @@
 #include <topi/generic/injective.h>
 #include <tvm/build_module.h>
 #include <tvm/top/operation.h>
-#include <tvm/packed_func_ext.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/transform.h>
index 8197e89..7ae4d88 100644 (file)
@@ -26,7 +26,7 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/module.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/packed_func_ext.h>
+#include <tvm/ir/expr.h>
 #include <tvm/build_module.h>
 
 #include <topi/broadcast.h>