return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}
- static constexpr const char* _type_key = "Buffer";
+ static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
void SHashReduce(SHashReducer hash_reduce) const {}
- static constexpr const char* _type_key = "DataProducer";
+ static constexpr const char* _type_key = "tir.DataProducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
v->Visit("axes", &axes);
}
- static constexpr const char* _type_key = "Layout";
+ static constexpr const char* _type_key = "tir.Layout";
TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
};
v->Visit("backward_rule", &backward_rule);
}
- static constexpr const char* _type_key = "BijectiveLayout";
+ static constexpr const char* _type_key = "tir.BijectiveLayout";
TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
};
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
- static constexpr const char* _type_key = "StringImm";
+ static constexpr const char* _type_key = "tir.StringImm";
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
};
hash_reduce(value);
}
- static constexpr const char* _type_key = "Cast";
+ static constexpr const char* _type_key = "tir.Cast";
TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
};
/*! \brief a + b */
class AddNode : public BinaryOpNode<AddNode> {
public:
- static constexpr const char* _type_key = "Add";
+ static constexpr const char* _type_key = "tir.Add";
};
/*!
/*! \brief a - b */
class SubNode : public BinaryOpNode<SubNode> {
public:
- static constexpr const char* _type_key = "Sub";
+ static constexpr const char* _type_key = "tir.Sub";
};
/*!
/*! \brief a * b */
class MulNode : public BinaryOpNode<MulNode> {
public:
- static constexpr const char* _type_key = "Mul";
+ static constexpr const char* _type_key = "tir.Mul";
};
/*!
*/
class DivNode : public BinaryOpNode<DivNode> {
public:
- static constexpr const char* _type_key = "Div";
+ static constexpr const char* _type_key = "tir.Div";
};
/*!
*/
class ModNode : public BinaryOpNode<ModNode> {
public:
- static constexpr const char* _type_key = "Mod";
+ static constexpr const char* _type_key = "tir.Mod";
};
/*!
/*! \brief Floor division, floor(a/b) */
class FloorDivNode : public BinaryOpNode<FloorDivNode> {
public:
- static constexpr const char* _type_key = "FloorDiv";
+ static constexpr const char* _type_key = "tir.FloorDiv";
};
/*!
/*! \brief The remainder of the floordiv */
class FloorModNode : public BinaryOpNode<FloorModNode> {
public:
- static constexpr const char* _type_key = "FloorMod";
+ static constexpr const char* _type_key = "tir.FloorMod";
};
/*!
/*! \brief min(a, b) */
class MinNode : public BinaryOpNode<MinNode> {
public:
- static constexpr const char* _type_key = "Min";
+ static constexpr const char* _type_key = "tir.Min";
};
/*!
/*! \brief max(a, b) */
class MaxNode : public BinaryOpNode<MaxNode> {
public:
- static constexpr const char* _type_key = "Max";
+ static constexpr const char* _type_key = "tir.Max";
};
/*!
/*! \brief a == b */
class EQNode : public CmpOpNode<EQNode> {
public:
- static constexpr const char* _type_key = "EQ";
+ static constexpr const char* _type_key = "tir.EQ";
};
/*!
/*! \brief a != b */
class NENode : public CmpOpNode<NENode> {
public:
- static constexpr const char* _type_key = "NE";
+ static constexpr const char* _type_key = "tir.NE";
};
/*!
/*! \brief a < b */
class LTNode : public CmpOpNode<LTNode> {
public:
- static constexpr const char* _type_key = "LT";
+ static constexpr const char* _type_key = "tir.LT";
};
/*!
/*! \brief a <= b */
struct LENode : public CmpOpNode<LENode> {
public:
- static constexpr const char* _type_key = "LE";
+ static constexpr const char* _type_key = "tir.LE";
};
/*!
/*! \brief a > b */
class GTNode : public CmpOpNode<GTNode> {
public:
- static constexpr const char* _type_key = "GT";
+ static constexpr const char* _type_key = "tir.GT";
};
/*!
/*! \brief a >= b */
class GENode : public CmpOpNode<GENode> {
public:
- static constexpr const char* _type_key = "GE";
+ static constexpr const char* _type_key = "tir.GE";
};
/*!
hash_reduce(b);
}
- static constexpr const char* _type_key = "And";
+ static constexpr const char* _type_key = "tir.And";
TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
};
hash_reduce(b);
}
- static constexpr const char* _type_key = "Or";
+ static constexpr const char* _type_key = "tir.Or";
TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
};
hash_reduce(a);
}
- static constexpr const char* _type_key = "Not";
+ static constexpr const char* _type_key = "tir.Not";
TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
};
hash_reduce(false_value);
}
- static constexpr const char* _type_key = "Select";
+ static constexpr const char* _type_key = "tir.Select";
TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
};
hash_reduce(indices);
}
- static constexpr const char* _type_key = "BufferLoad";
+ static constexpr const char* _type_key = "tir.BufferLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
};
hash_reduce(indices);
}
- static constexpr const char* _type_key = "ProducerLoad";
+ static constexpr const char* _type_key = "tir.ProducerLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
};
hash_reduce(predicate);
}
- static constexpr const char* _type_key = "Load";
+ static constexpr const char* _type_key = "tir.Load";
TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
};
hash_reduce(lanes);
}
- static constexpr const char* _type_key = "Ramp";
+ static constexpr const char* _type_key = "tir.Ramp";
TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
};
hash_reduce(lanes);
}
- static constexpr const char* _type_key = "Broadcast";
+ static constexpr const char* _type_key = "tir.Broadcast";
TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
};
hash_reduce(body);
}
- static constexpr const char* _type_key = "Let";
+ static constexpr const char* _type_key = "tir.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
};
/*! \return Whether call node can be vectorized. */
bool is_vectorizable() const;
- static constexpr const char* _type_key = "Call";
+ static constexpr const char* _type_key = "tir.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
// Build-in intrinsics
hash_reduce(indices);
}
- static constexpr const char* _type_key = "Shuffle";
+ static constexpr const char* _type_key = "tir.Shuffle";
TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
};
hash_reduce(identity_element);
}
- static constexpr const char* _type_key = "CommReducer";
+ static constexpr const char* _type_key = "tir.CommReducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
hash_reduce(value_index);
}
- static constexpr const char* _type_key = "Reduce";
+ static constexpr const char* _type_key = "tir.Reduce";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
};
/*! \brief Convert to var. */
Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
- static constexpr const char* _type_key = "Any";
+ static constexpr const char* _type_key = "tir.Any";
TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
};
/*! \brief Base node of all statements. */
class StmtNode : public Object {
public:
- static constexpr const char* _type_key = "Stmt";
+ static constexpr const char* _type_key = "tir.Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
hash_reduce(body);
}
- static constexpr const char* _type_key = "LetStmt";
+ static constexpr const char* _type_key = "tir.LetStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
};
hash_reduce(body);
}
- static constexpr const char* _type_key = "AttrStmt";
+ static constexpr const char* _type_key = "tir.AttrStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
};
hash_reduce(body);
}
- static constexpr const char* _type_key = "AssertStmt";
+ static constexpr const char* _type_key = "tir.AssertStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};
hash_reduce(predicate);
}
- static constexpr const char* _type_key = "Store";
+ static constexpr const char* _type_key = "tir.Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
};
hash_reduce(indices);
}
- static constexpr const char* _type_key = "BufferStore";
+ static constexpr const char* _type_key = "tir.BufferStore";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
};
BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body)
: buffer(buffer), bounds(bounds), condition(condition), body(body) {}
- static constexpr const char* _type_key = "BufferRealize";
+ static constexpr const char* _type_key = "tir.BufferRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
};
hash_reduce(indices);
}
- static constexpr const char* _type_key = "ProducerStore";
+ static constexpr const char* _type_key = "tir.ProducerStore";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
};
hash_reduce(body);
}
- static constexpr const char* _type_key = "ProducerRealize";
+ static constexpr const char* _type_key = "tir.ProducerRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
};
*/
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
- static constexpr const char* _type_key = "Allocate";
+ static constexpr const char* _type_key = "tir.Allocate";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }
- static constexpr const char* _type_key = "Free";
+ static constexpr const char* _type_key = "tir.Free";
TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
- static constexpr const char* _type_key = "SeqStmt";
+ static constexpr const char* _type_key = "tir.SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};
hash_reduce(else_case);
}
- static constexpr const char* _type_key = "IfThenElse";
+ static constexpr const char* _type_key = "tir.IfThenElse";
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
};
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
- static constexpr const char* _type_key = "Evaluate";
+ static constexpr const char* _type_key = "tir.Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};
hash_reduce(body);
}
- static constexpr const char* _type_key = "For";
+ static constexpr const char* _type_key = "tir.For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
PrefetchNode() = default;
PrefetchNode(Buffer buffer, Array<Range> bounds) : buffer(buffer), bounds(bounds) {}
- static constexpr const char* _type_key = "Prefetch";
+ static constexpr const char* _type_key = "tir.Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
};
hash_reduce(thread_tag);
}
- static constexpr const char* _type_key = "IterVar";
+ static constexpr const char* _type_key = "tir.IterVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
# TIR
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
- "StringImm": [_update_from_std_str("value")],
- "Call": [_update_from_std_str("name")],
- "AttrStmt": [_update_from_std_str("attr_key")],
- "Layout": [_update_from_std_str("name")],
- "Buffer": [_update_from_std_str("name"), _update_from_std_str("scope")],
+ "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
+ "Cast": [_rename("tir.Cast")],
+ "Add": [_rename("tir.Add")],
+ "Sub": [_rename("tir.Sub")],
+ "Mul": [_rename("tir.Mul")],
+ "Div": [_rename("tir.Div")],
+ "Mod": [_rename("tir.Mod")],
+ "FloorDiv": [_rename("tir.FloorDiv")],
+ "FloorMod": [_rename("tir.FloorMod")],
+ "Min": [_rename("tir.Min")],
+ "Max": [_rename("tir.Max")],
+ "EQ": [_rename("tir.EQ")],
+ "NE": [_rename("tir.NE")],
+ "LT": [_rename("tir.LT")],
+ "LE": [_rename("tir.LE")],
+ "GT": [_rename("tir.GT")],
+ "GE": [_rename("tir.GE")],
+ "And": [_rename("tir.And")],
+ "Or": [_rename("tir.Or")],
+ "Not": [_rename("tir.Not")],
+ "Select": [_rename("tir.Select")],
+ "Load": [_rename("tir.Load")],
+ "BufferLoad": [_rename("tir.BufferLoad")],
+ "Ramp": [_rename("tir.Ramp")],
+ "Broadcast": [_rename("tir.Broadcast")],
+ "Shuffle": [_rename("tir.Shuffle")],
+ "Call": [_rename("tir.Call"), _update_from_std_str("name")],
+ "Let": [_rename("tir.Let")],
+ "Any": [_rename("tir.Any")],
+ "LetStmt": [_rename("tir.LetStmt")],
+ "AssertStmt": [_rename("tir.AssertStmt")],
+ "Store": [_rename("tir.Store")],
+ "BufferStore": [_rename("tir.BufferStore")],
+ "BufferRealize": [_rename("tir.BufferRealize")],
+ "Allocate": [_rename("tir.Allocate")],
+ "IfThenElse": [_rename("tir.IfThenElse")],
+ "Evaluate": [_rename("tir.Evaluate")],
+ "Prefetch": [_rename("tir.Prefetch")],
+ "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")],
+ "Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
+ "Buffer": [
+ _rename("tir.Buffer"), _update_from_std_str("name"), _update_from_std_str("scope")],
}
return create_updater(node_map, "0.6", "0.7")
return _expr.ProducerLoad(buf, op.indices)
return None
- return stmt_functor.ir_transform(body, None, replace, ['ProducerStore', 'ProducerLoad'])
+ return stmt_functor.ir_transform(body, None, replace, ['tir.ProducerStore', 'tir.ProducerLoad'])
def _is_tvm_arg_types(args):
from . import _ffi_api
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Buffer")
class Buffer(Object):
"""Symbolic data buffer in TVM.
data_alignment, offset_factor, buffer_type)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.DataProducer")
class DataProducer(Object):
pass
from tvm.runtime import Object
from . import _ffi_api
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Layout")
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
return _ffi_api.LayoutFactorOf(self, axis)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BijectiveLayout")
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
_ffi_api.SizeVar, name, dtype)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.IterVar")
class IterVar(Object, ExprOp):
"""Represent iteration variable.
_ffi_api.IterVar, dom, var, iter_type, thread_tag)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.CommReducer")
class CommReducer(Object):
"""Communicative reduce operator
_ffi_api.CommReducer, lhs, rhs, result, identity_element)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Reduce")
class Reduce(PrimExprWithOp):
"""Reduce node.
return self.__nonzero__()
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.StringImm")
class StringImm(ConstExpr):
"""String constant.
return self.value != other
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Cast")
class Cast(PrimExprWithOp):
"""Cast expression.
_ffi_api.Cast, dtype, value)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Add")
class Add(BinaryOpExpr):
"""Add node.
_ffi_api.Add, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Sub")
class Sub(BinaryOpExpr):
"""Sub node.
_ffi_api.Sub, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Mul")
class Mul(BinaryOpExpr):
"""Mul node.
_ffi_api.Mul, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Div")
class Div(BinaryOpExpr):
"""Div node.
_ffi_api.Div, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Mod")
class Mod(BinaryOpExpr):
"""Mod node.
_ffi_api.Mod, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.FloorDiv")
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
_ffi_api.FloorDiv, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.FloorMod")
class FloorMod(BinaryOpExpr):
"""FloorMod node.
_ffi_api.FloorMod, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Min")
class Min(BinaryOpExpr):
"""Min node.
_ffi_api.Min, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Max")
class Max(BinaryOpExpr):
"""Max node.
_ffi_api.Max, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.EQ")
class EQ(CmpExpr):
"""EQ node.
_ffi_api.EQ, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.NE")
class NE(CmpExpr):
"""NE node.
_ffi_api.NE, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LT")
class LT(CmpExpr):
"""LT node.
_ffi_api.LT, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LE")
class LE(CmpExpr):
"""LE node.
_ffi_api.LE, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.GT")
class GT(CmpExpr):
"""GT node.
_ffi_api.GT, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.GE")
class GE(CmpExpr):
"""GE node.
_ffi_api.GE, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.And")
class And(LogicalExpr):
"""And node.
_ffi_api.And, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Or")
class Or(LogicalExpr):
"""Or node.
_ffi_api.Or, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Not")
class Not(LogicalExpr):
"""Not node.
_ffi_api.Not, a)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Select")
class Select(PrimExprWithOp):
"""Select node.
_ffi_api.Select, condition, true_value, false_value)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Load")
class Load(PrimExprWithOp):
"""Load node.
_ffi_api.Load, dtype, buffer_var, index, *args)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferLoad")
class BufferLoad(PrimExprWithOp):
"""Buffer load node.
_ffi_api.BufferLoad, buffer, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerLoad")
class ProducerLoad(PrimExprWithOp):
"""Producer load node.
_ffi_api.ProducerLoad, producer, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Ramp")
class Ramp(PrimExprWithOp):
"""Ramp node.
_ffi_api.Ramp, base, stride, lanes)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Broadcast")
class Broadcast(PrimExprWithOp):
"""Broadcast node.
_ffi_api.Broadcast, value, lanes)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Shuffle")
class Shuffle(PrimExprWithOp):
"""Shuffle node.
_ffi_api.Shuffle, vectors, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Call")
class Call(PrimExprWithOp):
"""Call node.
_ffi_api.Call, dtype, name, args, call_type)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Let")
class Let(PrimExprWithOp):
"""Let node.
_ffi_api.Let, var, value, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Any")
class Any(PrimExpr):
"""Any node.
"""
"""Base class of all the statements."""
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LetStmt")
class LetStmt(Stmt):
"""LetStmt node.
_ffi_api.LetStmt, var, value, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.AssertStmt")
class AssertStmt(Stmt):
"""AssertStmt node.
_ffi_api.AssertStmt, condition, message, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.For")
class For(Stmt):
"""For node.
for_type, device_api, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Store")
class Store(Stmt):
"""Store node.
_ffi_api.Store, buffer_var, value, index, *args)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferStore")
class BufferStore(Stmt):
"""Buffer store node.
_ffi_api.BufferStore, buffer, value, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferRealize")
class BufferRealize(Stmt):
"""Buffer realize node.
_ffi_api.BufferRealize, buffer, bounds, condition, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerStore")
class ProducerStore(Stmt):
"""ProducerStore node.
_ffi_api.ProducerStore, producer, value, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Allocate")
class Allocate(Stmt):
"""Allocate node.
extents, condition, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.AttrStmt")
class AttrStmt(Stmt):
"""AttrStmt node.
_ffi_api.AttrStmt, node, attr_key, value, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Free")
class Free(Stmt):
"""Free node.
_ffi_api.Free, buffer_var)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerRealize")
class ProducerRealize(Stmt):
"""ProducerRealize node.
_ffi_api.ProducerRealize, producer, bounds, condition, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.SeqStmt")
class SeqStmt(Stmt):
"""Sequence of statements.
return len(self.seq)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.IfThenElse")
class IfThenElse(Stmt):
"""IfThenElse node.
_ffi_api.IfThenElse, condition, then_case, else_case)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Evaluate")
class Evaluate(Stmt):
"""Evaluate node.
_ffi_api.Evaluate, value)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Prefetch")
class Prefetch(Stmt):
"""Prefetch node.
}
});
- return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"For"});
+ return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"tir.For"});
}
// Remove IfThenElse node from a For node.
}
});
- then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"IfThenElse"});
+ then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"tir.IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
- else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"IfThenElse"});
+ else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"tir.IfThenElse"});
}
return std::make_pair(then_for, else_for);
*ret = new_for;
}
});
- return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"For"});
+ return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"tir.For"});
}
Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); }
def _transform(f, *_):
return f.with_body(
- tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
+ tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, MyVectorize())]}):
def _transform(f, *_):
return f.with_body(
- tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For']))
+ tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
if isinstance(op, tvm.tir.IfThenElse):
global var_list
tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
- val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))]
+ val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))]
var_list.clear()
elif isinstance(op, tvm.tir.For):
- val = [(op.body,), ("For", op.loop_var.name)]
+ val = [(op.body,), ("tir.For", op.loop_var.name)]
elif isinstance(op, tvm.tir.AttrStmt):
- val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))]
+ val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))]
else:
return
node_dict[key] = val
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
- ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')),
- ('For', 'i'): (('IfThenElse', ('i',)),)}
+ expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
+ ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')),
+ ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
def test_no_else():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
- ('IfThenElse', ('i',)): (('For', 'j'), None),
- ('For', 'i'): (('IfThenElse', ('i',)),)}
+ expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
+ ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+ ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
def test_attr_stmt():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')),
- ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),),
- ('AttrStmt', 'thread_extent', 64): (('For', 'i'),),
- ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)}
+ expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')),
+ ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),),
+ ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),),
+ ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)}
verify_structure(new_stmt, expected_struct)
def test_nested_for():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),),
- ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None),
- ('For', 'i'): (('IfThenElse', ('i',)),)}
+ expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.For', 'l'): (('tir.IfThenElse', ('i', 'j')),),
+ ('tir.For', 'k'): (('tir.For', 'l'),), ('tir.For', 'j'): (None,), ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+ ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
def test_if_block():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None),
- ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),),
- ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),),
- ('IfThenElse', ('n',)): (('For', 'j'), None)}
+ expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None),
+ ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),),
+ ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),),
+ ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)}
verify_structure(new_stmt, expected_struct)
if op.name == "TestA":
return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1)
return op
- body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["Call"])
+ body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"])
stmt_list = tvm.tir.stmt_list(body.body.body)
assert stmt_list[0].value.args[0].name == "TestB"
assert stmt_list[1].value.value == 0
loops = []
def find_width8(op):
- """ Find all the 'For' nodes whose extent can be divided by 8. """
+ """ Find all the 'tir.For' nodes whose extent can be divided by 8. """
if isinstance(op, tvm.tir.For):
if isinstance(op.extent, tvm.tir.IntImm):
if op.extent.value % 8 == 0:
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
return f.with_body(
- tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['For']))
+ tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['tir.For']))
#####################################################################
return op
ret = tvm.tir.stmt_functor.ir_transform(
- stmt.body, None, _post_order, ["Call"])
+ stmt.body, None, _post_order, ["tir.Call"])
if not fail[0] and all(x is not None for x in gemm_offsets):
def _visit(op):
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.stmt_functor.ir_transform(
- f.body, _do_fold, None, ["AttrStmt"]))
+ f.body, _do_fold, None, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
stmt_in = f.body
stmt = tvm.tir.stmt_functor.ir_transform(
- stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
+ stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt(
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.stmt_functor.ir_transform(
- stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
+ stmt_in, _pre_order, _post_order, ["tir.Allocate", "tir.AttrStmt", "tir.For"])
assert len(lift_stmt) == 1
return f.with_body(_merge_block(lift_stmt[0], stmt))
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.stmt_functor.ir_transform(
- f.body, _do_fold, None, ["AttrStmt"]))
+ f.body, _do_fold, None, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
op.device_api, op.body)
return None
return f.with_body(tvm.tir.stmt_functor.ir_transform(
- f.body, None, _do_fold, ["AttrStmt"]))
+ f.body, None, _do_fold, ["tir.AttrStmt"]))
return tvm.transform.Sequential(
[tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
tvm.tir.transform.CoProcSync()],
return None
return func.with_body(tvm.tir.stmt_functor.ir_transform(
- func.body, _do_fold, None, ["AttrStmt"]))
+ func.body, _do_fold, None, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
return stmt
return func.with_body(tvm.tir.stmt_functor.ir_transform(
- func.body, None, _do_fold, ["AttrStmt"]))
+ func.body, None, _do_fold, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
return stmt
return func.with_body(tvm.tir.stmt_functor.ir_transform(
- func.body, None, _do_fold, ["AttrStmt"]))
+ func.body, None, _do_fold, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")