[TIR][REFACTOR] Add tir prefix to type keys (#5802)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sun, 14 Jun 2020 16:45:46 +0000 (09:45 -0700)
committerGitHub <noreply@github.com>
Sun, 14 Jun 2020 16:45:46 +0000 (09:45 -0700)
18 files changed:
include/tvm/tir/buffer.h
include/tvm/tir/data_layout.h
include/tvm/tir/expr.h
include/tvm/tir/stmt.h
include/tvm/tir/var.h
python/tvm/ir/json_compact.py
python/tvm/te/hybrid/util.py
python/tvm/tir/buffer.py
python/tvm/tir/data_layout.py
python/tvm/tir/expr.py
python/tvm/tir/stmt.py
src/tir/pass/hoist_if_then_else.cc
tests/python/unittest/test_target_codegen_cuda.py
tests/python/unittest/test_target_codegen_llvm.py
tests/python/unittest/test_tir_pass_hoist_if.py
tests/python/unittest/test_tir_stmt_functor_ir_transform.py
tutorials/dev/low_level_custom_pass.py
vta/python/vta/transform.py

index 34b0155..e150ff3 100644 (file)
@@ -118,7 +118,7 @@ class BufferNode : public Object {
     return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
   }
 
-  static constexpr const char* _type_key = "Buffer";
+  static constexpr const char* _type_key = "tir.Buffer";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
@@ -228,7 +228,7 @@ class DataProducerNode : public Object {
 
   void SHashReduce(SHashReducer hash_reduce) const {}
 
-  static constexpr const char* _type_key = "DataProducer";
+  static constexpr const char* _type_key = "tir.DataProducer";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
   TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
index b7cb686..d3a77cc 100644 (file)
@@ -112,7 +112,7 @@ class LayoutNode : public Object {
     v->Visit("axes", &axes);
   }
 
-  static constexpr const char* _type_key = "Layout";
+  static constexpr const char* _type_key = "tir.Layout";
   TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
 };
 
@@ -308,7 +308,7 @@ class BijectiveLayoutNode : public Object {
     v->Visit("backward_rule", &backward_rule);
   }
 
-  static constexpr const char* _type_key = "BijectiveLayout";
+  static constexpr const char* _type_key = "tir.BijectiveLayout";
   TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
 };
 
index cfb7f1e..1518d1f 100644 (file)
@@ -64,7 +64,7 @@ class StringImmNode : public PrimExprNode {
 
   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
 
-  static constexpr const char* _type_key = "StringImm";
+  static constexpr const char* _type_key = "tir.StringImm";
   TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
 };
 
@@ -101,7 +101,7 @@ class CastNode : public PrimExprNode {
     hash_reduce(value);
   }
 
-  static constexpr const char* _type_key = "Cast";
+  static constexpr const char* _type_key = "tir.Cast";
   TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
 };
 
@@ -149,7 +149,7 @@ class BinaryOpNode : public PrimExprNode {
 /*! \brief a + b */
 class AddNode : public BinaryOpNode<AddNode> {
  public:
-  static constexpr const char* _type_key = "Add";
+  static constexpr const char* _type_key = "tir.Add";
 };
 
 /*!
@@ -165,7 +165,7 @@ class Add : public PrimExpr {
 /*! \brief a - b */
 class SubNode : public BinaryOpNode<SubNode> {
  public:
-  static constexpr const char* _type_key = "Sub";
+  static constexpr const char* _type_key = "tir.Sub";
 };
 
 /*!
@@ -181,7 +181,7 @@ class Sub : public PrimExpr {
 /*! \brief a * b */
 class MulNode : public BinaryOpNode<MulNode> {
  public:
-  static constexpr const char* _type_key = "Mul";
+  static constexpr const char* _type_key = "tir.Mul";
 };
 
 /*!
@@ -200,7 +200,7 @@ class Mul : public PrimExpr {
  */
 class DivNode : public BinaryOpNode<DivNode> {
  public:
-  static constexpr const char* _type_key = "Div";
+  static constexpr const char* _type_key = "tir.Div";
 };
 
 /*!
@@ -219,7 +219,7 @@ class Div : public PrimExpr {
  */
 class ModNode : public BinaryOpNode<ModNode> {
  public:
-  static constexpr const char* _type_key = "Mod";
+  static constexpr const char* _type_key = "tir.Mod";
 };
 
 /*!
@@ -235,7 +235,7 @@ class Mod : public PrimExpr {
 /*! \brief Floor division, floor(a/b) */
 class FloorDivNode : public BinaryOpNode<FloorDivNode> {
  public:
-  static constexpr const char* _type_key = "FloorDiv";
+  static constexpr const char* _type_key = "tir.FloorDiv";
 };
 
 /*!
@@ -251,7 +251,7 @@ class FloorDiv : public PrimExpr {
 /*! \brief The remainder of the floordiv */
 class FloorModNode : public BinaryOpNode<FloorModNode> {
  public:
-  static constexpr const char* _type_key = "FloorMod";
+  static constexpr const char* _type_key = "tir.FloorMod";
 };
 
 /*!
@@ -267,7 +267,7 @@ class FloorMod : public PrimExpr {
 /*! \brief min(a, b) */
 class MinNode : public BinaryOpNode<MinNode> {
  public:
-  static constexpr const char* _type_key = "Min";
+  static constexpr const char* _type_key = "tir.Min";
 };
 
 /*!
@@ -283,7 +283,7 @@ class Min : public PrimExpr {
 /*! \brief max(a, b) */
 class MaxNode : public BinaryOpNode<MaxNode> {
  public:
-  static constexpr const char* _type_key = "Max";
+  static constexpr const char* _type_key = "tir.Max";
 };
 
 /*!
@@ -330,7 +330,7 @@ class CmpOpNode : public PrimExprNode {
 /*! \brief a == b */
 class EQNode : public CmpOpNode<EQNode> {
  public:
-  static constexpr const char* _type_key = "EQ";
+  static constexpr const char* _type_key = "tir.EQ";
 };
 
 /*!
@@ -346,7 +346,7 @@ class EQ : public PrimExpr {
 /*! \brief a != b */
 class NENode : public CmpOpNode<NENode> {
  public:
-  static constexpr const char* _type_key = "NE";
+  static constexpr const char* _type_key = "tir.NE";
 };
 
 /*!
@@ -362,7 +362,7 @@ class NE : public PrimExpr {
 /*! \brief a < b */
 class LTNode : public CmpOpNode<LTNode> {
  public:
-  static constexpr const char* _type_key = "LT";
+  static constexpr const char* _type_key = "tir.LT";
 };
 
 /*!
@@ -378,7 +378,7 @@ class LT : public PrimExpr {
 /*! \brief a <= b */
 struct LENode : public CmpOpNode<LENode> {
  public:
-  static constexpr const char* _type_key = "LE";
+  static constexpr const char* _type_key = "tir.LE";
 };
 
 /*!
@@ -394,7 +394,7 @@ class LE : public PrimExpr {
 /*! \brief a > b */
 class GTNode : public CmpOpNode<GTNode> {
  public:
-  static constexpr const char* _type_key = "GT";
+  static constexpr const char* _type_key = "tir.GT";
 };
 
 /*!
@@ -410,7 +410,7 @@ class GT : public PrimExpr {
 /*! \brief a >= b */
 class GENode : public CmpOpNode<GENode> {
  public:
-  static constexpr const char* _type_key = "GE";
+  static constexpr const char* _type_key = "tir.GE";
 };
 
 /*!
@@ -447,7 +447,7 @@ class AndNode : public PrimExprNode {
     hash_reduce(b);
   }
 
-  static constexpr const char* _type_key = "And";
+  static constexpr const char* _type_key = "tir.And";
   TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
 };
 
@@ -485,7 +485,7 @@ class OrNode : public PrimExprNode {
     hash_reduce(b);
   }
 
-  static constexpr const char* _type_key = "Or";
+  static constexpr const char* _type_key = "tir.Or";
   TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
 };
 
@@ -519,7 +519,7 @@ class NotNode : public PrimExprNode {
     hash_reduce(a);
   }
 
-  static constexpr const char* _type_key = "Not";
+  static constexpr const char* _type_key = "tir.Not";
   TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
 };
 
@@ -568,7 +568,7 @@ class SelectNode : public PrimExprNode {
     hash_reduce(false_value);
   }
 
-  static constexpr const char* _type_key = "Select";
+  static constexpr const char* _type_key = "tir.Select";
   TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
 };
 
@@ -617,7 +617,7 @@ class BufferLoadNode : public PrimExprNode {
     hash_reduce(indices);
   }
 
-  static constexpr const char* _type_key = "BufferLoad";
+  static constexpr const char* _type_key = "tir.BufferLoad";
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
 };
 
@@ -664,7 +664,7 @@ class ProducerLoadNode : public PrimExprNode {
     hash_reduce(indices);
   }
 
-  static constexpr const char* _type_key = "ProducerLoad";
+  static constexpr const char* _type_key = "tir.ProducerLoad";
   TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
 };
 
@@ -722,7 +722,7 @@ class LoadNode : public PrimExprNode {
     hash_reduce(predicate);
   }
 
-  static constexpr const char* _type_key = "Load";
+  static constexpr const char* _type_key = "tir.Load";
   TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
 };
 
@@ -773,7 +773,7 @@ class RampNode : public PrimExprNode {
     hash_reduce(lanes);
   }
 
-  static constexpr const char* _type_key = "Ramp";
+  static constexpr const char* _type_key = "tir.Ramp";
   TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
 };
 
@@ -811,7 +811,7 @@ class BroadcastNode : public PrimExprNode {
     hash_reduce(lanes);
   }
 
-  static constexpr const char* _type_key = "Broadcast";
+  static constexpr const char* _type_key = "tir.Broadcast";
   TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
 };
 
@@ -856,7 +856,7 @@ class LetNode : public PrimExprNode {
     hash_reduce(body);
   }
 
-  static constexpr const char* _type_key = "Let";
+  static constexpr const char* _type_key = "tir.Let";
   TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
 };
 
@@ -928,7 +928,7 @@ class CallNode : public PrimExprNode {
   /*! \return Whether call node can be vectorized. */
   bool is_vectorizable() const;
 
-  static constexpr const char* _type_key = "Call";
+  static constexpr const char* _type_key = "tir.Call";
   TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
 
   // Build-in intrinsics
@@ -990,7 +990,7 @@ class ShuffleNode : public PrimExprNode {
     hash_reduce(indices);
   }
 
-  static constexpr const char* _type_key = "Shuffle";
+  static constexpr const char* _type_key = "tir.Shuffle";
   TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
 };
 
@@ -1048,7 +1048,7 @@ class CommReducerNode : public Object {
     hash_reduce(identity_element);
   }
 
-  static constexpr const char* _type_key = "CommReducer";
+  static constexpr const char* _type_key = "tir.CommReducer";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
@@ -1108,7 +1108,7 @@ class ReduceNode : public PrimExprNode {
     hash_reduce(value_index);
   }
 
-  static constexpr const char* _type_key = "Reduce";
+  static constexpr const char* _type_key = "tir.Reduce";
   TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
 };
 
@@ -1136,7 +1136,7 @@ class AnyNode : public PrimExprNode {
   /*! \brief Convert to var. */
   Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
 
-  static constexpr const char* _type_key = "Any";
+  static constexpr const char* _type_key = "tir.Any";
   TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
 };
 
index ee8e1eb..be1c567 100644 (file)
@@ -37,7 +37,7 @@ namespace tir {
 /*! \brief Base node of all statements. */
 class StmtNode : public Object {
  public:
-  static constexpr const char* _type_key = "Stmt";
+  static constexpr const char* _type_key = "tir.Stmt";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
   static constexpr const uint32_t _type_child_slots = 15;
@@ -79,7 +79,7 @@ class LetStmtNode : public StmtNode {
     hash_reduce(body);
   }
 
-  static constexpr const char* _type_key = "LetStmt";
+  static constexpr const char* _type_key = "tir.LetStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
 };
 
@@ -134,7 +134,7 @@ class AttrStmtNode : public StmtNode {
     hash_reduce(body);
   }
 
-  static constexpr const char* _type_key = "AttrStmt";
+  static constexpr const char* _type_key = "tir.AttrStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
 };
 
@@ -181,7 +181,7 @@ class AssertStmtNode : public StmtNode {
     hash_reduce(body);
   }
 
-  static constexpr const char* _type_key = "AssertStmt";
+  static constexpr const char* _type_key = "tir.AssertStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
 };
 
@@ -244,7 +244,7 @@ class StoreNode : public StmtNode {
     hash_reduce(predicate);
   }
 
-  static constexpr const char* _type_key = "Store";
+  static constexpr const char* _type_key = "tir.Store";
   TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
 };
 
@@ -295,7 +295,7 @@ class BufferStoreNode : public StmtNode {
     hash_reduce(indices);
   }
 
-  static constexpr const char* _type_key = "BufferStore";
+  static constexpr const char* _type_key = "tir.BufferStore";
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
 };
 
@@ -355,7 +355,7 @@ class BufferRealizeNode : public StmtNode {
   BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body)
       : buffer(buffer), bounds(bounds), condition(condition), body(body) {}
 
-  static constexpr const char* _type_key = "BufferRealize";
+  static constexpr const char* _type_key = "tir.BufferRealize";
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
 };
 
@@ -406,7 +406,7 @@ class ProducerStoreNode : public StmtNode {
     hash_reduce(indices);
   }
 
-  static constexpr const char* _type_key = "ProducerStore";
+  static constexpr const char* _type_key = "tir.ProducerStore";
   TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
 };
 
@@ -462,7 +462,7 @@ class ProducerRealizeNode : public StmtNode {
     hash_reduce(body);
   }
 
-  static constexpr const char* _type_key = "ProducerRealize";
+  static constexpr const char* _type_key = "tir.ProducerRealize";
   TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
 };
 
@@ -529,7 +529,7 @@ class AllocateNode : public StmtNode {
    */
   TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
 
-  static constexpr const char* _type_key = "Allocate";
+  static constexpr const char* _type_key = "tir.Allocate";
   TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
 };
 
@@ -559,7 +559,7 @@ class FreeNode : public StmtNode {
 
   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }
 
-  static constexpr const char* _type_key = "Free";
+  static constexpr const char* _type_key = "tir.Free";
   TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
 };
 
@@ -598,7 +598,7 @@ class SeqStmtNode : public StmtNode {
 
   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
 
-  static constexpr const char* _type_key = "SeqStmt";
+  static constexpr const char* _type_key = "tir.SeqStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
 };
 
@@ -697,7 +697,7 @@ class IfThenElseNode : public StmtNode {
     hash_reduce(else_case);
   }
 
-  static constexpr const char* _type_key = "IfThenElse";
+  static constexpr const char* _type_key = "tir.IfThenElse";
   TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
 };
 
@@ -731,7 +731,7 @@ class EvaluateNode : public StmtNode {
 
   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
 
-  static constexpr const char* _type_key = "Evaluate";
+  static constexpr const char* _type_key = "tir.Evaluate";
   TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
 };
 
@@ -817,7 +817,7 @@ class ForNode : public StmtNode {
     hash_reduce(body);
   }
 
-  static constexpr const char* _type_key = "For";
+  static constexpr const char* _type_key = "tir.For";
   TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
 };
 
@@ -860,7 +860,7 @@ class PrefetchNode : public StmtNode {
   PrefetchNode() = default;
   PrefetchNode(Buffer buffer, Array<Range> bounds) : buffer(buffer), bounds(bounds) {}
 
-  static constexpr const char* _type_key = "Prefetch";
+  static constexpr const char* _type_key = "tir.Prefetch";
   TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
 };
 
index 2a44909..f1651c1 100644 (file)
@@ -266,7 +266,7 @@ class IterVarNode : public Object {
     hash_reduce(thread_tag);
   }
 
-  static constexpr const char* _type_key = "IterVar";
+  static constexpr const char* _type_key = "tir.IterVar";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
index 94e9cf3..8b75685 100644 (file)
@@ -138,11 +138,48 @@ def create_updater_06_to_07():
         # TIR
         "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
         "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
-        "StringImm": [_update_from_std_str("value")],
-        "Call": [_update_from_std_str("name")],
-        "AttrStmt": [_update_from_std_str("attr_key")],
-        "Layout": [_update_from_std_str("name")],
-        "Buffer": [_update_from_std_str("name"), _update_from_std_str("scope")],
+        "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
+        "Cast": [_rename("tir.Cast")],
+        "Add": [_rename("tir.Add")],
+        "Sub": [_rename("tir.Sub")],
+        "Mul": [_rename("tir.Mul")],
+        "Div": [_rename("tir.Div")],
+        "Mod": [_rename("tir.Mod")],
+        "FloorDiv": [_rename("tir.FloorDiv")],
+        "FloorMod": [_rename("tir.FloorMod")],
+        "Min": [_rename("tir.Min")],
+        "Max": [_rename("tir.Max")],
+        "EQ": [_rename("tir.EQ")],
+        "NE": [_rename("tir.NE")],
+        "LT": [_rename("tir.LT")],
+        "LE": [_rename("tir.LE")],
+        "GT": [_rename("tir.GT")],
+        "GE": [_rename("tir.GE")],
+        "And": [_rename("tir.And")],
+        "Or": [_rename("tir.Or")],
+        "Not": [_rename("tir.Not")],
+        "Select": [_rename("tir.Select")],
+        "Load": [_rename("tir.Load")],
+        "BufferLoad": [_rename("tir.BufferLoad")],
+        "Ramp": [_rename("tir.Ramp")],
+        "Broadcast": [_rename("tir.Broadcast")],
+        "Shuffle": [_rename("tir.Shuffle")],
+        "Call": [_rename("tir.Call"), _update_from_std_str("name")],
+        "Let": [_rename("tir.Let")],
+        "Any": [_rename("tir.Any")],
+        "LetStmt": [_rename("tir.LetStmt")],
+        "AssertStmt": [_rename("tir.AssertStmt")],
+        "Store": [_rename("tir.Store")],
+        "BufferStore": [_rename("tir.BufferStore")],
+        "BufferRealize": [_rename("tir.BufferRealize")],
+        "Allocate": [_rename("tir.Allocate")],
+        "IfThenElse": [_rename("tir.IfThenElse")],
+        "Evaluate": [_rename("tir.Evaluate")],
+        "Prefetch": [_rename("tir.Prefetch")],
+        "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")],
+        "Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
+        "Buffer": [
+            _rename("tir.Buffer"), _update_from_std_str("name"), _update_from_std_str("scope")],
     }
     return create_updater(node_map, "0.6", "0.7")
 
index 810509b..891d7ba 100644 (file)
@@ -83,7 +83,7 @@ def replace_io(body, rmap):
             return _expr.ProducerLoad(buf, op.indices)
         return None
 
-    return stmt_functor.ir_transform(body, None, replace, ['ProducerStore', 'ProducerLoad'])
+    return stmt_functor.ir_transform(body, None, replace, ['tir.ProducerStore', 'tir.ProducerLoad'])
 
 
 def _is_tvm_arg_types(args):
index e4dec5f..11bfb4c 100644 (file)
@@ -24,7 +24,7 @@ from tvm.ir import PrimExpr
 from . import _ffi_api
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Buffer")
 class Buffer(Object):
     """Symbolic data buffer in TVM.
 
@@ -247,6 +247,6 @@ def decl_buffer(shape,
         data_alignment, offset_factor, buffer_type)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.DataProducer")
 class DataProducer(Object):
     pass
index fd8c7a9..1616473 100644 (file)
@@ -20,7 +20,7 @@ import tvm._ffi
 from tvm.runtime import Object
 from . import _ffi_api
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Layout")
 class Layout(Object):
     """Layout is composed of upper cases, lower cases and numbers,
     where upper case indicates a primal axis and
@@ -77,7 +77,7 @@ class Layout(Object):
         return _ffi_api.LayoutFactorOf(self, axis)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BijectiveLayout")
 class BijectiveLayout(Object):
     """Bijective mapping for two layouts (src-layout and dst-layout).
     It provides shape and index conversion between each other.
index d55370e..f8cb054 100644 (file)
@@ -321,7 +321,7 @@ class SizeVar(Var):
             _ffi_api.SizeVar, name, dtype)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.IterVar")
 class IterVar(Object, ExprOp):
     """Represent iteration variable.
 
@@ -373,7 +373,7 @@ class IterVar(Object, ExprOp):
             _ffi_api.IterVar, dom, var, iter_type, thread_tag)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.CommReducer")
 class CommReducer(Object):
     """Communicative reduce operator
 
@@ -396,7 +396,7 @@ class CommReducer(Object):
             _ffi_api.CommReducer, lhs, rhs, result, identity_element)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Reduce")
 class Reduce(PrimExprWithOp):
     """Reduce node.
 
@@ -475,7 +475,7 @@ class IntImm(ConstExpr):
         return self.__nonzero__()
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.StringImm")
 class StringImm(ConstExpr):
     """String constant.
 
@@ -499,7 +499,7 @@ class StringImm(ConstExpr):
         return self.value != other
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Cast")
 class Cast(PrimExprWithOp):
     """Cast expression.
 
@@ -516,7 +516,7 @@ class Cast(PrimExprWithOp):
             _ffi_api.Cast, dtype, value)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Add")
 class Add(BinaryOpExpr):
     """Add node.
 
@@ -533,7 +533,7 @@ class Add(BinaryOpExpr):
             _ffi_api.Add, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Sub")
 class Sub(BinaryOpExpr):
     """Sub node.
 
@@ -550,7 +550,7 @@ class Sub(BinaryOpExpr):
             _ffi_api.Sub, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Mul")
 class Mul(BinaryOpExpr):
     """Mul node.
 
@@ -567,7 +567,7 @@ class Mul(BinaryOpExpr):
             _ffi_api.Mul, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Div")
 class Div(BinaryOpExpr):
     """Div node.
 
@@ -584,7 +584,7 @@ class Div(BinaryOpExpr):
             _ffi_api.Div, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Mod")
 class Mod(BinaryOpExpr):
     """Mod node.
 
@@ -601,7 +601,7 @@ class Mod(BinaryOpExpr):
             _ffi_api.Mod, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.FloorDiv")
 class FloorDiv(BinaryOpExpr):
     """FloorDiv node.
 
@@ -618,7 +618,7 @@ class FloorDiv(BinaryOpExpr):
             _ffi_api.FloorDiv, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.FloorMod")
 class FloorMod(BinaryOpExpr):
     """FloorMod node.
 
@@ -635,7 +635,7 @@ class FloorMod(BinaryOpExpr):
             _ffi_api.FloorMod, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Min")
 class Min(BinaryOpExpr):
     """Min node.
 
@@ -652,7 +652,7 @@ class Min(BinaryOpExpr):
             _ffi_api.Min, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Max")
 class Max(BinaryOpExpr):
     """Max node.
 
@@ -669,7 +669,7 @@ class Max(BinaryOpExpr):
             _ffi_api.Max, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.EQ")
 class EQ(CmpExpr):
     """EQ node.
 
@@ -686,7 +686,7 @@ class EQ(CmpExpr):
             _ffi_api.EQ, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.NE")
 class NE(CmpExpr):
     """NE node.
 
@@ -703,7 +703,7 @@ class NE(CmpExpr):
             _ffi_api.NE, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LT")
 class LT(CmpExpr):
     """LT node.
 
@@ -720,7 +720,7 @@ class LT(CmpExpr):
             _ffi_api.LT, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LE")
 class LE(CmpExpr):
     """LE node.
 
@@ -737,7 +737,7 @@ class LE(CmpExpr):
             _ffi_api.LE, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.GT")
 class GT(CmpExpr):
     """GT node.
 
@@ -754,7 +754,7 @@ class GT(CmpExpr):
             _ffi_api.GT, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.GE")
 class GE(CmpExpr):
     """GE node.
 
@@ -771,7 +771,7 @@ class GE(CmpExpr):
             _ffi_api.GE, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.And")
 class And(LogicalExpr):
     """And node.
 
@@ -788,7 +788,7 @@ class And(LogicalExpr):
             _ffi_api.And, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Or")
 class Or(LogicalExpr):
     """Or node.
 
@@ -805,7 +805,7 @@ class Or(LogicalExpr):
             _ffi_api.Or, a, b)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Not")
 class Not(LogicalExpr):
     """Not node.
 
@@ -819,7 +819,7 @@ class Not(LogicalExpr):
             _ffi_api.Not, a)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Select")
 class Select(PrimExprWithOp):
     """Select node.
 
@@ -847,7 +847,7 @@ class Select(PrimExprWithOp):
             _ffi_api.Select, condition, true_value, false_value)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Load")
 class Load(PrimExprWithOp):
     """Load node.
 
@@ -871,7 +871,7 @@ class Load(PrimExprWithOp):
             _ffi_api.Load, dtype, buffer_var, index, *args)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferLoad")
 class BufferLoad(PrimExprWithOp):
     """Buffer load node.
 
@@ -888,7 +888,7 @@ class BufferLoad(PrimExprWithOp):
             _ffi_api.BufferLoad, buffer, indices)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerLoad")
 class ProducerLoad(PrimExprWithOp):
     """Producer load node.
 
@@ -905,7 +905,7 @@ class ProducerLoad(PrimExprWithOp):
             _ffi_api.ProducerLoad, producer, indices)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Ramp")
 class Ramp(PrimExprWithOp):
     """Ramp node.
 
@@ -925,7 +925,7 @@ class Ramp(PrimExprWithOp):
             _ffi_api.Ramp, base, stride, lanes)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Broadcast")
 class Broadcast(PrimExprWithOp):
     """Broadcast node.
 
@@ -942,7 +942,7 @@ class Broadcast(PrimExprWithOp):
             _ffi_api.Broadcast, value, lanes)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Shuffle")
 class Shuffle(PrimExprWithOp):
     """Shuffle node.
 
@@ -959,7 +959,7 @@ class Shuffle(PrimExprWithOp):
             _ffi_api.Shuffle, vectors, indices)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Call")
 class Call(PrimExprWithOp):
     """Call node.
 
@@ -987,7 +987,7 @@ class Call(PrimExprWithOp):
             _ffi_api.Call, dtype, name, args, call_type)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Let")
 class Let(PrimExprWithOp):
     """Let node.
 
@@ -1007,7 +1007,7 @@ class Let(PrimExprWithOp):
             _ffi_api.Let, var, value, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Any")
 class Any(PrimExpr):
     """Any node.
     """
index f4d8471..4536580 100644 (file)
@@ -36,7 +36,7 @@ class Stmt(Object):
     """Base class of all the statements."""
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LetStmt")
 class LetStmt(Stmt):
     """LetStmt node.
 
@@ -56,7 +56,7 @@ class LetStmt(Stmt):
             _ffi_api.LetStmt, var, value, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.AssertStmt")
 class AssertStmt(Stmt):
     """AssertStmt node.
 
@@ -76,7 +76,7 @@ class AssertStmt(Stmt):
             _ffi_api.AssertStmt, condition, message, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.For")
 class For(Stmt):
     """For node.
 
@@ -116,7 +116,7 @@ class For(Stmt):
             for_type, device_api, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Store")
 class Store(Stmt):
     """Store node.
 
@@ -140,7 +140,7 @@ class Store(Stmt):
             _ffi_api.Store, buffer_var, value, index, *args)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferStore")
 class BufferStore(Stmt):
     """Buffer store node.
 
@@ -160,7 +160,7 @@ class BufferStore(Stmt):
             _ffi_api.BufferStore, buffer, value, indices)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferRealize")
 class BufferRealize(Stmt):
     """Buffer realize node.
 
@@ -183,7 +183,7 @@ class BufferRealize(Stmt):
             _ffi_api.BufferRealize, buffer, bounds, condition, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerStore")
 class ProducerStore(Stmt):
     """ProducerStore node.
 
@@ -203,7 +203,7 @@ class ProducerStore(Stmt):
             _ffi_api.ProducerStore, producer, value, indices)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Allocate")
 class Allocate(Stmt):
     """Allocate node.
 
@@ -235,7 +235,7 @@ class Allocate(Stmt):
             extents, condition, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.AttrStmt")
 class AttrStmt(Stmt):
     """AttrStmt node.
 
@@ -258,7 +258,7 @@ class AttrStmt(Stmt):
             _ffi_api.AttrStmt, node, attr_key, value, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Free")
 class Free(Stmt):
     """Free node.
 
@@ -272,7 +272,7 @@ class Free(Stmt):
             _ffi_api.Free, buffer_var)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerRealize")
 class ProducerRealize(Stmt):
     """ProducerRealize node.
 
@@ -299,7 +299,7 @@ class ProducerRealize(Stmt):
             _ffi_api.ProducerRealize, producer, bounds, condition, body)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.SeqStmt")
 class SeqStmt(Stmt):
     """Sequence of statements.
 
@@ -319,7 +319,7 @@ class SeqStmt(Stmt):
         return len(self.seq)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.IfThenElse")
 class IfThenElse(Stmt):
     """IfThenElse node.
 
@@ -339,7 +339,7 @@ class IfThenElse(Stmt):
             _ffi_api.IfThenElse, condition, then_case, else_case)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Evaluate")
 class Evaluate(Stmt):
     """Evaluate node.
 
@@ -353,7 +353,7 @@ class Evaluate(Stmt):
             _ffi_api.Evaluate, value)
 
 
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Prefetch")
 class Prefetch(Stmt):
     """Prefetch node.
 
index 868845f..d1e24b9 100644 (file)
@@ -159,7 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
     }
   });
 
-  return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"For"});
+  return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"tir.For"});
 }
 
 // Remove IfThenElse node from a For node.
@@ -183,9 +183,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
     }
   });
 
-  then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"IfThenElse"});
+  then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"tir.IfThenElse"});
   if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
-    else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"IfThenElse"});
+    else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"tir.IfThenElse"});
   }
 
   return std::make_pair(then_for, else_for);
@@ -393,7 +393,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
       *ret = new_for;
     }
   });
-  return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"For"});
+  return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"tir.For"});
 }
 
 Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); }
index bafa957..1a7163f 100644 (file)
@@ -214,7 +214,7 @@ def test_cuda_shuffle():
 
         def _transform(f, *_):
             return f.with_body(
-                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
+                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For']))
         return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
 
     with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, MyVectorize())]}):
index 34db08f..1173b71 100644 (file)
@@ -724,7 +724,7 @@ def test_llvm_shuffle():
 
         def _transform(f, *_):
             return f.with_body(
-                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
+                tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For']))
 
         return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
 
index 346239d..80e93a7 100644 (file)
@@ -33,12 +33,12 @@ def verify_structure(stmt, expected_struct):
         if isinstance(op, tvm.tir.IfThenElse):
             global var_list
             tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
-            val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
+            val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))]
             var_list.clear()
         elif isinstance(op, tvm.tir.For):
-            val = [(op.body,), ("For", op.loop_var.name)]
+            val = [(op.body,), ("tir.For", op.loop_var.name)]
         elif isinstance(op, tvm.tir.AttrStmt):
-            val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))]
+            val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))]
         else:
             return
         node_dict[key] = val
@@ -68,9 +68,9 @@ def test_basic():
 
     stmt = ib.get()
     new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
-                       ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')),
-                       ('For', 'i'): (('IfThenElse', ('i',)),)}
+    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')),
+                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
     verify_structure(new_stmt, expected_struct)
 
 def test_no_else():
@@ -87,9 +87,9 @@ def test_no_else():
 
     stmt = ib.get()
     new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
-                       ('IfThenElse', ('i',)): (('For', 'j'), None),
-                       ('For', 'i'): (('IfThenElse', ('i',)),)}
+    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
     verify_structure(new_stmt, expected_struct)
 
 def test_attr_stmt():
@@ -114,10 +114,10 @@ def test_attr_stmt():
 
     stmt = ib.get()
     new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')),
-                       ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),),
-                       ('AttrStmt', 'thread_extent', 64): (('For', 'i'),),
-                       ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)}
+    expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
+                       ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),),
+                       ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),),
+                       ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)}
     verify_structure(new_stmt, expected_struct)
 
 def test_nested_for():
@@ -138,9 +138,9 @@ def test_nested_for():
 
     stmt = ib.get()
     new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),),
-                       ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None),
-                       ('For', 'i'): (('IfThenElse', ('i',)),)}
+    expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.For', 'l'): (('tir.IfThenElse', ('i', 'j')),),
+                       ('tir.For', 'k'): (('tir.For', 'l'),), ('tir.For', 'j'): (None,), ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
     verify_structure(new_stmt, expected_struct)
 
 def test_if_block():
@@ -171,10 +171,10 @@ def test_if_block():
 
     stmt = ib.get()
     new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None),
-                       ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),),
-                       ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),),
-                       ('IfThenElse', ('n',)): (('For', 'j'), None)}
+    expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None),
+                       ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),),
+                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),),
+                       ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)}
     verify_structure(new_stmt, expected_struct)
 
 
index 7bf7011..38529e9 100644 (file)
@@ -37,7 +37,7 @@ def test_ir_transform():
         if op.name == "TestA":
             return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1)
         return op
-    body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["Call"])
+    body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"])
     stmt_list = tvm.tir.stmt_list(body.body.body)
     assert stmt_list[0].value.args[0].name == "TestB"
     assert stmt_list[1].value.value == 0
index db50572..17f864f 100644 (file)
@@ -84,7 +84,7 @@ print(ir)
 
 loops = []
 def find_width8(op):
-    """ Find all the 'For' nodes whose extent can be divided by 8. """
+    """ Find all the 'tir.For' nodes whose extent can be divided by 8. """
     if isinstance(op, tvm.tir.For):
         if isinstance(op.extent, tvm.tir.IntImm):
             if op.extent.value % 8 == 0:
@@ -129,7 +129,7 @@ def vectorize(f, mod, ctx):
     # The last list arugment indicates what kinds of nodes will be transformed.
     # Thus, in this case only `For` nodes will call `vectorize8`
     return f.with_body(
-        tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['For']))
+        tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['tir.For']))
 
 
 #####################################################################
index 37b4e0e..207f784 100644 (file)
@@ -87,7 +87,7 @@ def FoldUopLoop():
             return op
 
         ret = tvm.tir.stmt_functor.ir_transform(
-            stmt.body, None, _post_order, ["Call"])
+            stmt.body, None, _post_order, ["tir.Call"])
 
         if not fail[0] and all(x is not None for x in gemm_offsets):
             def _visit(op):
@@ -132,7 +132,7 @@ def FoldUopLoop():
 
     def _ftransform(f, mod, ctx):
         return f.with_body(tvm.tir.stmt_functor.ir_transform(
-            f.body, _do_fold, None, ["AttrStmt"]))
+            f.body, _do_fold, None, ["tir.AttrStmt"]))
 
     return tvm.tir.transform.prim_func_pass(
         _ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
@@ -188,7 +188,7 @@ def CPUAccessRewrite():
 
         stmt_in = f.body
         stmt = tvm.tir.stmt_functor.ir_transform(
-            stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
+            stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"])
 
         for buffer_var, new_var in rw_info.items():
             stmt = tvm.tir.LetStmt(
@@ -254,7 +254,7 @@ def LiftAllocToScopeBegin():
             raise RuntimeError("not reached")
         stmt_in = f.body
         stmt = tvm.tir.stmt_functor.ir_transform(
-            stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
+            stmt_in, _pre_order, _post_order, ["tir.Allocate", "tir.AttrStmt", "tir.For"])
         assert len(lift_stmt) == 1
         return f.with_body(_merge_block(lift_stmt[0], stmt))
 
@@ -277,7 +277,7 @@ def InjectSkipCopy():
 
     def _ftransform(f, mod, ctx):
         return f.with_body(tvm.tir.stmt_functor.ir_transform(
-            f.body, _do_fold, None, ["AttrStmt"]))
+            f.body, _do_fold, None, ["tir.AttrStmt"]))
 
     return tvm.tir.transform.prim_func_pass(
         _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
@@ -307,7 +307,7 @@ def InjectCoProcSync():
                     op.device_api, op.body)
             return None
         return f.with_body(tvm.tir.stmt_functor.ir_transform(
-            f.body, None, _do_fold, ["AttrStmt"]))
+            f.body, None, _do_fold, ["tir.AttrStmt"]))
     return tvm.transform.Sequential(
         [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
          tvm.tir.transform.CoProcSync()],
@@ -708,7 +708,7 @@ def InjectConv2DTransposeSkip():
             return None
 
         return func.with_body(tvm.tir.stmt_functor.ir_transform(
-            func.body, _do_fold, None, ["AttrStmt"]))
+            func.body, _do_fold, None, ["tir.AttrStmt"]))
     return tvm.tir.transform.prim_func_pass(
         _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
 
@@ -737,7 +737,7 @@ def AnnotateALUCoProcScope():
             return stmt
 
         return func.with_body(tvm.tir.stmt_functor.ir_transform(
-            func.body, None, _do_fold, ["AttrStmt"]))
+            func.body, None, _do_fold, ["tir.AttrStmt"]))
     return tvm.tir.transform.prim_func_pass(
         _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
 
@@ -956,7 +956,7 @@ def InjectALUIntrin():
             return stmt
 
         return func.with_body(tvm.tir.stmt_functor.ir_transform(
-            func.body, None, _do_fold, ["AttrStmt"]))
+            func.body, None, _do_fold, ["tir.AttrStmt"]))
 
     return tvm.tir.transform.prim_func_pass(
         _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")