[REFACTOR][IR] Move to runtime::String (#5276)
authorZhi <5145158+zhiics@users.noreply.github.com>
Fri, 10 Apr 2020 14:46:23 +0000 (07:46 -0700)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 14:46:23 +0000 (07:46 -0700)
* Use runtime::String

* move string to tvm namespace

* add const char* constructor

* implicit cast from std::string

85 files changed:
include/tvm/ir/expr.h
include/tvm/ir/transform.h
include/tvm/node/container.h
include/tvm/node/node.h
include/tvm/relay/transform.h
include/tvm/runtime/container.h
include/tvm/target/target.h
include/tvm/tir/stmt_functor.h
include/tvm/tir/transform.h
python/tvm/autotvm/task/task.py
python/tvm/relay/backend/graph_runtime_codegen.py
python/tvm/runtime/container.py
python/tvm/runtime/object_generic.py
python/tvm/target/target.py
src/autotvm/touch_extractor.cc
src/ir/attrs.cc
src/ir/expr.cc
src/ir/op.cc
src/ir/transform.cc
src/node/container.cc
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/backend/contrib/codegen_c/codegen_c.h
src/relay/backend/graph_runtime_codegen.cc
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/inline_primitives.cc
src/relay/backend/vm/lambda_lift.cc
src/relay/backend/vm/removed_unused_funcs.cc
src/relay/ir/transform.cc
src/relay/op/tensor/transform.cc
src/relay/transforms/alter_op_layout.cc
src/relay/transforms/annotate_target.cc
src/relay/transforms/canonicalize_cast.cc
src/relay/transforms/canonicalize_ops.cc
src/relay/transforms/combine_parallel_conv2d.cc
src/relay/transforms/combine_parallel_dense.cc
src/relay/transforms/combine_parallel_op_batch.cc
src/relay/transforms/convert_layout.cc
src/relay/transforms/device_annotation.cc
src/relay/transforms/eliminate_common_subexpr.cc
src/relay/transforms/fast_math.cc
src/relay/transforms/fold_scale_axis.cc
src/relay/transforms/fuse_ops.cc
src/relay/transforms/inline.cc
src/relay/transforms/legalize.cc
src/relay/transforms/merge_composite.cc
src/relay/transforms/partition_graph.cc
src/relay/transforms/simplify_inference.cc
src/relay/transforms/to_a_normal_form.cc
src/runtime/container.cc
src/target/build_common.h
src/target/generic_func.cc
src/target/llvm/codegen_cpu.cc
src/target/llvm/codegen_llvm.cc
src/target/llvm/llvm_module.cc
src/target/source/codegen_c.cc
src/target/source/codegen_metal.cc
src/target/source/codegen_opengl.cc
src/target/source/codegen_vhls.cc
src/target/spirv/build_vulkan.cc
src/target/spirv/codegen_spirv.cc
src/target/stackvm/codegen_stackvm.cc
src/target/target.cc
src/tir/ir/expr.cc
src/tir/ir/stmt_functor.cc
src/tir/ir/transform.cc
src/tir/pass/arg_binder.cc
src/tir/pass/hoist_if_then_else.cc
src/tir/pass/tensor_core.cc
src/tir/transforms/bind_device_type.cc
src/tir/transforms/make_packed_api.cc
src/tir/transforms/remap_thread_axis.cc
src/tir/transforms/split_host_device.cc
tests/cpp/container_test.cc
tests/python/relay/test_annotate_target.py
tests/python/relay/test_call_graph.py
tests/python/relay/test_external_codegen.py
tests/python/relay/test_ir_nodes.py
tests/python/relay/test_ir_structural_equal_hash.py
tests/python/relay/test_pass_inline.py
tests/python/relay/test_pass_merge_composite.py
tests/python/relay/test_pass_partition_graph.py
tests/python/unittest/test_ir_attrs.py
topi/include/topi/contrib/cublas.h
topi/include/topi/contrib/rocblas.h

index a683fd6..6822159 100644 (file)
@@ -107,11 +107,12 @@ class PrimExpr : public BaseExpr {
    * \param value The value to be constructed.
    */
   TVM_DLL PrimExpr(float value);  // NOLINT(*)
+
   /*!
-   * \brief construct from string.
-   * \param str The value to be constructed.
+   * \brief construct from runtime String.
+   * \param value The value to be constructed.
    */
-  TVM_DLL PrimExpr(std::string str);  // NOLINT(*)
+  TVM_DLL PrimExpr(runtime::String value);  // NOLINT(*)
 
   /*! \return the data type of this expression. */
   DataType dtype() const {
index ecd234a..3a9913f 100644 (file)
@@ -57,6 +57,7 @@
 #define TVM_IR_TRANSFORM_H_
 
 #include <tvm/support/with.h>
+#include <tvm/runtime/container.h>
 #include <tvm/node/container.h>
 #include <tvm/ir/error.h>
 #include <tvm/ir/module.h>
@@ -95,9 +96,9 @@ class PassContextNode : public Object {
   int fallback_device{static_cast<int>(kDLCPU)};
 
   /*! \brief The list of required passes. */
-  Array<PrimExpr> required_pass;
+  Array<runtime::String> required_pass;
   /*! \brief The list of disabled passes. */
-  Array<PrimExpr> disabled_pass;
+  Array<runtime::String> disabled_pass;
 
   TraceFunc trace_func;
 
@@ -197,7 +198,7 @@ class PassInfoNode : public Object {
   std::string name;
 
   /*! \brief The passes that are required to perform the current pass. */
-  Array<PrimExpr> required;
+  Array<runtime::String> required;
 
   PassInfoNode() = default;
 
@@ -226,7 +227,7 @@ class PassInfo : public ObjectRef {
    */
   TVM_DLL PassInfo(int opt_level,
                    std::string name,
-                   Array<PrimExpr> required);
+                   Array<runtime::String> required);
 
   TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
 };
@@ -346,7 +347,7 @@ Pass CreateModulePass(
     const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
-    const Array<PrimExpr>& required);
+    const Array<runtime::String>& required);
 
 }  // namespace transform
 }  // namespace tvm
index 461fa11..cf2ac26 100644 (file)
@@ -36,6 +36,8 @@
 
 namespace tvm {
 
+using runtime::String;
+using runtime::StringObj;
 using runtime::Object;
 using runtime::ObjectPtr;
 using runtime::ObjectRef;
index 04f477b..b39e3b4 100644 (file)
@@ -35,6 +35,7 @@
 #define TVM_NODE_NODE_H_
 
 #include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/object.h>
 #include <tvm/runtime/memory.h>
 #include <tvm/node/reflection.h>
@@ -62,6 +63,7 @@ using runtime::make_object;
 using runtime::PackedFunc;
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
+using runtime::String;
 
 }  // namespace tvm
 #endif  // TVM_NODE_NODE_H_
index deb084c..2dcf7f3 100644 (file)
@@ -24,6 +24,7 @@
 #ifndef TVM_RELAY_TRANSFORM_H_
 #define TVM_RELAY_TRANSFORM_H_
 
+#include <tvm/runtime/container.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/ir/transform.h>
 #include <tvm/relay/expr.h>
@@ -59,7 +60,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
                                 Function(Function, IRModule, PassContext)>& pass_func,
                                 int opt_level,
                                 const std::string& name,
-                                const tvm::Array<tvm::PrimExpr>& required);
+                                const tvm::Array<runtime::String>& required);
 
 /*! \brief Remove expressions which does not effect the program result.
  *
@@ -355,7 +356,7 @@ TVM_DLL Pass Inline();
  *
  * \return The pass.
  */
-TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
+TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
 
 }  // namespace transform
 
index 50b406b..083f87f 100644 (file)
@@ -360,7 +360,15 @@ class String : public ObjectRef {
    * \note If user passes const reference, it will trigger copy. If it's rvalue,
    * it will be moved into other.
    */
-  explicit String(std::string other);
+  String(std::string other);  // NOLINT(*)
+
+  /*!
+   * \brief Construct a new String object
+   *
+   * \param other a char array.
+   */
+  String(const char* other)  // NOLINT(*)
+      : String(std::string(other)) {}
 
   /*!
    * \brief Change the value the reference object points to.
index f6fd3c4..59aa955 100644 (file)
@@ -52,11 +52,11 @@ class TargetNode : public Object {
   /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
   int thread_warp_size = 1;
   /*! \brief Keys for this target */
-  Array<PrimExpr> keys_array;
+  Array<runtime::String> keys_array;
   /*! \brief Options for this target */
-  Array<PrimExpr> options_array;
+  Array<runtime::String> options_array;
   /*! \brief Collection of imported libs */
-  Array<PrimExpr> libs_array;
+  Array<runtime::String> libs_array;
 
   /*! \return the full device string to pass to codegen::Build */
   TVM_DLL const std::string& str() const;
index 6824022..ad5c5cd 100644 (file)
@@ -326,7 +326,7 @@ class StmtExprMutator :
  *          won't do further recursion.
  * \param postorder The function called after recursive mutation.
  *          The recursive mutation result is passed to postorder for further mutation.
- * \param only_enable List of StringImm.
+ * \param only_enable List of runtime::String.
  *          If it is empty, all IRNode will call preorder/postorder
  *          If it is not empty, preorder/postorder will only be called
  *          when the IRNode's type key is in the list.
@@ -334,7 +334,7 @@ class StmtExprMutator :
 TVM_DLL Stmt IRTransform(Stmt node,
                          const runtime::PackedFunc& preorder,
                          const runtime::PackedFunc& postorder,
-                         const Array<PrimExpr>& only_enable = {});
+                         const Array<runtime::String>& only_enable = {});
 
 /*!
  * \brief recursively visit the ir in post DFS order node, apply fvisit
index 860014d..5ad40a3 100644 (file)
@@ -56,7 +56,7 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
                                 PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
                                 int opt_level,
                                 const std::string& name,
-                                const tvm::Array<tvm::PrimExpr>& required);
+                                const tvm::Array<runtime::String>& required);
 
 /*!
  * \brief Transform the high-level PrimFunc to a low-level version
@@ -100,7 +100,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
  *
  * \return The pass.
  */
-TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
+TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
 
 
 /*!
index ddee149..00b6676 100644 (file)
@@ -24,6 +24,7 @@ registers the standard task.
 import numpy as np
 
 from tvm import target as _target
+from tvm import runtime
 from tvm.ir import container
 from tvm.tir import expr
 from tvm.te import tensor, placeholder
@@ -55,6 +56,8 @@ def serialize_args(args):
             return x
         if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
             return x.value
+        if isinstance(x, runtime.container.String):
+            return str(x)
         if x is None:
             return None
         raise RuntimeError('Do not support type "%s" in argument. Consider to use'
index 3e5f015..8210f27 100644 (file)
@@ -84,8 +84,7 @@ class GraphRuntimeCodegen(object):
         lowered_func = self._get_irmodule()
         param_names = self._list_params_name()
         params = {}
-        for name in param_names:
-            key = name.value
+        for key in param_names:
             arr = self._get_param_by_name(key)
             param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx)
             arr.copyto(param)
index dd59011..a719dcd 100644 (file)
@@ -16,8 +16,9 @@
 # under the License.
 """Runtime container structures."""
 import tvm._ffi
-
+from tvm._ffi.base import string_types
 from tvm.runtime import Object, ObjectTypes
+from tvm.runtime import _ffi_api
 
 def getitem_helper(obj, elem_getter, length, idx):
     """Helper function to implement a pythonic getitem function.
@@ -75,18 +76,19 @@ class ADT(Object):
         for f in fields:
             assert isinstance(f, ObjectTypes), "Expect object or " \
             "tvm NDArray type, but received : {0}".format(type(f))
-        self.__init_handle_by_constructor__(_ADT, tag, *fields)
+        self.__init_handle_by_constructor__(_ffi_api.ADT, tag,
+                                            *fields)
 
     @property
     def tag(self):
-        return _GetADTTag(self)
+        return _ffi_api.GetADTTag(self)
 
     def __getitem__(self, idx):
         return getitem_helper(
-            self, _GetADTFields, len(self), idx)
+            self, _ffi_api.GetADTFields, len(self), idx)
 
     def __len__(self):
-        return _GetADTSize(self)
+        return _ffi_api.GetADTSize(self)
 
 
 def tuple_object(fields=None):
@@ -106,7 +108,7 @@ def tuple_object(fields=None):
     for f in fields:
         assert isinstance(f, ObjectTypes), "Expect object or tvm " \
         "NDArray type, but received : {0}".format(type(f))
-    return _Tuple(*fields)
+    return _ffi_api.Tuple(*fields)
 
 
 @tvm._ffi.register_object("runtime.String")
@@ -115,7 +117,7 @@ class String(Object):
 
     Parameters
     ----------
-    string : Str
+    string : str
         The string used to construct a runtime String object
 
     Returns
@@ -124,7 +126,50 @@ class String(Object):
         The created object.
     """
     def __init__(self, string):
-        self.__init_handle_by_constructor__(_String, string)
+        self.__init_handle_by_constructor__(_ffi_api.String, string)
+
+    def __str__(self):
+        return _ffi_api.GetStdString(self)
+
+    def __len__(self):
+        return _ffi_api.GetStringSize(self)
+
+    def __hash__(self):
+        return _ffi_api.StringHash(self)
+
+    def __eq__(self, other):
+        if isinstance(other, string_types):
+            return self.__str__() == other
+
+        if not isinstance(other, String):
+            return False
+
+        return _ffi_api.CompareString(self, other) == 0
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __gt__(self, other):
+        return _ffi_api.CompareString(self, other) > 0
+
+    def __lt__(self, other):
+        return _ffi_api.CompareString(self, other) < 0
+
+    def __getitem__(self, key):
+        return self.__str__()[key]
+
+    def startswith(self, string):
+        """Check if the runtime string starts with a given string
 
+        Parameters
+        ----------
+        string : str
+            The provided string
 
-tvm._ffi._init_api("tvm.runtime.container")
+        Returns
+        -------
+        ret : boolean
+            Return true if the runtime string starts with the given string,
+        otherwise, false.
+        """
+        return self.__str__().startswith(string)
index 22354db..a7716df 100644 (file)
@@ -19,7 +19,7 @@
 from numbers import Number, Integral
 from tvm._ffi.base import string_types
 
-from . import _ffi_node_api
+from . import _ffi_node_api, _ffi_api
 from .object import ObjectBase, _set_class_object_generic
 from .ndarray import NDArrayBase
 from .packed_func import PackedFuncBase, convert_to_tvm_func
@@ -56,7 +56,7 @@ def convert_to_object(value):
     if isinstance(value, Number):
         return const(value)
     if isinstance(value, string_types):
-        return _ffi_node_api.String(value)
+        return _ffi_api.String(value)
     if isinstance(value, (list, tuple)):
         value = [convert_to_object(x) for x in value]
         return _ffi_node_api.Array(*value)
index a83ea0c..fd15ff9 100644 (file)
@@ -48,26 +48,26 @@ class Target(Object):
     @property
     def keys(self):
         if not self._keys:
-            self._keys = [k.value for k in self.keys_array]
+            self._keys = [str(k) for k in self.keys_array]
         return self._keys
 
     @property
     def options(self):
         if not self._options:
-            self._options = [o.value for o in self.options_array]
+            self._options = [str(o) for o in self.options_array]
         return self._options
 
     @property
     def libs(self):
         if not self._libs:
-            self._libs = [l.value for l in self.libs_array]
+            self._libs = [str(l) for l in self.libs_array]
         return self._libs
 
     @property
     def model(self):
         for opt in self.options_array:
-            if opt.value.startswith('-model='):
-                return opt.value[7:]
+            if opt.startswith('-model='):
+                return opt[7:]
         return 'unknown'
 
     @property
index b5bf2ed..fbd0829 100644 (file)
@@ -252,9 +252,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
   for (auto var : vars) {
     Array<Array<PrimExpr> > feature_row;
     ItervarFeature &fea = touch_analyzer.itervar_map[var];
-    feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
+    feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_itervar_"), var});
 
-    Array<PrimExpr> attr{std::string("_attr_"),
+    Array<PrimExpr> attr{tvm::tir::StringImmNode::make("_attr_"),
                      FloatImm(DataType::Float(32), trans(fea.length)),
                      IntImm(DataType::Int(32), fea.nest_level),
                      FloatImm(DataType::Float(32), trans(fea.topdown_product)),
@@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
     feature_row.push_back(attr);
 
     // arithmetic
-    feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
+    feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_arith_"),
             FloatImm(DataType::Float(32), trans(fea.add_ct)),
             FloatImm(DataType::Float(32), trans(fea.mul_ct)),
             FloatImm(DataType::Float(32), trans(fea.div_ct)),
@@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
     for (auto k : bufs) {
       TouchPattern &v = fea.touch_feature[k];
       feature_row.push_back(
-          Array<PrimExpr>{k,
+          Array<PrimExpr>{tvm::tir::StringImmNode::make(k),
                 FloatImm(DataType::Float(32), trans(v.stride)),
                 FloatImm(DataType::Float(32), trans(v.mod)),
                 FloatImm(DataType::Float(32), trans(v.count)),
index 066b8f9..bee103d 100644 (file)
@@ -42,7 +42,7 @@ void DictAttrsNode::InitByPackedArgs(
     if (val.IsObjectRef<ObjectRef>()) {
       dict.Set(key, val.operator ObjectRef());
     } else if (val.type_code() == kTVMStr) {
-      dict.Set(key, PrimExpr(val.operator std::string()));
+      dict.Set(key, val.operator String());
     } else {
       dict.Set(key, val.operator PrimExpr());
     }
index b07f04a..1f0337e 100644 (file)
@@ -40,8 +40,8 @@ PrimExpr::PrimExpr(int32_t value)
 PrimExpr::PrimExpr(float value)
     : PrimExpr(FloatImm(DataType::Float(32), value)) {}
 
-PrimExpr::PrimExpr(std::string str)
-    : PrimExpr(tir::StringImmNode::make(str)) {}
+PrimExpr::PrimExpr(runtime::String value)
+    : PrimExpr(tir::StringImmNode::make(value)) {}
 
 PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
   using runtime::ObjectTypeChecker;
@@ -51,6 +51,9 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
   if (ptr->IsInstance<te::TensorNode>()) {
     return te::Tensor(ptr)();
   }
+  if (ptr->IsInstance<runtime::StringObj>()) {
+    return tir::StringImmNode::make(runtime::String(ptr));
+  }
   CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
       << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
       << " but get " << ptr->GetTypeKey();
index 6a50240..b024165 100644 (file)
@@ -24,6 +24,7 @@
 #include <tvm/ir/op.h>
 #include <tvm/ir/type.h>
 #include <tvm/runtime/module.h>
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/packed_func.h>
 
 #include <memory>
@@ -140,10 +141,9 @@ void OpRegistry::UpdateAttr(const std::string& key,
 // Frontend APIs
 TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
 .set_body_typed([]() {
-    Array<tvm::PrimExpr> ret;
-    for (const std::string& name :
-             dmlc::Registry<OpRegistry>::ListAllNames()) {
-      ret.push_back(tvm::PrimExpr(name));
+    Array<runtime::String> ret;
+    for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
+      ret.push_back(name);
     }
     return ret;
   });
index 61c1fc2..6e38aac 100644 (file)
@@ -23,6 +23,7 @@
  */
 #include <dmlc/thread_local.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/node/repr_printer.h>
 #include <tvm/ir/transform.h>
@@ -212,7 +213,7 @@ class SequentialNode : public PassNode {
 
 PassInfo::PassInfo(int opt_level,
                    std::string name,
-                   tvm::Array<tvm::PrimExpr> required) {
+                   tvm::Array<runtime::String> required) {
   auto pass_info = make_object<PassInfoNode>();
   pass_info->opt_level = opt_level;
   pass_info->name = std::move(name);
@@ -274,12 +275,10 @@ void SequentialNode::ResolveDependency(const IRModule& mod) {
 }
 
 // linearly scan the pass array to match pass_name
-inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
+inline bool PassArrayContains(const Array<runtime::String>& pass_array,
                               const std::string& pass_name) {
   for (auto x : pass_array) {
-    auto* str_name = x.as<tir::StringImmNode>();
-    CHECK(str_name) << "pass name must be str";
-    if (str_name->value == pass_name) return true;
+    if (x == pass_name) return true;
   }
   return false;
 }
@@ -324,9 +323,7 @@ IRModule SequentialNode::operator()(const IRModule& module,
     if (!PassEnabled(pass_info))  continue;
     // resolve dependencies
     for (const auto& it : pass_info->required) {
-      const auto* name = it.as<tvm::tir::StringImmNode>();
-      CHECK(name);
-      mod = GetPass(name->value)(mod, pass_ctx);
+      mod = GetPass(it)(mod, pass_ctx);
     }
     mod = pass(mod, pass_ctx);
   }
@@ -337,7 +334,7 @@ Pass CreateModulePass(
     const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
-    const tvm::Array<tvm::PrimExpr>& required) {
+    const tvm::Array<runtime::String>& required) {
   PassInfo pass_info = PassInfo(opt_level, name, required);
   return ModulePass(pass_func, pass_info);
 }
@@ -345,7 +342,7 @@ Pass CreateModulePass(
 TVM_REGISTER_NODE_TYPE(PassInfoNode);
 
 TVM_REGISTER_GLOBAL("transform.PassInfo")
-.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
+.set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
   return PassInfo(opt_level, name, required);
 });
 
@@ -363,8 +360,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
   p->stream << "opt_level: " << node->opt_level;
   p->stream << "required passes: [" << "\n";
   for (const auto& it : node->required) {
-    const auto* str = it.as<tvm::tir::StringImmNode>();
-    p->stream << str->value << ", ";
+    p->stream << it << ", ";
   }
   p->stream << "]\n";
 });
@@ -401,7 +397,7 @@ TVM_REGISTER_GLOBAL("transform.Sequential")
   tvm::Array<Pass> passes = args[0];
   int opt_level = args[1];
   std::string name = args[2];
-  tvm::Array<tvm::PrimExpr> required = args[3];
+  tvm::Array<runtime::String> required = args[3];
   PassInfo pass_info = PassInfo(opt_level, name, required);
   *ret = Sequential(passes, pass_info);
 });
@@ -427,8 +423,8 @@ TVM_REGISTER_GLOBAL("transform.PassContext")
   auto pctx = PassContext::Create();
   int opt_level = args[0];
   int fallback_device = args[1];
-  tvm::Array<tvm::PrimExpr> required = args[2];
-  tvm::Array<tvm::PrimExpr> disabled = args[3];
+  tvm::Array<runtime::String> required = args[2];
+  tvm::Array<runtime::String> disabled = args[3];
   TraceFunc trace_func = args[4];
   pctx->opt_level = opt_level;
   pctx->fallback_device = fallback_device;
index e7e4979..bce2eee 100644 (file)
@@ -63,7 +63,6 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
       static_cast<const runtime::StringObj*>(n)).operator std::string();
 });
 
-
 struct ADTObjTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
 
index eaf78bc..e2d5e93 100644 (file)
@@ -86,9 +86,10 @@ struct GraphCodegen {
 
   std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
     std::unordered_map<std::string, tvm::runtime::NDArray> ret;
-    auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
-    for (auto expr : names) {
-      auto key = expr.as<tir::StringImmNode>()->value;
+    auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
+    for (const auto& expr : names) {
+      // Implicit cast from runtime::String to std::string
+      std::string key = expr;
       ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
     }
     return ret;
@@ -191,12 +192,12 @@ class RelayBuildModule : public runtime::ModuleNode {
   /*!
    * \brief List all paramter names
    *
-   * \return Array<StringImm> names of params
+   * \return Array<runtime::String> names of params
    */
-  Array<tvm::PrimExpr> ListParamNames() {
-    Array<tvm::PrimExpr> ret;
+  Array<runtime::String> ListParamNames() {
+    Array<runtime::String> ret;
     for (const auto& kv : params_) {
-      ret.push_back(tir::StringImmNode::make(kv.first));
+      ret.push_back(kv.first);
     }
     return ret;
   }
@@ -272,7 +273,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     }
 
     Array<Pass> pass_seqs;
-    Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+    Array<runtime::String> entry_functions{"main"};
     pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
 
     // Run all dialect legalization passes.
index f75da07..9cb6b2e 100644 (file)
@@ -617,17 +617,18 @@ class CompileEngineImpl : public CompileEngineNode {
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
       CHECK(src_func.defined());
-      if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
-        auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
         CHECK(code_gen.defined()) << "No external codegen is set";
-        if (ext_mods.find(code_gen->value) == ext_mods.end()) {
-          ext_mods[code_gen->value] = IRModule({}, {});
+        std::string code_gen_name = code_gen;
+        if (ext_mods.find(code_gen_name) == ext_mods.end()) {
+          ext_mods[code_gen_name] = IRModule({}, {});
         }
-        auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
         CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
                                      << AsText(src_func, false);
         auto gv = GlobalVar(std::string(symbol_name));
-        ext_mods[code_gen->value]->Add(gv, src_func);
+        ext_mods[code_gen_name]->Add(gv, src_func);
         cached_ext_funcs.push_back(it.first);
       }
     }
@@ -691,10 +692,10 @@ class CompileEngineImpl : public CompileEngineNode {
     }
     // No need to lower external functions for now. We will invoke the external
     // codegen tool once and lower all functions together.
-    if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
       auto cache_node = make_object<CachedFuncNode>();
       const auto name_node =
-          key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+          key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
       CHECK(name_node.defined())
           << "External function has not been attached a name yet.";
       cache_node->func_name = std::string(name_node);
index 79d4d3f..1db3f20 100644 (file)
@@ -70,7 +70,7 @@ class CSourceModuleCodegenBase {
    */
   std::string GetExtSymbol(const Function& func) const {
     const auto name_node =
-        func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+        func->GetAttr<String>(tvm::attr::kGlobalSymbol);
     CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
     return std::string(name_node);
   }
index c7f1be8..4279db0 100644 (file)
@@ -419,7 +419,7 @@ class GraphRuntimeCodegen
     auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
     Target target;
     // Handle external function
-    if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
       target = tvm::target::ext_dev();
       CCacheKey key = (*pf0)(func, target);
       CachedFunc ext_func = (*pf1)(compile_engine_, key);
@@ -482,7 +482,7 @@ class GraphRuntimeCodegen
     return {};
   }
   std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
-    CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
+    CHECK(op->GetAttr<String>(attr::kCompiler).defined())
         << "Only functions supported by custom codegen";
     return {};
   }
@@ -633,10 +633,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
       });
     } else if (name == "list_params_name") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        Array<tvm::PrimExpr> ret;
+        Array<runtime::String> ret;
         for (const auto &kv : this->output_.params) {
-          tvm::PrimExpr name = tir::StringImmNode::make(kv.first);
-          ret.push_back(name);
+          ret.push_back(kv.first);
         }
         *rv = ret;
       });
index d68bff6..e2b0fff 100644 (file)
@@ -475,7 +475,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
 
     Target target;
 
-    if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
       target = tvm::target::ext_dev();
     } else {
       // Next generate the invoke instruction.
@@ -493,7 +493,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     auto cfunc = engine_->Lower(key);
 
     auto op_index = -1;
-    if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
       op_index = context_->cached_funcs.size();
       context_->cached_funcs.push_back(cfunc);
     } else {
@@ -873,7 +873,7 @@ void VMCompiler::Lower(IRModule mod,
 
 IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
   Array<Pass> pass_seqs;
-  Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+  Array<runtime::String> entry_functions{"main"};
   pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
   // Run all dialect legalization passes.
   pass_seqs.push_back(relay::qnn::transform::Legalize());
index 74b2a47..12113b0 100644 (file)
@@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator {
       auto global = pair.first;
       auto base_func = pair.second;
       if (auto* n = base_func.as<FunctionNode>()) {
-        if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+        if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
         auto func = GetRef<Function>(n);
 
         DLOG(INFO) << "Before inlining primitives: " << global
index 80745e1..59c549c 100644 (file)
@@ -190,7 +190,7 @@ class LambdaLifter : public ExprMutator {
     auto glob_funcs = module_->functions;
     for (auto pair : glob_funcs) {
       if (auto* n = pair.second.as<FunctionNode>()) {
-        if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+        if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
         auto func = GetRef<Function>(n);
         func = Function(func->params,
                         VisitExpr(func->body),
index dd11fce..c2fe37f 100644 (file)
@@ -87,11 +87,10 @@ struct CallTracer : ExprVisitor {
  * \return The module with dead functions removed.
  */
 IRModule RemoveUnusedFunctions(const IRModule& module,
-                             Array<tvm::PrimExpr> entry_funcs) {
+                               Array<runtime::String> entry_funcs) {
   std::unordered_set<std::string> called_funcs{};
   for (auto entry : entry_funcs) {
-    auto* str_name = entry.as<tir::StringImmNode>();
-    auto funcs = CallTracer(module).Trace(str_name->value);
+    auto funcs = CallTracer(module).Trace(entry);
     called_funcs.insert(funcs.cbegin(), funcs.cend());
   }
   auto existing_functions = module->functions;
@@ -108,7 +107,7 @@ IRModule RemoveUnusedFunctions(const IRModule& module,
 
 namespace transform {
 
-Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
+Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
     [=](IRModule m, PassContext pc) {
     return relay::vm::RemoveUnusedFunctions(m, entry_functions);
index a4bab36..fa709eb 100644 (file)
@@ -145,14 +145,14 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
 
 bool FunctionPassNode::SkipFunction(const Function& func) const {
   return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
-    (func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
+    (func->GetAttr<String>(attr::kCompiler).defined());
 }
 
 Pass CreateFunctionPass(
     const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
-    const tvm::Array<tvm::PrimExpr>& required) {
+    const tvm::Array<runtime::String>& required) {
   PassInfo pass_info = PassInfo(opt_level, name, required);
   return FunctionPass(pass_func, pass_info);
 }
index 87b4602..7aa8bf1 100644 (file)
@@ -1177,7 +1177,6 @@ Array<te::Tensor> ArangeCompute(const Attrs& attrs,
   te::Tensor start = inputs[0];
   te::Tensor stop =  inputs[1];
   te::Tensor step = inputs[2];
-  Array<tvm::PrimExpr> empty = {0};
   return { DynamicArange(start, stop, step, param->dtype) };
 }
 
index 63c1cb9..aab0b3a 100644 (file)
@@ -125,8 +125,7 @@ Pass AlterOpLayout() {
     [=](Function f, IRModule m, PassContext pc) {
       return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
   };
-  return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
index c3d34cb..44ef35a 100644 (file)
@@ -59,11 +59,12 @@ class AnnotateTargetWrapper : public ExprMutator {
         // handle composite functions
         Function func = Downcast<Function>(call->op);
         CHECK(func.defined());
-        auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
+        auto comp_name = func->GetAttr<String>(attr::kComposite);
         if (comp_name.defined()) {
-          size_t i = comp_name->value.find('.');
+          std::string comp_name_str = comp_name;
+          size_t i = comp_name_str.find('.');
           if (i != std::string::npos) {
-            std::string target = comp_name->value.substr(0, i);
+            std::string target = comp_name_str.substr(0, i);
             if (target == target_) return true;
           }
         }
@@ -147,7 +148,7 @@ class AnnotateTargetWrapper : public ExprMutator {
     Function func;
     Expr new_body;
     // don't step into composite functions
-    if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
+    if (fn->GetAttr<String>(attr::kComposite).defined()) {
       func = GetRef<Function>(fn);
       new_body = func->body;
     } else {
@@ -225,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) {
         return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
       };
   auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
-                                      {tir::StringImmNode::make("InferType")});
+                                      {"InferType"});
   return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
 }
 
index 759a4ae..ebcbd57 100644 (file)
@@ -133,8 +133,7 @@ Pass CanonicalizeCast() {
     [=](Function f, IRModule m, PassContext pc) {
     return Downcast<Function>(CanonicalizeCast(f));
   };
-  return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
index 97a128d..1d3111b 100644 (file)
@@ -74,8 +74,7 @@ Pass CanonicalizeOps() {
     [=](Function f, IRModule m, PassContext pc) {
     return Downcast<Function>(CanonicalizeOps(f));
   };
-  return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
index 3884dac..af6b135 100644 (file)
@@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) {
     [=](Function f, IRModule m, PassContext pc) {
       return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
   };
-  return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
index 612dae5..1278020 100644 (file)
@@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) {
     [=](Function f, IRModule m, PassContext pc) {
       return Downcast<Function>(CombineParallelDense(f, min_num_branches));
   };
-  return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
index 55ca3f6..361565e 100644 (file)
@@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name,
                                                        batch_op_name,
                                                        min_num_branches));
   };
-  return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
index 871969d..dbb2c38 100644 (file)
@@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
         return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
       };
   return CreateFunctionPass(
-      pass_func, 3, "ConvertLayout",
-      {tir::StringImmNode::make("InferType"),
-       tir::StringImmNode::make("CanonicalizeOps")});
+      pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
index b4d61f1..908ba87 100644 (file)
@@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) {
     [=](Function f, IRModule m, PassContext pc) {
     return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
   };
-  return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
index f905ba5..68c59f5 100644 (file)
@@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) {
     [=](Function f, IRModule m, PassContext pc) {
       return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
   };
-  return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
index cf00a89..8234dea 100644 (file)
@@ -70,8 +70,7 @@ Pass FastMath() {
     [=](Function f, IRModule m, PassContext pc) {
     return Downcast<Function>(FastMath(f));
   };
-  return CreateFunctionPass(pass_func, 4, "FastMath",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.FastMath")
index 49f6e3f..cfe74bf 100644 (file)
@@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() {
       return Downcast<Function>(
           relay::fold_scale_axis::ForwardFoldScaleAxis(f));
   };
-  return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
@@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() {
       return Downcast<Function>(
           relay::fold_scale_axis::BackwardFoldScaleAxis(f));
     };
-  return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
index 9168898..f646042 100644 (file)
@@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) {
     int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
     return Downcast<Function>(FuseOps(f, opt_level, m));
   };
-  return CreateFunctionPass(pass_func, 1, "FuseOps",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
index ef3c51f..ba0f568 100644 (file)
@@ -131,7 +131,7 @@ class Inliner : ExprMutator {
                          fn->attrs);
     // Inline the function body to the caller if this function uses default
     // compiler, i.e. no external codegen is needed.
-    if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+    if (!func->GetAttr<String>(attr::kCompiler).defined()) {
       CHECK_EQ(func->params.size(), args.size())
           << "Mismatch found in the number of parameters and call args";
       // Bind the parameters with call args.
index 01411a6..0b5c671 100644 (file)
@@ -101,7 +101,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
       [=](Function f, IRModule m, PassContext pc) {
         return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
       };
-  return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
index 35b93dc..75d95f0 100644 (file)
@@ -159,9 +159,9 @@ class MergeCompositeWrapper : public ExprMutator {
     if (call->op->IsInstance<FunctionNode>()) {
       Function func = Downcast<Function>(call->op);
       CHECK(func.defined());
-      const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite);
+      auto name_node = func->GetAttr<String>(attr::kComposite);
       // don't step into existing composite functions
-      if (name_node.defined() && name_node->value != "") {
+      if (name_node.defined() && name_node != "") {
         tvm::Array<tvm::relay::Expr> new_args;
         for (const auto& arg : call->args) {
           auto new_e = this->Mutate(arg);
@@ -185,7 +185,7 @@ class MergeCompositeWrapper : public ExprMutator {
       auto free_vars = FreeVars(extract);
       // make the composite function
       auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
-      f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_));
+      f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_));
       // find the expressions associated with the free vars using the args_map
       // this tells us which expressions should be given as inputs to the composite function
       Array<Expr> args;
@@ -207,16 +207,14 @@ class MergeCompositeWrapper : public ExprMutator {
   PackedFunc check_;
 };
 
-Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names,
+Expr MergeComposite(const Expr& expr, const Array<runtime::String>& pattern_names,
                     const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
   CHECK_EQ(pattern_names.size(), patterns.size());
   Expr merged_expr = expr;
   // merge the patterns one-by-one in order
   for (size_t i = 0; i < patterns.size(); i++) {
-    std::string pattern_name = pattern_names[i]->value;
-    Expr pattern = patterns[i];
-    PackedFunc check = checks[i];
-    merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
+    merged_expr =
+        MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr);
   }
   return merged_expr;
 }
@@ -225,7 +223,7 @@ Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names
 
 namespace transform {
 
-Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
+Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
                     const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
       [=](Function f, IRModule m, PassContext pc) {
@@ -236,8 +234,9 @@ Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
   return func_pass;
 }
 
-TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
-  tvm::Array<tir::StringImm> pattern_names = args[0];
+TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+  tvm::Array<runtime::String> pattern_names = args[0];
   tvm::Array<Expr> patterns = args[1];
   std::vector<PackedFunc> checks;
   for (int i = 2; i < args.size(); i++) {
index a4e3863..8eeac17 100644 (file)
@@ -245,7 +245,7 @@ class Partitioner : public ExprMutator {
         global_region_func =
             WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
         global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
-                                      tvm::tir::StringImmNode::make(target));
+                                      tvm::runtime::String(target));
         global_region_func =
             WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
 
index bc7c15e..d349fdd 100644 (file)
@@ -204,8 +204,7 @@ Pass SimplifyInference() {
     [=](Function f, IRModule m, PassContext pc) {
     return Downcast<Function>(SimplifyInference(f));
   };
-  return CreateFunctionPass(pass_func, 0, "SimplifyInference",
-                            {tir::StringImmNode::make("InferType")});
+  return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
index 6e35dfb..21c5162 100644 (file)
@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
   for (const auto& it : funcs) {
     CHECK_EQ(FreeVars(it.second).size(), 0);
     if (const auto* n = it.second.as<FunctionNode>()) {
-      if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+      if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
     }
     Expr ret =
       TransformF([&](const Expr& e) {
index 400f646..81dfd3d 100644 (file)
@@ -32,14 +32,14 @@ namespace runtime {
 
 using namespace vm;
 
-TVM_REGISTER_GLOBAL("runtime.container._GetADTTag")
+TVM_REGISTER_GLOBAL("runtime.GetADTTag")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   ObjectRef obj = args[0];
   const auto& adt = Downcast<ADT>(obj);
   *rv = static_cast<int64_t>(adt.tag());
 });
 
-TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
+TVM_REGISTER_GLOBAL("runtime.GetADTSize")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   ObjectRef obj = args[0];
   const auto& adt = Downcast<ADT>(obj);
@@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
 });
 
 
-TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
+TVM_REGISTER_GLOBAL("runtime.GetADTFields")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   ObjectRef obj = args[0];
   int idx = args[1];
@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
   *rv = adt[idx];
 });
 
-TVM_REGISTER_GLOBAL("runtime.container._Tuple")
+TVM_REGISTER_GLOBAL("runtime.Tuple")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   std::vector<ObjectRef> fields;
   for (auto i = 0; i < args.size(); ++i) {
@@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("runtime.container._Tuple")
   *rv = ADT::Tuple(fields);
 });
 
-TVM_REGISTER_GLOBAL("runtime.container._ADT")
+TVM_REGISTER_GLOBAL("runtime.ADT")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   int itag = args[0];
   size_t tag = static_cast<size_t>(itag);
@@ -76,11 +76,31 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT")
   *rv = ADT(tag, fields);
 });
 
-TVM_REGISTER_GLOBAL("runtime.container._String")
+TVM_REGISTER_GLOBAL("runtime.String")
 .set_body_typed([](std::string str) {
   return String(std::move(str));
 });
 
+TVM_REGISTER_GLOBAL("runtime.GetStringSize")
+.set_body_typed([](String str) {
+  return static_cast<int64_t>(str.size());
+});
+
+TVM_REGISTER_GLOBAL("runtime.GetStdString")
+.set_body_typed([](String str) {
+  return std::string(str);
+});
+
+TVM_REGISTER_GLOBAL("runtime.CompareString")
+.set_body_typed([](String lhs, String rhs) {
+  return lhs.compare(rhs);
+});
+
+TVM_REGISTER_GLOBAL("runtime.StringHash")
+.set_body_typed([](String str) {
+  return static_cast<int64_t>(std::hash<String>()(str));
+});
+
 TVM_REGISTER_OBJECT_TYPE(ADTObj);
 TVM_REGISTER_OBJECT_TYPE(StringObj);
 TVM_REGISTER_OBJECT_TYPE(ClosureObj);
index fc45cef..5ba51da 100644 (file)
@@ -57,7 +57,7 @@ ExtractFuncInfo(const IRModule& mod) {
         info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
       }
     }
-    auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
     fmap[static_cast<std::string>(global_symbol)] = info;
   }
   return fmap;
index 8eef4b7..44d017f 100644 (file)
@@ -22,6 +22,7 @@
 #include <dmlc/thread_local.h>
 
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
 #include <tvm/node/node.h>
 #include <tvm/node/repr_printer.h>
 #include <tvm/target/target.h>
@@ -150,12 +151,12 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc")
   GenericFunc generic_func = args[0];
   // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
   PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
-  Array<PrimExpr> tags = args[2];
+  Array<runtime::String> tags = args[2];
   bool allow_override = args[3];
 
   std::vector<std::string> tags_vector;
   for (auto& tag : tags) {
-    tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value);
+    tags_vector.push_back(tag);
   }
 
   generic_func
index f0b0a4b..a863056 100644 (file)
@@ -126,7 +126,7 @@ void CodeGenCPU::Init(const std::string& module_name,
 void CodeGenCPU::AddFunction(const PrimFunc& f) {
   CodeGenLLVM::AddFunction(f);
   if (f_tvm_register_system_symbol_ != nullptr) {
-    auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
     CHECK(global_symbol.defined())
         << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
     export_system_symbols_.emplace_back(
index 28f4efd..bb0b7e4 100644 (file)
@@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   llvm::FunctionType* ftype = llvm::FunctionType::get(
       ret_void ? t_void_ : t_int_, param_types, false);
 
-  auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
   CHECK(global_symbol.defined())
       << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
   CHECK(module_->getFunction(static_cast<std::string>(global_symbol)) == nullptr)
index 9ea77ac..52dccba 100644 (file)
@@ -214,7 +214,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
           << "Can only lower IR Module with PrimFuncs";
       auto f = Downcast<PrimFunc>(kv.second);
       if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-        auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+        auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
         CHECK(global_symbol.defined());
         entry_func = global_symbol;
       }
index 0cb4742..a0e18a6 100644 (file)
@@ -78,7 +78,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
   // reserve keywords
   ReserveKeywordsAsUnique();
 
-  auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
   CHECK(global_symbol.defined())
       << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
   bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
index 2f31a3e..715c0ae 100644 (file)
@@ -56,7 +56,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
   GetUniqueName("_");
 
   // add to alloc buffer type.
-  auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
   CHECK(global_symbol.defined())
       << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
 
index 4748599..13d87d2 100644 (file)
@@ -156,7 +156,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) {
     arg_kinds.push_back(kind);
   }
 
-  auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
   CHECK(global_symbol.defined())
       << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
 
index 6c1c3b9..7486164 100644 (file)
@@ -147,7 +147,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
   std::string whole_code = cg.Finish();
 
   // Generate source code for compilation.
-  Array<Array<PrimExpr> > kernel_info;
+  Array<Array<runtime::String> > kernel_info;
 
   for (auto kv :  mod->functions) {
     CHECK(kv.second->IsInstance<PrimFuncNode>())
@@ -161,11 +161,10 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
       code = (*f)(code).operator std::string();
     }
 
-    auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
     CHECK(global_symbol.defined())
         << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
-    std::string func_name = global_symbol;
-    kernel_info.push_back(Array<PrimExpr>({func_name, code}));
+    kernel_info.push_back({global_symbol, code});
   }
 
   std::string xclbin;
index b6f9b86..5872141 100644 (file)
@@ -90,7 +90,7 @@ runtime::Module BuildSPIRV(IRModule mod) {
     CHECK(calling_conv.defined() &&
           calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
         << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
     CHECK(global_symbol.defined())
         << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
 
index 0241e22..db2a2f3 100644 (file)
@@ -78,7 +78,7 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
   builder_->MakeInst(spv::OpReturn);
   builder_->MakeInst(spv::OpFunctionEnd);
 
-  auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
   CHECK(global_symbol.defined())
       << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
 
index af8b341..da75a70 100644 (file)
@@ -536,7 +536,7 @@ runtime::Module BuildStackVM(const IRModule& mod) {
     CHECK(kv.second->IsInstance<PrimFuncNode>())
         << "CodeGenStackVM: Can only take PrimFunc";
     auto f = Downcast<PrimFunc>(kv.second);
-    auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
     CHECK(global_symbol.defined())
         << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
     std::string f_name = global_symbol;
index 8fb9cb6..306fba4 100644 (file)
@@ -62,39 +62,39 @@ Target CreateTarget(const std::string& target_name,
   std::string device_flag = "-device=";
   std::string keys_flag = "-keys=";
   for (auto& item : options) {
-    t->options_array.push_back(tir::StringImmNode::make(item));
+    t->options_array.push_back(item);
 
     if (item.find(libs_flag) == 0) {
       std::stringstream ss(item.substr(libs_flag.length()));
       std::string lib_item;
       while (std::getline(ss, lib_item, ',')) {
-        t->libs_array.push_back(tir::StringImmNode::make(lib_item));
+        t->libs_array.push_back(lib_item);
       }
     } else if (item.find(device_flag) == 0) {
       t->device_name = item.substr(device_flag.length());
-      t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
+      t->keys_array.push_back(t->device_name);
     } else if (item.find(keys_flag) == 0) {
       std::stringstream ss(item.substr(keys_flag.length()));
       std::string key_item;
       while (std::getline(ss, key_item, ',')) {
-        t->keys_array.push_back(tir::StringImmNode::make(key_item));
+        t->keys_array.push_back(key_item);
       }
     }
   }
 
   if (t->device_name.length() > 0) {
-    t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
+    t->keys_array.push_back(t->device_name);
   }
   t->device_type = kDLCPU;
   t->thread_warp_size = 1;
   if (target_name == "c" && t->device_name == "micro_dev") {
     t->device_type = kDLMicroDev;
   } else if (target_name == "c" || target_name == "llvm") {
-    t->keys_array.push_back(tir::StringImmNode::make("cpu"));
+    t->keys_array.push_back("cpu");
   } else if (target_name == "cuda" || target_name == "nvptx") {
     t->device_type = kDLGPU;
-    t->keys_array.push_back(tir::StringImmNode::make("cuda"));
-    t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+    t->keys_array.push_back("cuda");
+    t->keys_array.push_back("gpu");
     t->max_num_threads = 1024;
     t->thread_warp_size = 32;
   } else if (target_name == "rocm" || target_name == "opencl") {
@@ -104,8 +104,8 @@ Target CreateTarget(const std::string& target_name,
     } else {
       t->device_type = kDLROCM;
     }
-    t->keys_array.push_back(tir::StringImmNode::make(target_name));
-    t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+    t->keys_array.push_back(target_name);
+    t->keys_array.push_back("gpu");
     t->max_num_threads = 256;
     if (t->device_name == "intel_graphics") {
       t->thread_warp_size = 16;
@@ -116,20 +116,20 @@ Target CreateTarget(const std::string& target_name,
     } else {
       t->device_type = kDLVulkan;
     }
-    t->keys_array.push_back(tir::StringImmNode::make(target_name));
-    t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+    t->keys_array.push_back(target_name);
+    t->keys_array.push_back("gpu");
     t->max_num_threads = 256;
   } else if (target_name == "sdaccel") {
     t->device_type = kDLOpenCL;
-    t->keys_array.push_back(tir::StringImmNode::make("sdaccel"));
-    t->keys_array.push_back(tir::StringImmNode::make("hls"));
+    t->keys_array.push_back("sdaccel");
+    t->keys_array.push_back("hls");
   } else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
     t->device_type = kDLAOCL;
-    t->keys_array.push_back(tir::StringImmNode::make("aocl"));
-    t->keys_array.push_back(tir::StringImmNode::make("hls"));
+    t->keys_array.push_back("aocl");
+    t->keys_array.push_back("hls");
   } else if (target_name == "opengl") {
     t->device_type = kOpenGL;
-    t->keys_array.push_back(tir::StringImmNode::make("opengl"));
+    t->keys_array.push_back("opengl");
   } else if (target_name == "stackvm") {
     t->device_type = kDLCPU;
   } else if (target_name == "ext_dev") {
@@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("target.TargetFromString")
 std::vector<std::string> TargetNode::keys() const {
   std::vector<std::string> result;
   for (auto& expr : keys_array) {
-    result.push_back(expr.as<tir::StringImmNode>()->value);
+    result.push_back(expr);
   }
   return result;
 }
@@ -176,7 +176,7 @@ std::vector<std::string> TargetNode::keys() const {
 std::vector<std::string> TargetNode::options() const {
   std::vector<std::string> result;
   for (auto& expr : options_array) {
-    result.push_back(expr.as<tir::StringImmNode>()->value);
+    result.push_back(expr);
   }
   return result;
 }
@@ -184,7 +184,7 @@ std::vector<std::string> TargetNode::options() const {
 std::unordered_set<std::string> TargetNode::libs() const {
   std::unordered_set<std::string> result;
   for (auto& expr : libs_array) {
-    result.insert(expr.as<tir::StringImmNode>()->value);
+    result.insert(expr);
   }
   return result;
 }
index 891d137..0efa33a 100644 (file)
@@ -47,7 +47,6 @@ Var::Var(std::string name_hint, Type type_annotation) {
   data_ = std::move(n);
 }
 
-
 Var Var::copy_with_suffix(const std::string& suffix) const {
   const VarNode* node = get();
   ObjectPtr<VarNode> new_ptr;
@@ -826,20 +825,28 @@ TVM_REGISTER_GLOBAL("tir.Load")
     }
   });
 
-
-
 TVM_REGISTER_GLOBAL("tir.Call")
 .set_body_typed([](
   DataType type, std::string name,
-  Array<PrimExpr> args, int call_type,
+  Array<ObjectRef> args, int call_type,
   FunctionRef func, int value_index
 ) {
+  Array<PrimExpr> prim_expr_args;
+  for (const auto& it : args) {
+    CHECK(it->IsInstance<runtime::StringObj>() ||
+          it->IsInstance<PrimExprNode>());
+    if (const auto* str = it.as<runtime::StringObj>()) {
+      prim_expr_args.push_back(StringImmNode::make(str->data));
+    } else {
+      prim_expr_args.push_back(Downcast<PrimExpr>(it));
+    }
+  }
   return CallNode::make(type,
-                    name,
-                    args,
-                    static_cast<CallNode::CallType>(call_type),
-                    func,
-                    value_index);
+                        name,
+                        prim_expr_args,
+                        static_cast<CallNode::CallType>(call_type),
+                        func,
+                        value_index);
 });
 
 }  // namespace tir
index ea19982..96fc435 100644 (file)
@@ -120,10 +120,10 @@ class IRTransformer final :
 Stmt IRTransform(Stmt ir_node,
                  const runtime::PackedFunc& f_preorder,
                  const runtime::PackedFunc& f_postorder,
-                 const Array<PrimExpr>& only_enable) {
+                 const Array<runtime::String>& only_enable) {
   std::unordered_set<uint32_t> only_type_index;
-  for (PrimExpr s : only_enable) {
-    only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
+  for (auto s : only_enable) {
+    only_type_index.insert(Object::TypeKey2Index(s.c_str()));
   }
   IRTransformer transform(f_preorder, f_postorder, only_type_index);
   return transform(std::move(ir_node));
index 773c67d..001c7cf 100644 (file)
@@ -124,7 +124,7 @@ Pass CreatePrimFuncPass(
     const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
-    const tvm::Array<tvm::PrimExpr>& required) {
+    const tvm::Array<runtime::String>& required) {
   PassInfo pass_info = PassInfo(opt_level, name, required);
   return PrimFuncPass(pass_func, pass_info);
 }
index 30542ea..c684b9e 100644 (file)
@@ -42,7 +42,8 @@ void BinderAddAssert(PrimExpr cond,
   if (!is_one(scond)) {
     std::ostringstream os;
     os << "Argument " << arg_name << " has an unsatisfied constraint";
-    asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0)));
+    asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()),
+                                               EvaluateNode::make(0)));
   }
 }
 
@@ -173,7 +174,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   ndim_err_msg << arg_name
                << ".ndim is expected to equal "
                << buffer->shape.size();
-  asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
+  auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str());
+  asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
   // type checks
   DataType dtype = buffer->dtype;
   std::ostringstream type_err_msg;
@@ -187,7 +189,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   if (!(dtype == DataType::Int(4) ||
         dtype == DataType::UInt(4) ||
         dtype == DataType::Int(1))) {
-    asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
+    auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str());
+    asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
+    asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
   }
   // data field
   if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
@@ -245,9 +249,10 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
     stride_err_msg << arg_name << ".strides:"
                    << " expected to be compact array";
     if (conds.size() != 0) {
+      auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
       Stmt check =
           AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
-                           stride_err_msg.str(), EvaluateNode::make(0));
+                           stride_msg, EvaluateNode::make(0));
       check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
       asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
     }
@@ -269,9 +274,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   } else {
     std::ostringstream stride_null_err_msg;
     stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
-    asserts_.emplace_back(
-        AssertStmtNode::make(
-            NotNode::make(is_null), stride_null_err_msg.str(), nop));
+    asserts_.emplace_back(AssertStmtNode::make(
+        NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop));
 
     for (size_t k = 0; k < buffer->strides.size(); ++k) {
       std::ostringstream field_name;
index 1fd43ff..8bc4620 100644 (file)
@@ -159,8 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
       }
     });
 
-  return IRTransform(parent_for_stmt, nullptr, replace_target_for,
-                     {PrimExpr("For")});
+  return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"});
 }
 
 // Remove IfThenElse node from a For node.
@@ -186,11 +185,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
       }
     });
 
-  then_for = IRTransform(for_stmt, nullptr, replace_then_case,
-                         {PrimExpr("IfThenElse")});
+  then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"});
   if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
-    else_for = IRTransform(for_stmt, nullptr, replace_else_case,
-                           {PrimExpr("IfThenElse")});
+    else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"});
   }
 
   return std::make_pair(then_for, else_for);
@@ -411,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
         *ret = new_for;
       }
     });
-  return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")});
+  return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")});
 }
 
 Stmt HoistIfThenElse(Stmt stmt) {
index 88f7496..dc2df98 100644 (file)
@@ -860,7 +860,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
         auto it = matrix_abc_.find(simplify_name(node->name));
         CHECK(it != matrix_abc_.end())
               << "Cannot find matrix info for " << node->name;
-        auto matrix_abc = "wmma." + it->second;
+        auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second);
         Stmt body = this->VisitStmt(op->body);
         return AttrStmtNode::make(op->node,
                               op->attr_key,
index 486f21c..952d663 100644 (file)
@@ -47,7 +47,8 @@ class DeviceTypeBinder: public StmtExprMutator {
         var_ = nullptr;
         std::ostringstream os;
         os << "device_type need to be " << device_type_;
-        return AssertStmtNode::make(op->value == value, os.str(), body);
+        return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()),
+                                    body);
       }
     }
     return StmtExprMutator::VisitStmt_(op);
index c49b044..b1dd235 100644 (file)
@@ -41,12 +41,13 @@ namespace tvm {
 namespace tir {
 
 inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
-  return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
+  return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg),
+                              EvaluateNode::make(0));
 }
 
 PrimFunc MakePackedAPI(PrimFunc&& func,
                        int num_unpacked_args) {
-  auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
   CHECK(global_symbol.defined())
       << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
   std::string name_hint = global_symbol;
@@ -140,17 +141,19 @@ PrimFunc MakePackedAPI(PrimFunc&& func,
             AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
                                  tcode == kTVMNDArrayHandle ||
                                  tcode == kTVMDLTensorHandle ||
-                                 tcode == kTVMNullptr, msg.str(), nop));
+                                 tcode == kTVMNullptr,
+                                 tvm::tir::StringImmNode::make(msg.str()), nop));
       } else if (t.is_int() || t.is_uint()) {
         std::ostringstream msg;
         msg << name_hint << ": Expect arg[" << i << "] to be int";
-        seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
+        seq_check.emplace_back(
+            AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop));
       } else {
         CHECK(t.is_float());
         std::ostringstream msg;
         msg << name_hint << ": Expect arg[" << i << "] to be float";
         seq_check.emplace_back(
-            AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
+            AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop));
       }
     } else {
       args.push_back(v_arg);
index f695b3c..f366353 100644 (file)
@@ -76,12 +76,10 @@ class ThreadAxisRewriter : private StmtExprMutator {
 };
 
 
-PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
+PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) {
   std::unordered_map<std::string, IterVar> tmap;
   for (const auto& kv : thread_map) {
-    const StringImmNode* str = kv.first.as<StringImmNode>();
-    CHECK(str != nullptr);
-    tmap[str->value] = kv.second;
+    tmap[kv.first] = kv.second;
   }
 
   auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
@@ -101,7 +99,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
 
 namespace transform {
 
-Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
+Pass RemapThreadAxis(Map<runtime::String, IterVar> thread_map) {
   auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
     return RemapThreadAxis(std::move(f), thread_map);
   };
index ae32bdc..5149d28 100644 (file)
@@ -272,7 +272,7 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
   auto target = func->GetAttr<Target>(tvm::attr::kTarget);
   CHECK(target.defined())
       << "SplitHostDevice: Require the target attribute";
-  auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
   CHECK(global_symbol.defined())
       << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
 
index f1198e7..063247d 100644 (file)
@@ -261,7 +261,7 @@ TEST(String, empty) {
   using namespace std;
   String s{"hello"};
   CHECK_EQ(s.empty(), false);
-  s = "";
+  s = std::string("");
   CHECK_EQ(s.empty(), true);
 }
 
index 7301ef7..dd00d7e 100644 (file)
@@ -231,7 +231,7 @@ def test_composite_function():
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
-        add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+        add_relu = add_relu.with_attr("Composite", "test.add_relu")
 
         # merged function
         r = relay.Call(add_relu, [a, b])
@@ -249,7 +249,7 @@ def test_composite_function():
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
-        add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+        add_relu = add_relu.with_attr("Composite", "test.add_relu")
 
         # merged function
         cb_1 = relay.annotation.compiler_begin(a, "test")
index 0af55d2..bae077c 100644 (file)
@@ -134,7 +134,7 @@ def test_recursive_func():
     func = relay.Function([i],
                           sb.get(),
                           ret_type=relay.TensorType([], 'int32'))
-    func = func.with_attr("Compiler", tvm.tir.StringImm("a"))
+    func = func.with_attr("Compiler", "a")
     mod[sum_up] = func
     iarg = relay.var('i', shape=[], dtype='int32')
     mod["main"] = relay.Function([iarg], sum_up(iarg))
index 724e81d..b4496bb 100644 (file)
@@ -79,9 +79,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
 
 def set_external_func_attr(func, compiler, ext_symbol):
     func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
-    func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
-    func = func.with_attr("global_symbol",
-                          runtime.container.String(ext_symbol))
+    func = func.with_attr("Compiler", compiler)
+    func = func.with_attr("global_symbol", ext_symbol)
     return func
 
 
index dbd5934..5a71023 100644 (file)
@@ -96,12 +96,14 @@ def test_function():
     body = relay.Tuple(tvm.runtime.convert([]))
     type_params = tvm.runtime.convert([])
     fn = relay.Function(params, body, ret_type, type_params)
-    fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value"))
+    fn = fn.with_attr("test_attribute", "value")
+    fn = fn.with_attr("test_attribute1", "value1")
     assert fn.params == params
     assert fn.body == body
     assert fn.type_params == type_params
     assert fn.span == None
     assert fn.attrs["test_attribute"] == "value"
+    assert fn.attrs["test_attribute1"] == "value1"
     str(fn)
     check_json_roundtrip(fn)
 
index 271960e..e1a0a01 100644 (file)
@@ -356,7 +356,7 @@ def test_function_attr():
     p00 = relay.subtract(z00, w01)
     q00 = relay.multiply(p00, w02)
     func0 = relay.Function([x0, w00, w01, w02], q00)
-    func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a"))
+    func0 = func0.with_attr("FuncName", "a")
 
     x1 = relay.var('x1', shape=(10, 10))
     w10 = relay.var('w10', shape=(10, 10))
@@ -366,7 +366,7 @@ def test_function_attr():
     p10 = relay.subtract(z10, w11)
     q10 = relay.multiply(p10, w12)
     func1 = relay.Function([x1, w10, w11, w12], q10)
-    func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b"))
+    func1 = func1.with_attr("FuncName", "b")
     assert not consistent_equal(func0, func1)
 
 
@@ -698,7 +698,7 @@ def test_fn_attribute():
     d = relay.var('d', shape=(10, 10))
     add_1 = relay.add(c, d)
     add_1_fn = relay.Function([c, d], add_1)
-    add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test"))
+    add_1_fn = add_1_fn.with_attr("TestAttribute", "test")
     add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
 
     assert not consistent_equal(add_1_fn, add_fn)
index 0f6d539..3b41f07 100644 (file)
@@ -209,7 +209,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
         g11 = relay.GlobalVar("g11")
         fn11 = relay.Function([x11], x11)
         fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn11 = fn11.with_attr("Compiler", "a")
         mod[g11] = fn11
 
         x1 = relay.var("x1", shape=(3, 5))
@@ -244,7 +244,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
         x11 = relay.var("x11", shape=(3, 5))
         fn11 = relay.Function([x11], x11)
         fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn11 = fn11.with_attr("Compiler", "a")
 
         x2 = relay.var("x2", shape=(3, 5))
         y2 = relay.var("y2", shape=(3, 5))
@@ -367,7 +367,7 @@ def test_recursive_not_called_extern_compiler():
         x1 = relay.var("x1", shape=(2, 2))
         fn1 = relay.Function([x1], x1)
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn1 = fn1.with_attr("Compiler", "a")
         g1 = relay.GlobalVar("g1")
         mod[g1] = fn1
         mod["main"] = relay.Function([x, y], x + y + g1(x))
@@ -380,7 +380,7 @@ def test_recursive_not_called_extern_compiler():
         x1 = relay.var("x1", shape=(2, 2))
         fn1 = relay.Function([x1], x1)
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn1 = fn1.with_attr("Compiler", "a")
         mod["main"] = relay.Function([x, y], x + y + fn1(x))
         return mod
 
@@ -446,7 +446,7 @@ def test_globalvar_as_call_arg_extern_compiler():
         sb.ret(x1 + y1)
         fn1 = relay.Function([x1, y1], sb.get())
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn1 = fn1.with_attr("Compiler", "a")
         g1 = relay.GlobalVar("g1")
         mod[g1] = fn1
 
@@ -456,7 +456,7 @@ def test_globalvar_as_call_arg_extern_compiler():
         sb1.ret(x2 - y2)
         fn2 = relay.Function([x2, y2], sb1.get())
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+        fn2 = fn2.with_attr("Compiler", "b")
         g2 = relay.GlobalVar("g2")
         mod[g2] = fn2
 
@@ -478,7 +478,7 @@ def test_globalvar_as_call_arg_extern_compiler():
         sb.ret(x1 + y1)
         fn1 = relay.Function([x1, y1], sb.get())
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn1 = fn1.with_attr("Compiler", "a")
 
         x2 = relay.var("x2", shape=(3, 5))
         y2 = relay.var("y2", shape=(3, 5))
@@ -486,7 +486,7 @@ def test_globalvar_as_call_arg_extern_compiler():
         sb1.ret(x2 - y2)
         fn2 = relay.Function([x2, y2], sb1.get())
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+        fn2 = fn2.with_attr("Compiler", "b")
 
         p0 = relay.var("p0", shape=(3, 5))
         p1 = relay.var("p1", shape=(3, 5))
@@ -539,10 +539,10 @@ def test_inline_globalvar_without_args_extern_compiler():
         mod = tvm.IRModule({})
         fn1 = relay.Function([], relay.const(1))
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn1 = fn1.with_attr("Compiler", "a")
         fn2 = relay.Function([], relay.const(2))
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+        fn2 = fn2.with_attr("Compiler", "b")
         g1 = relay.GlobalVar('g1')
         g2 = relay.GlobalVar('g2')
         mod[g1] = fn1
@@ -555,10 +555,10 @@ def test_inline_globalvar_without_args_extern_compiler():
         mod = tvm.IRModule({})
         fn1 = relay.Function([], relay.const(1))
         fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+        fn1 = fn1.with_attr("Compiler", "a")
         fn2 = relay.Function([], relay.const(2))
         fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+        fn2 = fn2.with_attr("Compiler", "b")
         p = relay.var('p', 'bool')
         mod['main'] = relay.Function([p], relay.Call(
             relay.If(p, fn1, fn2), []))
@@ -787,7 +787,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
         y0 = relay.var("y0", shape=(3, 5))
         fn0 = relay.Function([x0, y0], x0 * y0)
         fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
+        fn0 = fn0.with_attr("Compiler", "aa")
         g0 = relay.GlobalVar("g0")
         mod[g0] = fn0
 
@@ -811,7 +811,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
         y0 = relay.var("y0", shape=(3, 5))
         fn0 = relay.Function([x0, y0], x0 * y0)
         fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
+        fn0 = fn0.with_attr("Compiler", "aa")
 
         x1 = relay.var("x1", shape=(3, 5))
         y1 = relay.var("y1", shape=(3, 5))
index 110d855..e3c8991 100644 (file)
@@ -184,7 +184,7 @@ def test_simple_merge():
         add_node = relay.add(in_1, in_2)
         relu_node = relay.nn.relu(add_node)
         add_relu = relay.Function([in_1, in_2], relu_node)
-        add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
+        add_relu = add_relu.with_attr("Composite", "add_relu")
 
         # merged function
         r = relay.Call(add_relu, [a, b])
@@ -249,8 +249,7 @@ def test_branch_merge():
         sub_node = relay.subtract(in_1, in_2)
         mul_node = relay.multiply(add_node, sub_node)
         add_sub_mul = relay.Function([in_1, in_2], mul_node)
-        add_sub_mul = add_sub_mul.with_attr("Composite",
-                                                tir.StringImm("add_sub_mul"))
+        add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul")
 
         # add_sub_mul1 function
         in_3 = relay.var('in_3', shape=(10, 10))
@@ -259,8 +258,7 @@ def test_branch_merge():
         sub_node_1 = relay.subtract(in_3, in_4)
         mul_node_1 = relay.multiply(add_node_1, sub_node_1)
         add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
-        add_sub_mul_1 = add_sub_mul_1.with_attr("Composite",
-                                                    tir.StringImm("add_sub_mul"))
+        add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul")
 
         # merged function
         m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
@@ -319,8 +317,7 @@ def test_reuse_call_merge():
         add_node_1 = relay.add(in_1, add_node)
         add_node_2 = relay.add(add_node_1, add_node)
         add_add_add = relay.Function([in_1, in_2], add_node_2)
-        add_add_add = add_add_add.with_attr("Composite",
-                                                tir.StringImm("add_add_add"))
+        add_add_add = add_add_add.with_attr("Composite", "add_add_add")
 
         # merged function
         sub_node = relay.subtract(a, b)
@@ -404,7 +401,7 @@ def test_multiple_patterns():
         r = relay.nn.relu(bias_node)
         conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
         conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite",
-                                                              tir.StringImm("conv2d_bias_relu"))
+                                                          "conv2d_bias_relu")
 
         # add_relu function
         in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
@@ -412,7 +409,7 @@ def test_multiple_patterns():
         add_node = relay.add(in_4, in_5)
         r = relay.nn.relu(add_node)
         add_relu = relay.Function([in_4, in_5], r)
-        add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
+        add_relu = add_relu.with_attr("Composite", "add_relu")
 
         # merged function
         conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
@@ -481,8 +478,7 @@ def test_merge_order():
         out = relay.abs(out)
         out = relay.nn.relu(out)
         merged_func = relay.Function([x, y], out)
-        merged_func = merged_func.with_attr('Composite',
-                                                tir.StringImm(composite_name))
+        merged_func = merged_func.with_attr('Composite', composite_name)
         ret = relay.Call(merged_func, [input_1, input_2])
         return relay.Function([input_1, input_2], ret)
 
@@ -547,13 +543,13 @@ def test_parallel_merge():
         y = relay.var('y')
         branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
         func_1 = relay.Function([x, y], branch_1)
-        func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul"))
+        func_1 = func_1.with_attr('Composite', "add_sub_mul")
         call_1 = relay.Call(func_1, [input_1, input_2])
         x1 = relay.var('x1')
         y1 = relay.var('y1')
         branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
         func_2 = relay.Function([x1, y1], branch_2)
-        func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul"))
+        func_2 = func_2.with_attr('Composite', "add_sub_mul")
         call_2 = relay.Call(func_2, [input_1, input_2])
         out = relay.multiply(call_1, call_2)
         return relay.Function([input_1, input_2], out)
@@ -632,14 +628,14 @@ def test_multiple_input_subgraphs():
         add_relu_1 = relay.add(x, y)
         add_relu_1 = relay.nn.relu(add_relu_1)
         add_relu_1 = relay.Function([x, y], add_relu_1)
-        add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu'))
+        add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu')
         add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
         x1 = relay.var('x1')
         y1 = relay.var('y1')
         add_relu_2 = relay.add(x1, y1)
         add_relu_2 = relay.nn.relu(add_relu_2)
         add_relu_2 = relay.Function([x1, y1], add_relu_2)
-        add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu'))
+        add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu')
         add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
         x2 = relay.var('x2')
         y2 = relay.var('y2')
@@ -647,7 +643,7 @@ def test_multiple_input_subgraphs():
         sub = relay.subtract(x2, y2)
         add_sub_mul = relay.multiply(add, sub)
         add_sub_mul = relay.Function([x2, y2], add_sub_mul)
-        add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul'))
+        add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul')
         add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
         return relay.Function(inputs, add_sub_mul_call)
 
@@ -660,7 +656,7 @@ def test_multiple_input_subgraphs():
             add_relu = relay.add(x, y)
             add_relu = relay.nn.relu(add_relu)
             add_relu = relay.Function([x, y], add_relu)
-            add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu'))
+            add_relu = add_relu.with_attr('Composite', 'add_relu')
             add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
             add_relu_calls.append(add_relu_call)
 
@@ -720,7 +716,7 @@ def test_tuple_get_item_merge():
         tuple_get_item_node = bn_node[0]
         relu_node = relay.nn.relu(tuple_get_item_node)
         bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node)
-        bn_relu = bn_relu.with_attr("Composite", tir.StringImm("bn_relu"))
+        bn_relu = bn_relu.with_attr("Composite", "bn_relu")
 
         # merged function
         r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var])
index 3959613..1968f34 100644 (file)
@@ -24,7 +24,6 @@ import tvm
 import tvm.relay.testing
 from tvm import relay
 from tvm import runtime
-from tvm.runtime import container
 from tvm.relay import transform
 from tvm.contrib import util
 from tvm.relay.op.annotation import compiler_begin, compiler_end
@@ -307,8 +306,8 @@ def test_extern_ccompiler_default_ops():
         func = relay.Function([x0, y0], add)
         func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
-        func = func.with_attr("global_symbol", container.String("ccompiler_0"))
+        func = func.with_attr("Compiler", "ccompiler")
+        func = func.with_attr("global_symbol", "ccompiler_0")
         glb_0 = relay.GlobalVar("ccompiler_0")
         mod[glb_0] = func
         add_call = relay.Call(glb_0, [x, y])
@@ -392,8 +391,8 @@ def test_extern_dnnl():
         func = relay.Function([data0, input0, input1], out)
         func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
-        func = func.with_attr("global_symbol", container.String("dnnl_0"))
+        func = func.with_attr("Compiler", "dnnl")
+        func = func.with_attr("global_symbol", "dnnl_0")
         glb_var = relay.GlobalVar("dnnl_0")
         mod = tvm.IRModule()
         mod[glb_var] = func
@@ -518,10 +517,8 @@ def test_function_lifting():
                                bn.astuple())
         func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler",
-                                    tvm.tir.StringImm("test_compiler"))
-        func0 = func0.with_attr("global_symbol",
-                                container.String("test_compiler_0"))
+        func0 = func0.with_attr("Compiler", "test_compiler")
+        func0 = func0.with_attr("global_symbol", "test_compiler_0")
         gv0 = relay.GlobalVar("test_compiler_0")
         mod[gv0] = func0
 
@@ -537,10 +534,8 @@ def test_function_lifting():
         func1 = relay.Function([data1, weight1], conv)
         func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func1 = func1.with_attr("Compiler",
-                                    tvm.tir.StringImm("test_compiler"))
-        func1 = func1.with_attr("global_symbol",
-                                container.String("test_compiler_1"))
+        func1 = func1.with_attr("Compiler", "test_compiler")
+        func1 = func1.with_attr("global_symbol", "test_compiler_1")
         gv1 = relay.GlobalVar("test_compiler_1")
         mod[gv1] = func1
 
@@ -611,10 +606,8 @@ def test_function_lifting_inline():
                                bn.astuple())
         func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler",
-                                    tvm.tir.StringImm("test_compiler"))
-        func0 = func0.with_attr("global_symbol",
-                                container.String("test_compiler_0"))
+        func0 = func0.with_attr("Compiler", "test_compiler")
+        func0 = func0.with_attr("global_symbol", "test_compiler_0")
 
         # main function
         data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
@@ -648,8 +641,8 @@ def test_constant_propagation():
         func = relay.Function([y0], add)
         func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
-        func = func.with_attr("global_symbol", container.String("ccompiler_0"))
+        func = func.with_attr("Compiler", "ccompiler")
+        func = func.with_attr("global_symbol", "ccompiler_0")
         glb_0 = relay.GlobalVar("ccompiler_0")
         mod[glb_0] = func
         add_call = relay.Call(glb_0, [y])
@@ -748,10 +741,8 @@ def test_multiple_outputs():
                                 bn_mean, bn_var], tuple_o)
         func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler",
-                                tvm.tir.StringImm("test_target"))
-        func0 = func0.with_attr("global_symbol",
-                                container.String("test_target_2"))
+        func0 = func0.with_attr("Compiler", "test_target")
+        func0 = func0.with_attr("global_symbol", "test_target_2")
         gv0 = relay.GlobalVar("test_target_2")
         mod[gv0] = func0
 
@@ -816,10 +807,8 @@ def test_mixed_single_multiple_outputs():
 
         func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func1 = func1.with_attr("Compiler",
-                                tvm.tir.StringImm("test_target"))
-        func1 = func1.with_attr("global_symbol",
-                                container.String("test_target_1"))
+        func1 = func1.with_attr("Compiler", "test_target")
+        func1 = func1.with_attr("global_symbol", "test_target_1")
         gv1 = relay.GlobalVar("test_target_1")
         mod[gv1] = func1
 
@@ -831,10 +820,8 @@ def test_mixed_single_multiple_outputs():
 
         func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func0 = func0.with_attr("Compiler",
-                                tvm.tir.StringImm("test_target"))
-        func0 = func0.with_attr("global_symbol",
-                                container.String("test_target_0"))
+        func0 = func0.with_attr("Compiler", "test_target")
+        func0 = func0.with_attr("global_symbol", "test_target_0")
         gv0 = relay.GlobalVar("test_target_0")
         mod[gv0] = func0
 
index 8f2e9bb..48495f4 100644 (file)
@@ -41,7 +41,7 @@ def test_dict_attrs():
     dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
     assert dattr.x.value == 1
     datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
-    assert dattr.name.value == "xyz"
+    assert dattr.name == "xyz"
     assert isinstance(dattr, tvm.ir.DictAttrs)
     assert "name" in dattr
     assert dattr["x"].value == 1
index 66b8a10..ee18dea 100644 (file)
@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
     { { n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        PrimExpr("tvm.contrib.cublas.matmul"),
+        runtime::String("tvm.contrib.cublas.matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),
@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
     { { b, n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        PrimExpr("tvm.contrib.cublas.batch_matmul"),
+        runtime::String("tvm.contrib.cublas.batch_matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),
index 2fcafc7..9fe1825 100644 (file)
@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
     { { n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        PrimExpr("tvm.contrib.rocblas.matmul"),
+        runtime::String("tvm.contrib.rocblas.matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),