[REFACTOR] Polish runtime (#4729)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 17 Jan 2020 04:18:57 +0000 (20:18 -0800)
committerGitHub <noreply@github.com>
Fri, 17 Jan 2020 04:18:57 +0000 (20:18 -0800)
- Remove operator bool from base object ref macro
  - Raitionale: operator bool can be dangerous for sub-classes
    that also overloads other operators(e.g. ==).
  - If bool is still needed, use explicit operator bool.
- Use absolute include when necessary
- Move type related util to data_type
- Isolate stackvm code from compiler

27 files changed:
include/tvm/ir.h
include/tvm/ir/expr.h
include/tvm/runtime/c_backend_api.h
include/tvm/runtime/data_type.h
include/tvm/runtime/device_api.h
include/tvm/runtime/memory.h
include/tvm/runtime/object.h
include/tvm/runtime/serializer.h
include/tvm/runtime/util.h [deleted file]
src/codegen/stackvm/codegen_stackvm.cc
src/pass/hoist_if_then_else.cc
src/relay/op/type_relations.cc
src/relay/pass/partial_eval.cc
src/runtime/contrib/cblas/cblas.cc
src/runtime/contrib/cblas/gemm_common.h
src/runtime/contrib/cublas/cublas.cc
src/runtime/contrib/cudnn/conv_forward.cc
src/runtime/contrib/miopen/conv_forward.cc
src/runtime/contrib/mps/mps_utils.h
src/runtime/contrib/nnpack/convolution.cc
src/runtime/contrib/nnpack/fully_connected.cc
src/runtime/contrib/nnpack/nnpack_utils.h
src/runtime/contrib/random/random.cc
src/runtime/contrib/rocblas/rocblas.cc
src/runtime/contrib/sort/sort.cc
src/runtime/stackvm/stackvm.cc
src/runtime/stackvm/stackvm.h

index 4e36332..ff4b47f 100644 (file)
@@ -30,7 +30,6 @@
 #include <vector>
 #include <utility>
 #include "expr.h"
-#include "runtime/util.h"
 
 namespace tvm {
 namespace ir {
@@ -1677,6 +1676,25 @@ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
  */
 constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
 
+/*! \brief The kind of structure field info used in intrinsic */
+enum TVMStructFieldKind : int {
+  // array head address
+  kArrAddr,
+  kArrData,
+  kArrShape,
+  kArrStrides,
+  kArrNDim,
+  kArrTypeCode,
+  kArrTypeBits,
+  kArrTypeLanes,
+  kArrByteOffset,
+  kArrDeviceId,
+  kArrDeviceType,
+  kArrKindBound_,
+  // TVMValue field
+  kTVMValueContent,
+  kTVMValueKindBound_
+};
 }   // namespace intrinsic
 
 /*!
index ddb5f80..87122e8 100644 (file)
@@ -49,15 +49,7 @@ class BaseExprNode : public Object {
  */
 class BaseExpr : public ObjectRef {
  public:
-  /*! \brief Cosntructor */
-  BaseExpr() {}
-  /*!
-   * \brief Cosntructor from object ptr.
-   * \param ptr The object pointer.
-   */
-  explicit BaseExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
-  /*! \brief The container type. */
-  using ContainerType = BaseExprNode;
+  TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
 };
 
 /*!
@@ -100,13 +92,6 @@ class PrimExprNode : public BaseExprNode {
  */
 class PrimExpr : public BaseExpr {
  public:
-    /*! \brief Cosntructor */
-  PrimExpr() {}
-  /*!
-   * \brief Cosntructor from object ptr.
-   * \param ptr The object pointer.
-   */
-  explicit PrimExpr(ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
   /*!
    * \brief construct from integer.
    * \param value The value to be constructed.
@@ -127,8 +112,8 @@ class PrimExpr : public BaseExpr {
   DataType dtype() const {
     return static_cast<const PrimExprNode*>(get())->dtype;
   }
-  /*! \brief The container type. */
-  using ContainerType = PrimExprNode;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
 };
 
 /*!
@@ -157,28 +142,13 @@ class IntImmNode : public PrimExprNode {
 class IntImm : public PrimExpr {
  public:
   /*!
-   * \brief Constructor
-   */
-  IntImm() {}
-  /*!
-   * \brief constructor from node.
-   */
-  explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
-  /*!
    * \brief Constructor.
    * \param dtype The data type of the value.
    * \param value The internal value.
    */
   TVM_DLL IntImm(DataType dtype, int64_t value);
-  /*!
-   * \brief Get pointer to the internal value.
-   * \return the content of the integer.
-   */
-  const IntImmNode* operator->() const {
-    return static_cast<const IntImmNode*>(get());
-  }
-  /*! \brief type indicate the container type */
-  using ContainerType = IntImmNode;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
 };
 
 /*!
@@ -207,28 +177,13 @@ class FloatImmNode : public PrimExprNode {
 class FloatImm : public PrimExpr {
  public:
   /*!
-   * \brief Constructor
-   */
-  FloatImm() {}
-  /*!
-   * \brief constructor from node.
-   */
-  explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
-  /*!
    * \brief Constructor.
    * \param dtype The data type of the value.
    * \param value The internal value.
    */
   TVM_DLL FloatImm(DataType dtype, double value);
-  /*!
-   * \brief Get pointer to the container.
-   * \return The pointer.
-   */
-  const FloatImmNode* operator->() const {
-    return static_cast<const FloatImmNode*>(get());
-  }
-  /*! \brief type indicate the container type */
-  using ContainerType = FloatImmNode;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
 };
 
 /*!
index ffd13ec..abfc792 100644 (file)
@@ -28,7 +28,7 @@
 #ifndef TVM_RUNTIME_C_BACKEND_API_H_
 #define TVM_RUNTIME_C_BACKEND_API_H_
 
-#include "c_runtime_api.h"
+#include <tvm/runtime/c_runtime_api.h>
 
 #ifdef __cplusplus
 extern "C" {
index c91c2cf..cb58e97 100644 (file)
@@ -28,7 +28,6 @@
 #include <dmlc/logging.h>
 #include <type_traits>
 
-
 namespace tvm {
 namespace runtime {
 /*!
@@ -233,6 +232,24 @@ inline int GetVectorBytes(DataType dtype) {
   return data_bits / 8;
 }
 
+/*!
+ * \brief Check whether type matches the given spec.
+ * \param t The type
+ * \param code The type code.
+ * \param bits The number of bits to be matched.
+ * \param lanes The number of lanes in the type.
+ */
+inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
+  return t.code == code && t.bits == bits && t.lanes == lanes;
+}
+/*!
+ * \brief Check whether two types are equal .
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ */
+inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
+  return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
+}
 }  // namespace runtime
 
 using DataType = runtime::DataType;
index 7212aad..00508a1 100644 (file)
@@ -24,9 +24,9 @@
 #ifndef TVM_RUNTIME_DEVICE_API_H_
 #define TVM_RUNTIME_DEVICE_API_H_
 
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/packed_func.h>
 #include <string>
-#include "packed_func.h"
-#include "c_runtime_api.h"
 
 namespace tvm {
 namespace runtime {
index 10e8be3..121dbdd 100644 (file)
 #ifndef TVM_RUNTIME_MEMORY_H_
 #define TVM_RUNTIME_MEMORY_H_
 
+#include <tvm/runtime/object.h>
 #include <cstdlib>
 #include <utility>
 #include <type_traits>
-#include "object.h"
 
 namespace tvm {
 namespace runtime {
index a2e9188..8ef9cb4 100644 (file)
@@ -29,7 +29,6 @@
 #include <string>
 #include <utility>
 
-
 /*!
  * \brief Whether or not use atomic reference counter.
  *  If the reference counter is not atomic,
@@ -715,7 +714,6 @@ struct ObjectEqual {
   const ObjectName* operator->() const {                                \
     return static_cast<const ObjectName*>(data_.get());                 \
   }                                                                     \
-  operator bool() const { return data_ != nullptr; }                    \
   using ContainerType = ObjectName;
 
 /*
@@ -734,7 +732,6 @@ struct ObjectEqual {
   ObjectName* operator->() const {                                      \
     return static_cast<ObjectName*>(data_.get());                       \
   }                                                                     \
-  operator bool() const { return data_ != nullptr; }                    \
   using ContainerType = ObjectName;
 
 /*!
index ca968c4..37bb95f 100644 (file)
@@ -27,8 +27,8 @@
 
 #include <dmlc/io.h>
 #include <dmlc/serializer.h>
-#include "c_runtime_api.h"
-#include "ndarray.h"
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/ndarray.h>
 
 namespace dmlc {
 namespace serializer {
diff --git a/include/tvm/runtime/util.h b/include/tvm/runtime/util.h
deleted file mode 100644 (file)
index 8e213dd..0000000
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/runtime/util.h
- * \brief Useful runtime util.
- */
-#ifndef TVM_RUNTIME_UTIL_H_
-#define TVM_RUNTIME_UTIL_H_
-
-#include "c_runtime_api.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief Check whether type matches the given spec.
- * \param t The type
- * \param code The type code.
- * \param bits The number of bits to be matched.
- * \param lanes The number of lanes in the type.
- */
-inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
-  return t.code == code && t.bits == bits && t.lanes == lanes;
-}
-/*!
- * \brief Check whether two types are equal .
- * \param lhs The left operand.
- * \param rhs The right operand.
- */
-inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
-  return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
-}
-}  // namespace runtime
-}  // namespace tvm
-// Forward declare the intrinsic id we need
-// in structure fetch to enable stackvm in runtime
-namespace tvm {
-namespace ir {
-namespace intrinsic {
-/*! \brief The kind of structure field info used in intrinsic */
-enum TVMStructFieldKind : int {
-  // array head address
-  kArrAddr,
-  kArrData,
-  kArrShape,
-  kArrStrides,
-  kArrNDim,
-  kArrTypeCode,
-  kArrTypeBits,
-  kArrTypeLanes,
-  kArrByteOffset,
-  kArrDeviceId,
-  kArrDeviceType,
-  kArrKindBound_,
-  // TVMValue field
-  kTVMValueContent,
-  kTVMValueKindBound_
-};
-}  // namespace intrinsic
-}  // namespace ir
-}  // namespace tvm
-#endif  // TVM_RUNTIME_UTIL_H_
index f4b8fbe..8253007 100644 (file)
@@ -32,6 +32,28 @@ namespace codegen {
 
 using namespace ir;
 
+// map struct field kind to runtime variants
+// We keep two separate enums to ensure runtime/compiler isolation.
+StackVM::StructFieldKind MapFieldKind(int64_t kind) {
+  auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
+  switch (val) {
+    case intrinsic::kArrData: return StackVM::kArrData;
+    case intrinsic::kArrShape: return StackVM::kArrShape;
+    case intrinsic::kArrAddr: return StackVM::kArrAddr;
+    case intrinsic::kArrStrides: return StackVM::kArrStrides;
+    case intrinsic::kArrNDim: return StackVM::kArrNDim;
+    case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode;
+    case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits;
+    case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes;
+    case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset;
+    case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId;
+    case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType;
+    case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent;
+    default: LOG(FATAL) << "Do not know how to map field " << kind;
+  }
+  return StackVM::kArrData;
+}
+
 StackVM CodeGenStackVM::Compile(LoweredFunc f) {
   for (size_t i = 0; i < f->args.size(); ++i) {
     Var v = f->args[i];
@@ -163,7 +185,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
     vm_.code.push_back(code);
     code.v_int = index->value;
     vm_.code.push_back(code);
-    code.v_int = kind;
+    code.v_int = MapFieldKind(kind);
     vm_.code.push_back(code);
   } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
     CHECK_GE(op->args.size(), 5U);
@@ -431,7 +453,7 @@ void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
     vm_.code.push_back(code);
     code.v_int = index->value;
     vm_.code.push_back(code);
-    code.v_int = op->args[2].as<IntImmNode>()->value;
+    code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value);
     vm_.code.push_back(code);
   } else {
     this->Push(ev->value);
index 6d4df47..78d743a 100644 (file)
@@ -189,7 +189,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
 
   then_for = IRTransform(for_stmt, nullptr, replace_then_case,
                          {PrimExpr("IfThenElse")});
-  if (if_stmt.as<IfThenElseNode>()->else_case) {
+  if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
     else_for = IRTransform(for_stmt, nullptr, replace_else_case,
                            {PrimExpr("IfThenElse")});
   }
@@ -221,7 +221,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
         for2if_map_[for_stmt.get()].push_back(head);
         const IfThenElseNode* if_node = head.as<IfThenElseNode>();
         tracker.push(if_node->then_case);
-        if (if_node->else_case) {
+        if (if_node->else_case.defined()) {
           tracker.push(if_node->else_case);
         }
 
index fbaf665..a1e2a66 100644 (file)
 namespace tvm {
 namespace relay {
 
-TensorType ToTensorType(const Type& t) {
-  if (const auto* tt_node = t.as<TensorTypeNode>()) {
-    return GetRef<TensorType>(tt_node);
-  } else {
-    return TensorType(nullptr);
-  }
-}
-
 bool IdentityRel(const Array<Type>& types,
                  int num_inputs,
                  const Attrs& attrs,
@@ -115,11 +107,11 @@ bool BroadcastRel(const Array<Type>& types,
   CHECK_EQ(types.size(), 3);
   // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
   //                 << ",Out:" << types[2] << std::endl;
-  if (auto t0 = ToTensorType(types[0])) {
-    if (auto t1 = ToTensorType(types[1])) {
+  if (auto* t0 = types[0].as<TensorTypeNode>()) {
+    if (auto* t1 = types[1].as<TensorTypeNode>()) {
       CHECK_EQ(t0->dtype, t1->dtype);
       reporter->Assign(types[2],
-        ConcreteBroadcast(t0, t1, t0->dtype));
+        ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
       return true;
     }
   }
@@ -133,10 +125,11 @@ bool BroadcastCompRel(const Array<Type>& types,
   CHECK_EQ(types.size(), 3);
   // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
   //                 << ",Out:" << types[2] << std::endl;
-  if (auto t0 = ToTensorType(types[0])) {
-    if (auto t1 = ToTensorType(types[1])) {
+  if (auto* t0 = types[0].as<TensorTypeNode>()) {
+    if (auto* t1 = types[1].as<TensorTypeNode>()) {
       CHECK_EQ(t0->dtype, t1->dtype);
-      reporter->Assign(types[2], ConcreteBroadcast(t0, t1, ::tvm::DataType::Bool()));
+      reporter->Assign(types[2],
+        ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), DataType::Bool()));
       return true;
     }
   }
index 4c343bd..c7935c4 100644 (file)
@@ -749,7 +749,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     PStatic r = VisitExpr(op->ref, ll);
     if (r->pstatic.defined()) {
       PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>());
-      if (ret) {
+      if (ret.defined()) {
         return ret;
       }
     }
index ef9f5d6..d4959be 100644 (file)
@@ -22,7 +22,7 @@
  */
 #include <dmlc/logging.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include "gemm_common.h"
 
 extern "C" {
index e35e431..b73abab 100644 (file)
@@ -24,7 +24,7 @@
 #pragma once
 
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <algorithm>
 
 namespace tvm {
index 2cb6777..5424f4c 100644 (file)
@@ -21,7 +21,7 @@
  * \file Use external cblas library call.
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <dmlc/logging.h>
 #include "../cblas/gemm_common.h"
 #include "cublas_utils.h"
index b9609b9..9581133 100644 (file)
@@ -21,7 +21,7 @@
  * \file Use external cudnn utils function
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <tvm/runtime/device_api.h>
 #include "cudnn_utils.h"
 
index 5094cef..d457548 100644 (file)
@@ -21,7 +21,7 @@
  * \file Use external miopen utils function
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <tvm/runtime/device_api.h>
 #include "miopen_utils.h"
 
index 728646c..f1fff95 100644 (file)
@@ -29,7 +29,7 @@
 #include <dmlc/thread_local.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <vector>
 #include "../../metal/metal_common.h"
 
index 8934693..79ea191 100644 (file)
@@ -22,7 +22,7 @@
  */
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <dmlc/logging.h>
 #include <nnpack.h>
 #include "nnpack_utils.h"
index b0d72fe..5f111ef 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -21,7 +21,7 @@
  * \file Use external nnpack library call.
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <dmlc/logging.h>
 #include <nnpack.h>
 #include "nnpack_utils.h"
index 551cff2..4ba586f 100644 (file)
@@ -23,7 +23,7 @@
 #ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
 #define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <dmlc/thread_local.h>
 #include <dmlc/logging.h>
 #include <nnpack.h>
index 3da2e16..46a14e6 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -21,7 +21,7 @@
  * \file External random functions for tensor.
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <dmlc/logging.h>
 #include <dmlc/thread_local.h>
 #include <algorithm>
index 813f4c6..dda4ee3 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -21,7 +21,7 @@
  * \file Use external rocblas library call.
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+#include <tvm/runtime/data_type.h>
 #include <dmlc/logging.h>
 #include "rocblas.h"
 
index 68f70c1..0c9c575 100644 (file)
@@ -22,7 +22,6 @@
  */
 
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
 #include <dlpack/dlpack.h>
 #include <algorithm>
 #include <vector>
index 06b154e..0f17f9e 100644 (file)
@@ -22,7 +22,6 @@
  * \file stackvm.cc
  */
 #include <dmlc/thread_local.h>
-#include <tvm/runtime/util.h>
 #include <tvm/runtime/c_backend_api.h>
 #include <algorithm>
 #include "stackvm.h"
@@ -392,50 +391,49 @@ void StackVM::Run(State* s) const {
       }
       // intrinsics
       case TVM_STRUCT_GET: {
-        using namespace ir;
         int index = code[pc + 1].v_int;
         int kind = code[pc + 2].v_int;
         DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle);
         switch (kind) {
-          case intrinsic::kArrData: {
+          case StackVM::kArrData: {
             stack[sp].v_handle = arr[index].data; break;
           }
-          case intrinsic::kArrShape: {
+          case StackVM::kArrShape: {
             stack[sp].v_handle = arr[index].shape; break;
           }
-          case intrinsic::kArrStrides: {
+          case StackVM::kArrStrides: {
             stack[sp].v_handle = arr[index].strides; break;
           }
-          case intrinsic::kArrNDim: {
+          case StackVM::kArrNDim: {
             stack[sp].v_int64 = arr[index].ndim; break;
           }
-          case intrinsic::kArrTypeCode: {
+          case StackVM::kArrTypeCode: {
             stack[sp].v_int64 = static_cast<int64_t>(
                 arr[index].dtype.code); break;
           }
-          case intrinsic::kArrTypeBits: {
+          case StackVM::kArrTypeBits: {
             stack[sp].v_int64 = static_cast<int64_t>(
                 arr[index].dtype.bits); break;
           }
-          case intrinsic::kArrTypeLanes: {
+          case StackVM::kArrTypeLanes: {
             stack[sp].v_int64 = static_cast<int64_t>(
                 arr[index].dtype.lanes); break;
           }
-          case intrinsic::kArrByteOffset: {
+          case StackVM::kArrByteOffset: {
             stack[sp].v_int64 = static_cast<int64_t>(
                 arr[index].byte_offset); break;
           }
-          case intrinsic::kArrDeviceId: {
+          case StackVM::kArrDeviceId: {
             stack[sp].v_int64 = arr[index].ctx.device_id; break;
           }
-          case intrinsic::kArrDeviceType: {
+          case StackVM::kArrDeviceType: {
             stack[sp].v_int64 = static_cast<int64_t>(
                 arr[index].ctx.device_type); break;
           }
-          case intrinsic::kArrAddr: {
+          case StackVM::kArrAddr: {
             stack[sp].v_handle = arr + index; break;
           }
-          case intrinsic::kTVMValueContent: {
+          case StackVM::kTVMValueContent: {
             stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; break;
           }
           default: LOG(FATAL) << "unhandled get " << kind;
@@ -444,51 +442,50 @@ void StackVM::Run(State* s) const {
         break;
       }
       case TVM_STRUCT_SET: {
-        using namespace ir;
         int index = code[pc + 1].v_int;
         int kind = code[pc + 2].v_int;
         DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle);
         switch (kind) {
-          case intrinsic::kArrData: {
+          case StackVM::kArrData: {
             arr[index].data = stack[sp].v_handle; break;
           }
-          case intrinsic::kArrShape: {
+          case StackVM::kArrShape: {
             arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
             break;
           }
-          case intrinsic::kArrStrides: {
+          case StackVM::kArrStrides: {
             arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle);
             break;
           }
-          case intrinsic::kArrNDim: {
+          case StackVM::kArrNDim: {
             arr[index].ndim = static_cast<int>(stack[sp].v_int64);
             break;
           }
-          case intrinsic::kArrTypeCode: {
+          case StackVM::kArrTypeCode: {
             arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64);
             break;
           }
-          case intrinsic::kArrTypeBits: {
+          case StackVM::kArrTypeBits: {
             arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64);
             break;
           }
-          case intrinsic::kArrTypeLanes: {
+          case StackVM::kArrTypeLanes: {
             arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64);
             break;
           }
-          case intrinsic::kArrByteOffset: {
+          case StackVM::kArrByteOffset: {
             arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64);
             break;
           }
-          case intrinsic::kArrDeviceId: {
+          case StackVM::kArrDeviceId: {
             arr[index].ctx.device_id = static_cast<int>(stack[sp].v_int64);
             break;
           }
-          case intrinsic::kArrDeviceType: {
+          case StackVM::kArrDeviceType: {
             arr[index].ctx.device_type = static_cast<DLDeviceType>(stack[sp].v_int64);
             break;
           }
-          case intrinsic::kTVMValueContent: {
+          case StackVM::kTVMValueContent: {
             static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; break;
           }
           default: LOG(FATAL) << "unhandled tvm_struct_set " << kind;
index 6ed9647..f36e171 100644 (file)
@@ -38,6 +38,7 @@ namespace tvm {
 namespace runtime {
 
 using runtime::operator<<;
+
 /*!
  * \brief A simple stack-based virtual machine program.
  */
@@ -283,6 +284,25 @@ class StackVM {
      */
     TVM_STRUCT_SET
   };
+  /*! \brief The kind of structure field info */
+  enum StructFieldKind : int {
+    // array head address
+    kArrAddr,
+    kArrData,
+    kArrShape,
+    kArrStrides,
+    kArrNDim,
+    kArrTypeCode,
+    kArrTypeBits,
+    kArrTypeLanes,
+    kArrByteOffset,
+    kArrDeviceId,
+    kArrDeviceType,
+    kArrKindBound_,
+    // TVMValue field
+    kTVMValueContent,
+    kTVMValueKindBound_
+  };
   /*! \brief The code structure */
   union Code {
     OpCode op_code;