This is a followup refactor for tir::Call.
Now that we have switched call->name to call->op, the function effect property
can be registered through the op itself, so we no longer need the call_type in the CallNode.
- Introduce CallEffectKind to provide a more fine grained categorization of calls.
- Introduce call_pure_extern and call_llvm_pure_intrin to
allow us to indicate pure calls in those cases.
- Migrate existing usecases to the new API.
TVM_DLL const Op& call_extern();
/*!
+ * \brief Call an pure extern C function with given name
+ * and signature from the types of args in the runtime environment.
+ *
+ * Type call_pure_extern(name, args...) {
+ * return dlsym(name)(args...);
+ * }
+ *
+ * \note This intrinsic does not provide any type checking,
+ * and is main used for backward compatibility reasons.
+ * Always consider use pre-registered and typed tvm::Op first.
+ */
+TVM_DLL const Op& call_pure_extern();
+
+/*!
* \brief Call an LLVM intrinsic with a given intrinsic id
* and signature from the types of args in the runtime environment.
*
- * Type call_llvm_intrin(intrin_id, args...) {
+ * Type call_llvm_pure_intrin(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
TVM_DLL const Op& call_llvm_intrin();
/*!
- * \brief Call an SPIRV GLSL450 intrinsic.
+ * \brief Call an LLVM pure intrinsic with a given intrinsic id
+ * and signature from the types of args in the runtime environment.
+ *
+ * Type call_llvm_pure_intrin(intrin_id, args...) {
+ * return dlsym(name)(args...);
+ * }
+ *
+ * \note This op does not provide any type checking.
+ */
+TVM_DLL const Op& call_llvm_pure_intrin();
+
+/*!
+ * \brief Call an SPIRV pure GLSL450 intrinsic.
*
- * Type call_spirv_glsl450(intrin_id, args...) {
+ * Type call_spirv_pure_glsl450(intrin_id, args...) {
* return dlsym(name)(args...);
* }
*
* \note This op does not provide any type checking.
*/
-TVM_DLL const Op& call_spirv_glsl450();
+TVM_DLL const Op& call_spirv_pure_glsl450();
// TODO(tvm-team) revisit the builtins below
// some of them can simply become ops with special codegen attr.
*/
class CallNode : public PrimExprNode {
public:
- /*! \brief Possible types of calls. */
- enum CallType : int {
- /*! \brief Extern "C" function. */
- Extern = 0,
- /*! \brief Extern CXX function. */
- ExternCPlusPlus = 1,
- /*! \brief Extern "C" without side-effect. */
- PureExtern = 2,
- /*! \brief Intrinsic functions. */
- Intrinsic = 4,
- /*! \brief Intrinsic functions that are pure. */
- PureIntrinsic = 5
- };
/*!
* \brief The operator(function) being invoked
*
/*! \brief The arguments. */
Array<PrimExpr> args;
- /*! \brief Type of calls. */
- CallType call_type;
-
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("args", &args);
- v->Visit("call_type", &call_type);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
- return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) &&
- equal(call_type, other->call_type);
+ return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(op);
hash_reduce(args);
- hash_reduce(call_type);
}
- /*! \return Whether call node is pure. */
- bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }
-
static constexpr const char* _type_key = "tir.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
};
*/
class Call : public PrimExpr {
public:
- using CallType = CallNode::CallType;
-
- TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type);
+ TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args);
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
};
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
// Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName) \
- inline PrimExpr OpName(PrimExpr x) { \
- static const Op& op = Op::Get("tir." #OpName); \
- return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \
+#define TVM_DECLARE_INTRIN_UNARY(OpName) \
+ inline PrimExpr OpName(PrimExpr x) { \
+ static const Op& op = Op::Get("tir." #OpName); \
+ return tir::Call(x.dtype(), op, {x}); \
}
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);
-#define TVM_DECLARE_INTRIN_BINARY(OpName) \
- inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
- static const Op& op = Op::Get("tir." #OpName); \
- return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); \
+#define TVM_DECLARE_INTRIN_BINARY(OpName) \
+ inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
+ static const Op& op = Op::Get("tir." #OpName); \
+ return tir::Call(x.dtype(), op, {x, y}); \
}
TVM_DECLARE_INTRIN_BINARY(atan2);
*/
using TVectorizable = bool;
+/*!
+ * \brief The effect type of the call.
+ */
+enum class CallEffectKind : int {
+ /*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */
+ kExprAnnotation = 0,
+ /*!
+ * \brief Pure function that do not interacts
+ * with any external state.
+ */
+ kPure = 1,
+ /*!
+ * \brief Function's that may read from states(e.g. RAM)
+ */
+ kReadState = 2,
+ /*!
+ * \brief Function that may read/write from states(e.g. RAM).
+ */
+ kUpdateState = 3,
+ /*!
+ * \brief Opaque function, cannot make any assumption
+ */
+ kOpaque = kUpdateState,
+ /*!
+ * \brief Special intrinsic to annotate call arguments info
+ * only valid as a direct argument to a call.
+ */
+ kSpecialCallArg = 4,
+ /*!
+ * \brief Embed opaque information in the Expr, cannot be codegen.
+ */
+ kEmbedInfo = 5
+};
+
+/*! \brief Use integer to record the kind. */
+using TCallEffectKind = Integer;
+
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_OP_ATTR_TYPES_H_
from tvm.ir.container import Array
from tvm import target as _tgt
from tvm.tir import expr as _expr
-from tvm.tir import call_pure_intrin
+from tvm.tir import call_intrin
from tvm.tir.stmt import For
from .util import _internal_assert
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
- return call_pure_intrin(args[0].dtype, 'tir.likely', *args)
+ return call_intrin(args[0].dtype, 'tir.likely', *args)
def max_num_threads(func_id, args):
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
-from .expr import IterVar, Any
+from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
+from .expr import Call, CallEffectKind, Let, IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .function import PrimFunc
-from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, all, any, min_value, max_value, trace
+from .op import call_packed, call_intrin, call_pure_extern, call_extern
+from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
return _ffi_api._OpNE(self.a, self.b)
+class IntImmEnum(ObjectGeneric):
+ """Lazily evaluate an IntImm in case
+ the constructor is not available in runtime.
+
+ Parameters
+ ----------
+ value : int
+ The enum value
+ """
+ def __init__(self, value):
+ self.value = value
+
+ def asobject(self):
+ """Convert object."""
+ return IntImm("int32", self.value)
+
+
class PrimExprWithOp(ExprOp, PrimExpr):
"""Helper base class to inherit from PrimExpr."""
# In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
_ffi_api.Shuffle, vectors, indices)
+class CallEffectKind:
+ """Possible kinds of Call effects."""
+ # only expose up to opaque
+ ExprAnnotation = IntImmEnum(0)
+ Pure = IntImmEnum(1)
+ ReadState = IntImmEnum(2)
+ UpdateState = IntImmEnum(3)
+ Opaque = UpdateState
+
+
@tvm._ffi.register_object("tir.Call")
class Call(PrimExprWithOp):
"""Call node.
args : list of Expr
The input arguments to the call
-
- call_type : int
- The type of the call
"""
- Extern = 0
- ExternCPlusPlus = 1
- PureExtern = 2
- Intrinsic = 4
- PureIntrinsic = 5
- def __init__(self, dtype, op, args, call_type):
+ def __init__(self, dtype, op, args):
if isinstance(op, str):
if not op.startswith("tir."):
raise ValueError(
"certain about the intrinsic name, pass in Op.get(name) instead") % op)
op = Op.get(op)
self.__init_handle_by_constructor__(
- _ffi_api.Call, dtype, op, args, call_type)
+ _ffi_api.Call, dtype, op, args)
@tvm._ffi.register_object("tir.Let")
expr : Expr
The expression will likely tag.
"""
- return _expr.Call(expr.dtype, "tir.likely", [expr],
- _expr.Call.PureIntrinsic)
+ return _expr.Call(expr.dtype, "tir.likely", [expr])
def get(self):
"""Return the builded IR.
"""Build intrinsics that packs the buffer.
"""
assert buf.shape
- shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape,
- Call.Intrinsic)
- strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides,
- Call.Intrinsic) if buf.strides else 0
+ shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape)
+ strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0
pack_args = [buf.data,
shape,
strides,
const(0, dtype=buf.dtype),
buf.elem_offset]
return Call("handle", Op.get("tir.tvm_stack_make_array"),
- pack_args, Call.Intrinsic)
+ pack_args)
def call_packed(*args):
"""Build expression by call an external packed function.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
return Call(
- "int32", Op.get("tir.tvm_call_packed"), call_args, Call.Intrinsic)
+ "int32", Op.get("tir.tvm_call_packed"), call_args)
-def call_pure_intrin(dtype, func_name, *args):
- """Build expression by calling a pure intrinsic function.
+def call_intrin(dtype, func_name, *args):
+ """Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via
the intrinsic translation rule.
call : PrimExpr
The call expression.
"""
- args = convert(args)
return Call(
- dtype, func_name, convert(args), Call.PureIntrinsic)
+ dtype, func_name, convert(args))
-def call_intrin(dtype, func_name, *args):
- """Build expression by calling an intrinsic function.
-
- Intrinsics can be overloaded with multiple data types via
- the intrinsic translation rule.
+def call_pure_extern(dtype, func_name, *args):
+ """Build expression by calling a pure extern function.
Parameters
----------
The data type of the result.
func_name: str
- The intrinsic function name.
+ The extern function name.
args : list
Positional arguments.
call : PrimExpr
The call expression.
"""
- args = convert(args)
return Call(
- dtype, func_name, convert(args), Call.Intrinsic)
+ dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args))
-def call_pure_extern(dtype, func_name, *args):
- """Build expression by calling a pure extern function.
+def call_extern(dtype, func_name, *args):
+ """Build expression by calling a extern function.
Parameters
----------
The call expression.
"""
return Call(
- dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.PureExtern)
+ dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args))
-def call_extern(dtype, func_name, *args):
- """Build expression by calling a extern function.
+def call_llvm_intrin(dtype, name, *args):
+ """Build expression by calling a llvm intrinsic function
Parameters
----------
dtype : str
- The data type of the result.
+ The data type of the result.
- func_name: str
- The extern function name.
+ name : str
+ The name of the llvm intrinsic function.
args : list
- Positional arguments.
+ Poistional arguments.
Returns
-------
call : PrimExpr
The call expression.
"""
- return Call(
- dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.Extern)
+ # pylint: disable=import-outside-toplevel
+ from tvm.target import codegen
+ llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+ return call_intrin(
+ dtype, Op.get("tir.call_llvm_intrin"),
+ tvm.tir.const(llvm_id, 'uint32'), *args)
-def call_llvm_intrin(dtype, name, *args):
- """Build expression by calling an llvm intrinsic function
+def call_llvm_pure_intrin(dtype, name, *args):
+ """Build expression by calling a pure llvm intrinsic function
Parameters
----------
from tvm.target import codegen
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
- return call_pure_intrin(dtype, Op.get("tir.call_llvm_intrin"),
- tvm.tir.const(llvm_id, 'uint32'), *args)
+ return call_intrin(
+ dtype, Op.get("tir.call_llvm_pure_intrin"),
+ tvm.tir.const(llvm_id, 'uint32'), *args)
def any(*args):
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
call_args.insert(0, trace_action)
return tvm.tir.Call(
- args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args, tvm.tir.Call.Intrinsic)
+ args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args)
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.exp", x)
+ return call_intrin(x.dtype, "tir.exp", x)
def exp2(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.exp2", x)
+ return call_intrin(x.dtype, "tir.exp2", x)
def exp10(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.exp10", x)
+ return call_intrin(x.dtype, "tir.exp10", x)
def erf(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.erf", x)
+ return call_intrin(x.dtype, "tir.erf", x)
def tanh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.tanh", x)
+ return call_intrin(x.dtype, "tir.tanh", x)
def sigmoid(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.sigmoid", x)
+ return call_intrin(x.dtype, "tir.sigmoid", x)
def log(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.log", x)
+ return call_intrin(x.dtype, "tir.log", x)
def log2(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.log2", x)
+ return call_intrin(x.dtype, "tir.log2", x)
def log10(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.log10", x)
+ return call_intrin(x.dtype, "tir.log10", x)
def log1p(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.log1p", x)
+ return call_intrin(x.dtype, "tir.log1p", x)
def tan(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.tan", x)
+ return call_intrin(x.dtype, "tir.tan", x)
def cos(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.cos", x)
+ return call_intrin(x.dtype, "tir.cos", x)
def cosh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.cosh", x)
+ return call_intrin(x.dtype, "tir.cosh", x)
def acos(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.acos", x)
+ return call_intrin(x.dtype, "tir.acos", x)
def acosh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.acosh", x)
+ return call_intrin(x.dtype, "tir.acosh", x)
def sin(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.sin", x)
+ return call_intrin(x.dtype, "tir.sin", x)
def sinh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.sinh", x)
+ return call_intrin(x.dtype, "tir.sinh", x)
def asin(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.asin", x)
+ return call_intrin(x.dtype, "tir.asin", x)
def asinh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.asinh", x)
+ return call_intrin(x.dtype, "tir.asinh", x)
def atan(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.atan", x)
+ return call_intrin(x.dtype, "tir.atan", x)
def atanh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.atanh", x)
+ return call_intrin(x.dtype, "tir.atanh", x)
def atan2(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "tir.atan2", x1, x2)
+ return call_intrin(x1.dtype, "tir.atan2", x1, x2)
def sqrt(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.sqrt", x)
+ return call_intrin(x.dtype, "tir.sqrt", x)
def rsqrt(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.rsqrt", x)
+ return call_intrin(x.dtype, "tir.rsqrt", x)
def floor(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "tir.nextafter", x1, x2)
+ return call_intrin(x1.dtype, "tir.nextafter", x1, x2)
def hypot(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "tir.hypot", x1, x2)
+ return call_intrin(x1.dtype, "tir.hypot", x1, x2)
def copysign(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "tir.copysign", x1, x2)
+ return call_intrin(x1.dtype, "tir.copysign", x1, x2)
def ldexp(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "tir.ldexp", x1, x2)
+ return call_intrin(x1.dtype, "tir.ldexp", x1, x2)
def isnan(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.popcount", x)
+ return call_intrin(x.dtype, "tir.popcount", x)
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
z : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tir.fmod", x, y)
+ return call_intrin(x.dtype, "tir.fmod", x, y)
def if_then_else(cond, t, f):
false_value.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, {cond, true_value, false_value}, op->call_type);
+ return Call(op->dtype, op->op, {cond, true_value, false_value});
}
}
return StmtExprMutator::VisitExpr_(op);
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \
struct OpName { \
static PrimExpr Eval(Array<PrimExpr> args) { \
- return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \
+ return tir::Call(args[0].dtype(), GetOp(), args); \
} \
static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
}; \
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor);
// unary intrinsics
-#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \
- struct OpName { \
- static PrimExpr Eval(Array<PrimExpr> args) { \
- return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \
- } \
- static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
- }; \
- template <typename TA> \
- inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \
- return PCallExpr<OpName, TA>(a.derived()); \
+#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \
+ struct OpName { \
+ static PrimExpr Eval(Array<PrimExpr> args) { \
+ return tir::Call(args[0].dtype(), GetOp(), args); \
+ } \
+ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
+ }; \
+ template <typename TA> \
+ inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \
+ return PCallExpr<OpName, TA>(a.derived()); \
}
TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not);
// if_then_else
struct PIfThenElseOp {
- static PrimExpr Eval(Array<PrimExpr> args) {
- return tir::Call(args[1].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic);
- }
+ static PrimExpr Eval(Array<PrimExpr> args) { return tir::Call(args[1].dtype(), GetOp(), args); }
static const Op& GetOp() { return tir::builtin::if_then_else(); }
};
PrintExpr(op->args[0], os);
os << " else ";
PrintExpr(op->args[2], os);
- } else if (op->op.same_as(builtin::call_extern())) {
+ } else if (op->op.same_as(builtin::call_pure_extern()) ||
+ op->op.same_as(builtin::call_extern())) {
StringImm fname = Downcast<StringImm>(op->args[0]);
os << fname << "(";
for (size_t i = 1; i < op->args.size(); i++) {
return doc;
}
-inline const char* CallType2String(CallNode::CallType t) {
- switch (t) {
- case CallNode::Extern:
- return "extern";
- case CallNode::ExternCPlusPlus:
- return "extern_cpp";
- case CallNode::PureExtern:
- return "pure_extern";
- case CallNode::Intrinsic:
- return "intrin";
- case CallNode::PureIntrinsic:
- return "pure_intrin";
- }
- LOG(FATAL) << "Unknown CallType";
- return "Unknown";
-}
-
Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
if (auto* ptr_op = op->op.as<OpNode>()) {
for (const auto& arg : op->args) {
args.push_back(Print(arg));
}
- doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype)
- << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) << ")";
+ doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")";
return doc;
}
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchPureExtern<FloatSuffix>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchPureExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
*rv = one / sqrt(call->args[0]);
});
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchPureExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
// Call pure extern function.
template <typename T>
-inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
+inline void DispatchPureExtern(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
for (auto arg : call->args) {
new_args.push_back(arg);
}
- *rv = Call(call->dtype, tir::builtin::call_extern(), new_args, CallNode::PureExtern);
+ *rv = Call(call->dtype, tir::builtin::call_pure_extern(), new_args);
} else {
*rv = e;
}
};
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
- if (op->op.same_as(builtin_call_llvm_intrin_)) {
+ if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
if (id == ::llvm::Intrinsic::ctpop) {
PrimExpr e = ARMPopcount(op);
vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt_args.push_back(e);
- return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt_args, CallNode::PureIntrinsic);
+ return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args);
}
// Popcount lowering rule:
vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
- PrimExpr vcnt8 =
- tir::Call(uint8_type, builtin_call_llvm_intrin_, vcnt8_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args);
// Accumulation 8->16bit
Array<PrimExpr> vcnt16_args;
vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
- PrimExpr vcnt16 =
- tir::Call(uint16_type, builtin_call_llvm_intrin_, vcnt16_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args);
if (call->dtype.bits() == 16) {
return vcnt16;
}
vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
- PrimExpr vcnt32 =
- tir::Call(uint32_type, builtin_call_llvm_intrin_, vcnt32_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args);
if (call->dtype.bits() == 32) {
return vcnt32;
}
vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
- return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt64_args, CallNode::PureIntrinsic);
+ return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
}
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
- if (op->op.same_as(builtin_call_llvm_intrin_)) {
+ if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);
- if (op->op.same_as(builtin_call_extern_)) {
+ if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
// call extern intrinsic
CHECK_GE(op->args.size(), 1U);
auto global_symbol = Downcast<StringImm>(op->args[0]);
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
const Op& builtin_call_extern_ = builtin::call_extern();
+ const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin();
+ const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin();
+
/*! \brief Helper struct for debug infos. */
struct DebugInfo {
std::unique_ptr<llvm::DIBuilder> di_builder_;
DTypeToLLVMType(DataType::Float(32, from.lanes())),
{
MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(),
- {op->value}, tir::CallNode::PureIntrinsic)),
+ {op->value})),
MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())),
/*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
/*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
const auto has_f16c = TargetHasFeature(*target_machine_, "f16c");
if (from.lanes() >= 8 && has_f16c) {
- return CallVectorIntrin(
- ::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
- DTypeToLLVMType(DataType::Float(32, from.lanes())),
- {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(),
- {op->value}, tir::CallNode::PureIntrinsic))});
+ return CallVectorIntrin(::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
+ DTypeToLLVMType(DataType::Float(32, from.lanes())),
+ {MakeValue(tir::Call(DataType::Int(16, from.lanes()),
+ tir::builtin::reinterpret(), {op->value}))});
}
#endif
}
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv =
- tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::PureIntrinsic);
+ *rv = tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs);
}
template <unsigned id, int num_signature>
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::Intrinsic);
+ *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs);
}
} // namespace codegen
namespace tvm {
namespace codegen {
-inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
+inline void DispatchPureExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
using namespace tir;
const CallNode* call = e.as<CallNode>();
for (auto arg : call->args) {
new_args.push_back(arg);
}
- *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern);
+ *rv = Call(call->dtype, builtin::call_pure_extern(), new_args);
}
namespace llvm {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchPureExternLibDevice);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchPureExternLibDevice);
} // namespace llvm
} // namespace codegen
namespace tvm {
namespace codegen {
-inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
+inline void DispatchPureExternOCML(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
using namespace tir;
const CallNode* call = e.as<CallNode>();
new_args.push_back(arg);
}
- *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern);
+ *rv = Call(call->dtype, builtin::call_pure_extern(), new_args);
}
inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) {
// get own lane in self (__lane_id)
PrimExpr minus_one = tir::make_const(DataType::Int(32), -1);
PrimExpr zero = tir::make_zero(DataType::Int(32));
- PrimExpr lo = Call(DataType::Int(32), builtin::call_extern(),
- {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, CallNode::PureExtern);
- PrimExpr self = Call(DataType::Int(32), builtin::call_extern(),
- {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, CallNode::PureExtern);
+ PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(),
+ {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero});
+ PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(),
+ {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo});
// compute lane to get from
PrimExpr width = call->args[3];
index = self + delta;
index = Select((self & (width - 1)) + delta >= width, self, index);
}
- PrimExpr res =
- Call(var.dtype(), builtin::call_extern(),
- {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}, CallNode::PureExtern);
+ PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(),
+ {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
*rv = res;
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchPureExternOCML);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchPureExternOCML);
} // namespace llvm
} // namespace codegen
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);
- if (op->op.same_as(builtin_call_extern_)) {
+ if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
CHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
// cache commonly used ops
const Op& builtin_call_extern_ = builtin::call_extern();
+ const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
private:
/*! \brief whether to print in SSA form */
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchPureExtern<Direct>);
} // namespace intrin
} // namespace codegen
static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) {
Call call = args[0];
- *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, CallNode::PureExtern);
+ *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
}
template <typename T>
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
- *rv =
- Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args, CallNode::PureExtern);
+ *rv = Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
}
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchPureExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchPureExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchPureExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchPureExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchPureExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchExtern<CUDAFastMathTan>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchPureExtern<CUDAFastMathTan>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchPureExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchPureExtern<CUDAFastMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchPureExtern<CUDAMath>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchExtern<CUDAPopcount>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchPureExtern<CUDAPopcount>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask")
.set_body(DispatchCUDAWarpActiveMask);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchPureExtern<CUDAMath>);
// Register low-level builtin ops.
// TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins.
TVM_REGISTER_OP("tir.cuda.__shfl_sync")
.set_num_inputs(4)
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_sync")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
.set_attr<bool>("cuda.need_warp_shuffle", true);
TVM_REGISTER_OP("tir.cuda.__shfl_up_sync")
.set_num_inputs(4)
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_up_sync")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
.set_attr<bool>("cuda.need_warp_shuffle", true);
TVM_REGISTER_OP("tir.cuda.__shfl_down_sync")
.set_num_inputs(4)
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_down_sync")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
.set_attr<bool>("cuda.need_warp_shuffle", true);
TVM_REGISTER_OP("tir.cuda.__activemask")
.set_num_inputs(0)
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__activemask")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<bool>("cuda.need_warp_shuffle", true);
} // namespace intrin
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchPureExtern<Direct>);
} // namespace intrin
} // namespace codegen
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchPureExtern<Direct>);
// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
<< "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}};
- *rv = Call(call->dtype, builtin::call_extern(), opencl_args, CallNode::PureExtern);
+ *rv = Call(call->dtype, builtin::call_pure_extern(), opencl_args);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle);
namespace codegen {
namespace intrin {
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchPureExtern<Direct>);
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchPureExtern<Direct>);
} // namespace intrin
} // namespace codegen
}
spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
- if (op->op.same_as(builtin::call_spirv_glsl450())) {
+ if (op->op.same_as(builtin::call_spirv_pure_glsl450())) {
CHECK_GE(op->args.size(), 2U);
uint32_t inst_id = static_cast<uint32_t>(op->args[0].as<IntImmNode>()->value);
std::vector<spirv::Value> values;
return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else {
- if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
- LOG(FATAL) << "Unresolved intrinsic " << op->op << " with return type " << op->dtype;
- } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
- LOG(FATAL) << "Unresolved extern " << op->op << " with return type " << op->dtype;
- } else {
- LOG(FATAL) << "Unresolved call type " << op->call_type;
- }
+ LOG(FATAL) << "Unresolved call " << op->op;
return spirv::Value();
}
}
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::Call(call->dtype, tir::builtin::call_spirv_glsl450(), cargs,
- tir::CallNode::PureIntrinsic);
+ *rv = tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
PrimExpr VisitExpr_(const CallNode* op) {
PrimExpr expr = GetRef<PrimExpr>(op);
- if (op->call_type == CallNode::CallType::PureIntrinsic) {
- if (op->op.same_as(op_exp_)) {
- return Mul(Mutate(op->args[0]), expr);
- } else if (op->op.same_as(op_log_)) {
- return Div(Mutate(op->args[0]), op->args[0]);
- } else if (op->op.same_as(op_sigmoid_)) {
- return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr)));
- } else if (op->op.same_as(op_sqrt_)) {
- return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0)));
- } else if (op->op.same_as(op_tanh_)) {
- return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr)));
- } else if (op->op.same_as(op_pow_)) {
- auto x = op->args[0], y = op->args[1];
- return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
- } else if (op->op.same_as(op_fabs_)) {
- auto type = op->args[0].dtype();
- return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)),
- FloatImm(type, 1.0), FloatImm(type, -1.0)));
- } else if (op->op.same_as(op_if_then_else_)) {
- Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
- return Call(op->dtype, op->op, new_args, op->call_type);
- } else if (piecewise_const.count(op->op)) {
- return FloatImm(expr.dtype(), 0.0);
- } else {
- LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op;
- }
+ if (op->op.same_as(op_exp_)) {
+ return Mul(Mutate(op->args[0]), expr);
+ } else if (op->op.same_as(op_log_)) {
+ return Div(Mutate(op->args[0]), op->args[0]);
+ } else if (op->op.same_as(op_sigmoid_)) {
+ return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr)));
+ } else if (op->op.same_as(op_sqrt_)) {
+ return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0)));
+ } else if (op->op.same_as(op_tanh_)) {
+ return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr)));
+ } else if (op->op.same_as(op_pow_)) {
+ auto x = op->args[0], y = op->args[1];
+ return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
+ } else if (op->op.same_as(op_fabs_)) {
+ auto type = op->args[0].dtype();
+ return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0),
+ FloatImm(type, -1.0)));
+ } else if (op->op.same_as(op_if_then_else_)) {
+ Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
+ return Call(op->dtype, op->op, new_args);
+ } else if (piecewise_const.count(op->op)) {
+ return FloatImm(expr.dtype(), 0.0);
+ } else {
+ LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op;
+ return PrimExpr();
}
- NOT_IMPLEMENTED;
}
PrimExpr VisitExpr_(const AddNode* op) { return Add(Mutate(op->a), Mutate(op->b)); }
if (attr->dim_align_factor != 0) {
Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
attr->dim_align_offset};
- realize = tir::AttrStmt(
- t, tir::attr::buffer_dim_align,
- Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic),
- realize);
+ realize =
+ tir::AttrStmt(t, tir::attr::buffer_dim_align,
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), realize);
}
}
}
// Apply the existing input predicate if any.
output_preds.push_back(input_pred);
- Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(),
- freduce_args, CallNode::Intrinsic));
+ Stmt reduce_body =
+ Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args));
reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope,
make_zero(DataType::Handle()), reduce_body);
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), builtin::tvm_tuple(), tuple, CallNode::Intrinsic), ret);
+ Call(DataType::Handle(), builtin::tvm_tuple(), tuple), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
tuple.push_back(region[i]->min);
tuple.push_back(region[i]->extent);
}
- input_bind_nest.emplace_back(AttrStmt(
- bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+ input_bind_nest.emplace_back(
+ AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
}
// output binding
}
}
- output_bind_nest.emplace_back(AttrStmt(
- bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+ output_bind_nest.emplace_back(
+ AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
}
// Check variable remap
tuple.push_back(r->min);
tuple.push_back(r->extent);
}
- input_bind_nest.emplace_back(AttrStmt(
- bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+ input_bind_nest.emplace_back(
+ AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
}
// output binding
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
Tensor tensor = stage->op.output(i - intrin->inputs.size());
Buffer buffer = intrin->buffers[i];
Array<ObjectRef> bind_spec{buffer, tensor};
- output_bind_nest.emplace_back(AttrStmt(
- bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+ output_bind_nest.emplace_back(
+ AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
}
// Check variable remap
std::unordered_map<const VarNode*, PrimExpr> vmap;
return Evaluate(
Call(DataType::Handle(), builtin::tvm_bmma_sync(),
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
- buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
- CallNode::Intrinsic));
+ buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}));
} else {
return Evaluate(
Call(DataType::Handle(), builtin::tvm_mma_sync(),
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
- buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
- CallNode::Intrinsic));
+ buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}));
}
};
auto fill_fragment_call = [this, &op](const Buffer& buffer) {
return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, op->value},
- CallNode::Intrinsic));
+ buffer->elem_offset, op->value}));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
ThreadIdxMutator thread_idx_mutator(warp_y);
PrimExpr mutated_value = thread_idx_mutator(op->value);
// TODO(tvm-team) The extern function name seems to be a hack.
- PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value},
- CallNode::Extern);
+ PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value});
auto pload = dst.as<ProducerLoadNode>();
PrimExpr matrix_major;
auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, src, stride, matrix_major},
- CallNode::Intrinsic));
+ buffer->elem_offset, src, stride, matrix_major}));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator(dst);
- dst =
- Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}, CallNode::Extern);
+ dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst});
auto pload = op->value.as<ProducerLoadNode>();
auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, dst, stride, StringImm("col_major")},
- CallNode::Intrinsic));
+ buffer->elem_offset, dst, stride, StringImm("col_major")}));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
args.push_back(pload->indices[i]);
args.push_back(shape[i]);
}
- auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args, CallNode::Intrinsic);
+ auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args);
Array<ObjectRef> node = {buffer, tensor};
return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer));
}
* \file side_effect.cc
* \brief side effect analysis
*/
+#include <tvm/ir/op.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op_attr_types.h>
namespace tvm {
namespace tir {
}
void VisitExpr_(const CallNode* op) final {
- if (!op->is_pure()) {
+ static auto op_call_effect = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
+
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ auto effect_kind = op_call_effect[GetRef<Op>(ptr_op)];
+ if (effect_kind != CallEffectKind::kPure && effect_kind != CallEffectKind::kExprAnnotation) {
+ has_side_effect_ = true;
+ return;
+ } else {
+ ExprVisitor::VisitExpr_(op);
+ }
+ } else {
has_side_effect_ = true;
return;
- } else {
- ExprVisitor::VisitExpr_(op);
}
}
}
Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
make_const(DataType::Int(32), access_mask)};
- return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args, tir::CallNode::Intrinsic);
+ return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
}
Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
});
// Call
-Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type) {
+Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined());
}
node->dtype = dtype;
node->op = std::move(op);
node->args = std::move(args);
- node->call_type = call_type;
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Call")
- .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, int call_type) {
+ .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args) {
Array<PrimExpr> prim_expr_args;
for (const auto& it : args) {
CHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>());
prim_expr_args.push_back(Downcast<PrimExpr>(it));
}
}
- return Call(type, op, prim_expr_args, static_cast<CallNode::CallType>(call_type));
+ return Call(type, op, prim_expr_args);
});
TVM_REGISTER_NODE_TYPE(CallNode);
if (args.same_as(op->args)) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, args, op->call_type);
+ return Call(op->dtype, op->op, args);
}
}
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
+#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
namespace tvm {
PrimExpr TypeAnnotation(DataType dtype) {
static auto op = Op::Get("tir.type_annotation");
- return tir::Call(dtype, op, {}, tir::CallNode::PureIntrinsic);
+ return tir::Call(dtype, op, {});
}
-TVM_REGISTER_OP("tir.type_annotation");
+TVM_REGISTER_OP("tir.type_annotation")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace tir
} // namespace tvm
} \
TVM_REGISTER_OP("tir." #OpName)
-TIR_DEFINE_BUILTIN_FUNC(reinterpret).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(reinterpret)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+ .set_num_inputs(1);
-TIR_DEFINE_BUILTIN_FUNC(likely).set_num_inputs(1).set_attr<TVectorizable>("TVectorizable", true);
+TIR_DEFINE_BUILTIN_FUNC(likely)
+ .set_num_inputs(1)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
+ .set_attr<TVectorizable>("TVectorizable", true);
TIR_DEFINE_BUILTIN_FUNC(bitwise_and)
.set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TVectorizable>("TVectorizable", true);
TIR_DEFINE_BUILTIN_FUNC(bitwise_or)
.set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TVectorizable>("TVectorizable", true);
TIR_DEFINE_BUILTIN_FUNC(bitwise_xor)
.set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TVectorizable>("TVectorizable", true);
TIR_DEFINE_BUILTIN_FUNC(bitwise_not)
.set_num_inputs(1)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TVectorizable>("TVectorizable", true);
TIR_DEFINE_BUILTIN_FUNC(shift_left)
.set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TVectorizable>("TVectorizable", true);
TIR_DEFINE_BUILTIN_FUNC(shift_right)
.set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TVectorizable>("TVectorizable", true);
-TIR_DEFINE_BUILTIN_FUNC(large_uint_imm).set_num_inputs(2);
+TIR_DEFINE_BUILTIN_FUNC(large_uint_imm)
+ .set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
+
+TIR_DEFINE_BUILTIN_FUNC(address_of)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+ .set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(if_then_else)
+ .set_num_inputs(3)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(address_of).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr<TCallEffectKind>(
+ "TCallEffectKind", Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(if_then_else).set_num_inputs(3);
+TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1).set_attr<TCallEffectKind>(
+ "TCallEffectKind", Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(popcount)
+ .set_num_inputs(1)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+ .set_attr<TVectorizable>("TVectorizable", true);
-TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(fma)
+ .set_num_inputs(3)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+ .set_attr<TVectorizable>("TVectorizable", true);
-TIR_DEFINE_BUILTIN_FUNC(popcount).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(call_extern)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(fma).set_num_inputs(3).set_attr<TVectorizable>("TVectorizable", true);
+TIR_DEFINE_BUILTIN_FUNC(call_pure_extern)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(call_extern);
+TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin);
+TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(call_spirv_glsl450);
+TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(prefetch);
+TIR_DEFINE_BUILTIN_FUNC(prefetch).set_attr<TCallEffectKind>("TCallEffectKind",
+ Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr).set_num_inputs(5);
+TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr)
+ .set_num_inputs(5)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg));
-TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle).set_num_inputs(0);
+TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle)
+ .set_num_inputs(0)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg));
-TIR_DEFINE_BUILTIN_FUNC(tvm_context_id).set_num_inputs(0);
+TIR_DEFINE_BUILTIN_FUNC(tvm_context_id)
+ .set_num_inputs(0)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));
-TIR_DEFINE_BUILTIN_FUNC(tvm_tuple);
+TIR_DEFINE_BUILTIN_FUNC(tvm_tuple).set_attr<TCallEffectKind>("TCallEffectKind",
+ Integer(CallEffectKind::kEmbedInfo));
-TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get).set_num_inputs(3);
+TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get)
+ .set_num_inputs(3)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));
-TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set).set_num_inputs(4);
+TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set)
+ .set_num_inputs(4)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));
-TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error).set_num_inputs(0);
+TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error)
+ .set_num_inputs(0)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca).set_num_inputs(2);
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca)
+ .set_num_inputs(2)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape);
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array).set_num_inputs(6);
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array)
+ .set_num_inputs(6)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
// When num_inputs are not set, the function is assumed to be variable length.
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context)
+ .set_num_inputs(1)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
// TODO(tvm-team) revisit storage sync once we have a good memory hierachy structure.
-TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit);
+TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce);
+TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));
-TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment);
+TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(vectorhigh);
+TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(vectorlow);
+TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr<TCallEffectKind>("TCallEffectKind",
+ Integer(CallEffectKind::kPure));
-TIR_DEFINE_BUILTIN_FUNC(vectorcombine);
+TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace builtin
} // namespace tir
using namespace tir;
// macro to register an unary op
-#define TIR_REGISTER_PURE_UNARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(1)
+#define TIR_REGISTER_PURE_UNARY_OP(OpName) \
+ TVM_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
+ "TCallEffectKind", Integer(CallEffectKind::kPure))
// macro to register an binary op
-#define TIR_REGISTER_PURE_BINARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(2)
+#define TIR_REGISTER_PURE_BINARY_OP(OpName) \
+ TVM_REGISTER_OP(OpName).set_num_inputs(2).set_attr<TCallEffectKind>( \
+ "TCallEffectKind", Integer(CallEffectKind::kPure))
runtime::DataType GetRuntimeDataType(const Type& type) {
if (auto* n = type.as<PrimTypeNode>()) {
// LargeUIntImm
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
return tir::Call(t, tir::builtin::large_uint_imm(),
- {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)},
- tir::CallNode::PureIntrinsic);
+ {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)});
}
// The public function with a quick checking path.
// reinterpret
PrimExpr reinterpret(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
- return tir::Call(t, tir::builtin::reinterpret(), {value}, tir::CallNode::PureIntrinsic);
+ return tir::Call(t, tir::builtin::reinterpret(), {value});
}
// operator+
}
return tir::Call(true_value.dtype(), tir::builtin::if_then_else(),
- {cond, true_value, false_value}, tir::CallNode::PureIntrinsic);
+ {cond, true_value, false_value});
}
// likely
PrimExpr likely(PrimExpr cond) {
if (is_const(cond)) return cond;
- return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, tir::CallNode::PureIntrinsic);
+ return tir::Call(cond.dtype(), tir::builtin::likely(), {cond});
}
-TVM_REGISTER_OP("tir.likely").set_num_inputs(1);
-
// operator>
PrimExpr operator>(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
}
});
- return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b});
}
// shift left
if (pb->value == 0) return a;
}
});
- return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b});
}
// bitwise and
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
});
- return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b});
}
// bitwise_or
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
});
- return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b});
}
// bitwise_xor
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
});
- return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b});
}
// bitwie_not
PrimExpr operator~(PrimExpr a) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
- return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a});
}
-TVM_REGISTER_OP("tir.bitwise_not");
-
TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; });
// pow
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "power only applies to float";
static auto op = Op::Get("tir.pow");
- return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x, y});
}
-TVM_REGISTER_OP("tir.pow").set_num_inputs(2).set_attr<TVectorizable>("TVectorizable", true);
+TIR_REGISTER_PURE_BINARY_OP("tir.pow").set_attr<TVectorizable>("TVectorizable", true);
// abs
PrimExpr abs(PrimExpr x) {
return FloatImm(x.dtype(), std::fabs(fx->value));
}
static auto op = Op::Get("tir.fabs");
- return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x});
} else if (x.dtype().is_uint()) {
return x;
} else {
}
static auto op = Op::Get("tir.isnan");
if (x.dtype().bits() == 16) {
- return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))},
- tir::CallNode::PureIntrinsic);
+ return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))});
} else {
- return tir::Call(t, op, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(t, op, {x});
}
} else {
LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op...";
}
}
-TIR_REGISTER_PURE_UNARY_OP("tir.isnan");
-
// isinf
PrimExpr isinf(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "fmod only applies to float";
static auto op = Op::Get("tir.fmod");
- return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x, y});
}
TIR_REGISTER_PURE_UNARY_OP("tir.fmod");
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
static auto op = Op::Get("tir.floor");
- return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr<TVectorizable>("TVectorizable", true);
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
static auto op = Op::Get("tir.ceil");
- return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr<TVectorizable>("TVectorizable", true);
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
static auto op = Op::Get("tir.round");
- return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr<TVectorizable>("TVectorizable", true);
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
static auto op = Op::Get("tir.nearbyint");
- return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint");
return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)));
}
static auto op = Op::Get("tir.trunc");
- return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.popcount").set_attr<TVectorizable>("TVectorizable", true);
-
TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr<TVectorizable>("TVectorizable", true);
TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace")
.set_num_inputs(5)
- .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace");
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace")
.set_num_inputs(3)
- .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace");
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
} // namespace tir
} // namespace tvm
def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
- PrimExpr is_null =
- Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}, CallNode::PureIntrinsic);
+ PrimExpr is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides});
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
auto uint32_v = Cast(uint32_dtype, op_val);
// to be endian invariant.
- return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}, CallNode::PureIntrinsic);
+ return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16});
} else if (op->dtype.is_bfloat16()) {
// if is cast_to_bf16, check if op->value is fp32
CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
- auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}, CallNode::PureIntrinsic);
+ auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val});
auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes());
/* the following TIR is equivalent to the C++ code below:
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
}
std::vector<Stmt> GetSync(std::string sync_name) {
- return {
- Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}, CallNode::Intrinsic))};
+ return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))};
}
const std::unordered_set<const VarNode*>& touched_;
PrimExpr min = r->min;
PrimExpr extent = r->extent;
return Evaluate(Call(DataType::Int(32), Op::Get(func),
- {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
- CallNode::Intrinsic));
+ {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}));
}
// Write barrier name
bool read_barrier_{false};
Stmt MakePush(int from, int to) {
return Evaluate(Call(DataType::Int(32), sync_push_op_,
- {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- CallNode::Intrinsic));
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
}
Stmt MakePop(int from, int to) {
return Evaluate(Call(DataType::Int(32), sync_pop_op_,
- {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- CallNode::Intrinsic));
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
}
// sync states.
SyncState first_state_, last_state_, curr_state_;
PrimExpr extent = this->VisitExpr(op->args[3]);
PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
- return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]},
- op->call_type);
+ return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]});
} else if (op->op.same_as(builtin::tvm_context_id())) {
return allow_share_ ? GetRef<PrimExpr>(op) : var_;
} else {
builtin::TVMStructFieldKind kind) {
Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind))};
- return Call(dtype, builtin::tvm_struct_get(), args, CallNode::PureIntrinsic);
+ return Call(dtype, builtin::tvm_struct_get(), args);
}
/*!
inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) {
return Call(DataType::Handle(), builtin::address_of(),
{Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
- const_true(dtype.lanes()))},
- CallNode::PureIntrinsic);
+ const_true(dtype.lanes()))});
}
/*!
offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes());
}
return Call(DataType::Handle(), builtin::address_of(),
- {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic);
+ {Load(dtype, handle, offset, const_true(dtype.lanes()))});
}
/*!
inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) {
Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind)), value};
- return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args, CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args));
}
/*!
}
PrimExpr VisitExpr_(const CallNode* op) final {
- // NOTE: call_type will eventually be deprecated and the information
- // will be folded into Op's attr
- if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
- if (auto* ptr_op = op->op.as<OpNode>()) {
- // Still use legacy string based rewriting
- // TODO(tvm-team): migrate the pattern application from global function look up
- // to an OpAttrMap<PackedFunc>
- std::string name = ptr_op->name;
- PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
- if (r.defined()) return r;
- }
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ // Still use legacy string based rewriting
+ // TODO(tvm-team): migrate the pattern application from global function look up
+ // to an OpAttrMap<PackedFunc>
+ std::string name = ptr_op->name;
+ PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
+ if (r.defined()) return r;
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
- PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}, CallNode::PureIntrinsic));
+ PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
Var mask_var("mask", DataType::UInt(32));
{
PrimExpr pred = const_true(1);
- PrimExpr mask =
- Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic);
+ PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
seq.emplace_back(Store(mask_var, mask, index, pred));
// Push allocation with an empty body. Later this will be fixed
// when the entire body is ready.
}
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
- return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}));
}
// Emit warp shuffle calls.
PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred);
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width};
- return Call(val.dtype(), op, args, CallNode::Intrinsic);
+ return Call(val.dtype(), op, args);
}
// Check if this is a reduction on threadIdx.x and its extent matches
inline PrimExpr StackAlloca(std::string type, size_t num) {
Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
- return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args, CallNode::Intrinsic);
+ return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
}
// Calculate the statistics of packed function.
}
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
- Stmt throw_last_error =
- Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}, CallNode::Intrinsic));
+ Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
- Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var},
- CallNode::PureIntrinsic),
+ Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
throw_last_error),
op->body});
Stmt alloca = LetStmt(
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
- IntImm(DataType::Int(32), op->dtype.bits())},
- CallNode::Extern),
+ IntImm(DataType::Int(32), op->dtype.bits())}),
body);
PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_),
- cast(DataType::Int(32), device_id_), op->buffer_var},
- CallNode::Extern);
+ cast(DataType::Int(32), device_id_), op->buffer_var});
Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
body = AttrStmt(op->buffer_var, attr::storage_alignment,
Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};
- return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args,
- CallNode::Intrinsic);
+ return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args);
}
PrimExpr MakeCallTracePacked(const CallNode* op) {
ConstInt32(arg_stack_begin + op->args.size() - 1),
// Pass traced value.
op->args[args_size - 1]};
- return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args,
- CallNode::Intrinsic);
+ return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args);
}
private:
<< "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index
<< " local_index=" << local_index;
PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);
- PrimExpr mask =
- Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic);
+ PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
return Call(load_value.dtype(), builtin::tvm_warp_shuffle(),
- {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic);
+ {mask, load_value, group, width_, warp_size_});
} else {
return StmtExprMutator::VisitExpr_(op);
}
IntImm(DataType::Int(32), builtin::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
- PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args, CallNode::PureIntrinsic);
+ PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != t) {
res = Cast(t, res);
if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
- {StringImm(runtime::symbol::tvm_set_device), device_type, device_id},
- CallNode::Intrinsic));
+ {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
body = SeqStmt({set_device, body});
}
}
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
} else if (op->op.same_as(builtin::address_of())) {
const LoadNode* l = op->args[0].as<LoadNode>();
return this->VisitExpr(l->index);
- } else if (op->is_pure()) {
- for (PrimExpr e : op->args) {
- if (VisitExpr(e)) return true;
+ } else if (auto* ptr_op = op->op.as<OpNode>()) {
+ auto effect_kind = op_call_effect_[GetRef<Op>(ptr_op)];
+ if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) {
+ for (PrimExpr e : op->args) {
+ if (VisitExpr(e)) return true;
+ }
+ return false;
+ } else {
+ return true;
}
- return false;
} else {
return true;
}
bool BinaryOp(const T* op) {
return VisitExpr(op->a) || VisitExpr(op->b);
}
+
+ OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
};
class UnsafeSelectRewriter : public StmtExprMutator {
if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) &&
cond_is_scalar_bool) {
return Call(op->dtype, builtin::if_then_else(),
- {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic);
+ {op->condition, op->true_value, op->false_value});
} else {
return expr;
}
for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
- return Evaluate(
- Call(DataType::Int(32), builtin::tvm_call_packed(), call_args, CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args));
}
// target ir module
stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
- PrimExpr address =
- Call(DataType::Handle(), builtin::address_of(), {load}, CallNode::PureIntrinsic);
- PrimExpr prefetch =
- Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}, CallNode::Intrinsic);
+ PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load});
+ PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1});
stmt = Evaluate(prefetch);
PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
if (se->bits_offset != 0) {
offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
}
- return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]},
- op->call_type);
+ return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]});
} else {
return StmtExprMutator::VisitExpr_(op);
}
barrier = MakeGlobalBarrier();
} else {
barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
- {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic));
+ {StringImm(sync_scope_.to_string())}));
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
Stmt InitGlobalBarrier(const AttrStmtNode* op) {
CHECK(op != nullptr);
Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)};
- Stmt prep =
- Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs, CallNode::Intrinsic));
+ Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
Stmt body = op->body;
for (const auto& kv : rw_stats_) {
const auto& e = kv.second;
}
}
rw_stats_.clear();
- Stmt kinit = Evaluate(
- Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}, CallNode::Intrinsic));
+ Stmt kinit = Evaluate(Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}));
body = SeqStmt({kinit, body});
body = AttrStmt(op->node, op->attr_key, op->value, body);
return SeqStmt({prep, body});
CHECK_EQ(num_work_dim_, thread_extents_.size());
}
return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
- {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_},
- CallNode::Intrinsic));
+ {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}));
}
// data structure.
StorageScope sync_scope_;
int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
t = BroadcastTo(t, lanes);
f = BroadcastTo(f, lanes);
- return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->call_type);
+ return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
}
}
// Call
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, new_args, op->call_type);
+ return Call(op->dtype, op->op, new_args);
}
} else {
int lane = 0;
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype.with_lanes(lane), op->op, new_args, op->call_type);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args);
}
}
}
}
{
- auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1},
- CallNode::Extern));
+ auto body =
+ Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}));
auto res = v(std::move(body));
CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[1].same_as(x));
}
# Test that components with side effects are not removed
dummy = tvm.ir.GlobalVar("dummy")
- side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs, tvm.tir.Call.Intrinsic)
+ side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs)
ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name='A', dtype="int32")
- B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "tir.reinterpret", A(*i)), name='B')
+ B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", A(*i)), name='B')
s = te.create_schedule(B.op)
def check_c():
n = tvm.runtime.convert(4)
A = ib.pointer("float32", name="A")
args = [
- tvm.tir.call_pure_intrin("handle", "tir.address_of", A[0]),
+ tvm.tir.call_intrin("handle", "tir.address_of", A[0]),
0, 3, 1
]
ib.emit(tvm.tir.Evaluate(
tvm.tir.Call(
- "int32", "tir.prefetch", args, tvm.tir.Call.Intrinsic)))
+ "int32", "tir.prefetch", args)))
body = ib.get()
mod = tvm.IRModule.from_expr(
def use_llvm_intrinsic(A, C):
ib = tvm.tir.ir_builder.create()
L = A.vload((0,0))
- I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz',
+ I = tvm.tir.call_llvm_pure_intrin('int32', 'llvm.ctlz',
tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1'))
S = C.vstore((0,0), I)
ib.emit(S)
ib = tvm.tir.ir_builder.create()
A = ib.pointer("uint8x8", name="A")
z = tvm.tir.const(0, 'int32')
- x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
+ x = tvm.tir.call_llvm_pure_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x)
body = ib.get()
mod = tvm.IRModule.from_expr(
assert x.vectors[0] == a
assert x.indices[0].value == 0
- x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a], tvm.tir.Call.Extern)
+ x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a])
assert isinstance(x, tvm.tir.Call)
assert x.dtype == "float32"
assert x.op.name == "tir.call_extern"
assert x.args[1] == a
- assert x.call_type == tvm.tir.Call.Extern
v = te.var("aa")
x = tvm.tir.Let(v, 1, v)
def test_bitwise():
x = te.var('x')
y = te.var('y')
- assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32, type="pure_intrin")'
+ assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32)'
+ assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32)'
+ assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32)'
+ assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32)'
+ assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32)'
+ assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32)'
+ assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32)'
+ assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32)'
+ assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32)'
+ assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32)'
assert str(10 % x) == 'floormod(10, x: int32)'
- assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32, type="pure_intrin")'
+ assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32)'
assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
def test_isnan():
x = te.var('x', 'float32')
- assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool, type="pure_intrin")'
+ assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool)'
assert str(tvm.tir.isnan(x).dtype) == 'bool'
y = te.var('y', 'float16')
- assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")'
+ assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool)'
z = te.var('z', 'int32')
assert str(tvm.tir.isnan(z)) == 'False'
k = te.var('k', 'int8x2')
def test_legalize():
def to32(v):
uint32_v = topi.cast(v, "uint32")
- uint32_v = tvm.tir.call_pure_intrin(
+ uint32_v = tvm.tir.call_intrin(
"uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32"))
- return tvm.tir.call_pure_intrin("float32", "tir.reinterpret", uint32_v)
+ return tvm.tir.call_intrin("float32", "tir.reinterpret", uint32_v)
def to16(v):
- uint32_v = tvm.tir.call_pure_intrin("uint32", "tir.reinterpret", v)
- rounding_bias = tvm.tir.call_pure_intrin(
+ uint32_v = tvm.tir.call_intrin("uint32", "tir.reinterpret", v)
+ rounding_bias = tvm.tir.call_intrin(
"uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
- rounding_bias = tvm.tir.call_pure_intrin(
+ rounding_bias = tvm.tir.call_intrin(
"uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32"))
rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
uint32_v = uint32_v + rounding_bias
- uint32_v = tvm.tir.call_pure_intrin(
+ uint32_v = tvm.tir.call_intrin(
"uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
return topi.cast(uint32_v, 'uint16')
def device_context(dev_id):
ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
return tvm.tir.Call(
- "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic)
+ "handle", "tir.tvm_thread_context", [ctx])
ib = tvm.tir.ir_builder.create()
n = te.var("n")
bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
ib.emit(tvm.tir.call_extern("int32", "Run",
bbuffer.access_ptr("r"),
- tvm.tir.call_pure_intrin("int32", "tir.tvm_context_id")))
+ tvm.tir.call_intrin("int32", "tir.tvm_context_id")))
C[i * nthread + tx] = B[i] + 1
return ib.get()
*/
inline PrimExpr pack_buffer(Buffer buf) {
CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
- auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(),
- buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
+ auto shape =
+ tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape);
PrimExpr strides;
if (buf->strides.size() > 0) {
- strides = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(),
- buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
+ strides =
+ tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape);
} else {
strides = 0;
}
make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
make_const(buf->dtype, 0),
buf->elem_offset};
- return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args,
- tvm::tir::CallNode::CallType::Intrinsic);
+ return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args);
}
/*!
* \return An expression representing the invocation
*/
inline PrimExpr call_packed(Array<PrimExpr> args) {
- return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args,
- tvm::tir::CallNode::CallType::Intrinsic);
+ return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args);
}
} // namespace detail
return compute(
x->shape,
[&](const Array<Var>& i) {
- return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)},
- tvm::tir::CallNode::PureIntrinsic);
+ return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)});
},
name, tag);
}
cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_)
else:
cnts = tvm.tir.popcount(w_ & x_)
- upper_half = tvm.tir.call_pure_intrin(
+ upper_half = tvm.tir.call_intrin(
half_dtype, 'tir.vectorhigh', cnts)
- lower_half = tvm.tir.call_pure_intrin(
+ lower_half = tvm.tir.call_intrin(
half_dtype, 'tir.vectorlow', cnts)
cnts8[i] = upper_half + lower_half
for i in range(m//2):
- cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_2, cnts8[i*2], cnts8[i*2+1])
+ cnts4[i] = tvm.tir.call_llvm_pure_intrin(
+ half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
- cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_2, cnts4[i*2], cnts4[i*2+1])
- cnts = tvm.tir.call_pure_intrin(
+ cnts2[i] = tvm.tir.call_llvm_pure_intrin(
+ half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1])
+ cnts = tvm.tir.call_intrin(
full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
- out = tvm.tir.call_llvm_intrin(
+ out = tvm.tir.call_llvm_pure_intrin(
return_dtype, vpadalu,
args_2, zz.vload(0, return_dtype), shifted_cnts)
else: # ki == 8
else:
cnts8[i] = tvm.tir.popcount(w_ & x_)
for i in range(m//2):
- cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_2, cnts8[i*2], cnts8[i*2+1])
+ cnts4[i] = tvm.tir.call_llvm_pure_intrin(
+ half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
- cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_2, cnts4[i*2], cnts4[i*2+1])
- cnts = tvm.tir.call_pure_intrin(
+ cnts2[i] = tvm.tir.call_llvm_pure_intrin(
+ half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1])
+ cnts = tvm.tir.call_intrin(
full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
- out = tvm.tir.call_llvm_intrin(
+ out = tvm.tir.call_llvm_pure_intrin(
return_dtype, vpadalu,
args_2, zz.vload(0, return_dtype), shifted_cnts)
irb.emit(zz.vstore(0, out))
dtype_c = '%s32x%d' % (dtype, int32_lanes)
a_int8 = ins[0].vload([0], dtype_a)
- re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'tir.reinterpret', a_int8)
+ re_int32 = tvm.tir.call_intrin('%s32' % dtype, 'tir.reinterpret', a_int8)
# broadcast a
vec_ai32 = re_int32.astype(dtype_c)
- vec_a = tvm.tir.call_pure_intrin(dtype_b, 'tir.reinterpret', vec_ai32)
+ vec_a = tvm.tir.call_intrin(dtype_b, 'tir.reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], dtype_b)
vec_c = outs[0].vload([0], dtype_c)
inst = 'udot' if dtype == 'uint' else 'sdot'
inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
inst, int32_lanes, int32_lanes * num_int8_elements)
- vdot = tvm.tir.call_llvm_intrin(dtype_c,
- inst,
- tvm.tir.const(2, 'uint32'),
- vec_c, vec_a, vec_b)
+ vdot = tvm.tir.call_llvm_pure_intrin(
+ dtype_c,
+ inst,
+ tvm.tir.const(2, 'uint32'),
+ vec_c, vec_a, vec_b)
ib.emit(outs[0].vstore(0, vdot))
return ib.get()
tvm.target.intrin.register_intrin_rule(
"cuda", "atomic_add", cuda_atomic_add_rule, override=True)
-tvm.ir.register_op_attr("tir.atomic_add", "TVectorizable", False)
+tvm.ir.register_op_attr("tir.atomic_add", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
def atomic_add(x, y):
- return tvm.tir.call_pure_intrin(y.dtype, "tir.atomic_add", x, y)
+ return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
def get_valid_counts_ir(data, valid_count, out, out_indices,
with ib.if_scope(
tvm.tir.all(data[tid * elem_length + score_index] > score_threshold,
tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
- atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tir.address_of",
+ atomic_add_return[0] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of",
valid_count[i]), one_count)
with ib.for_range(0, elem_length) as k:
out[tid * elem_length + k] = data[tid * elem_length + k]
index_out[offset] = index_out[offset + 1]
index_out[offset + 1] = temp_index[0]
ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
- tvm.runtime.convert(['shared']),
- tvm.tir.Call.Intrinsic))
+ tvm.runtime.convert(['shared'])))
return ib.get()
with ib.if_scope(iou > nms_threshold):
p_out[base_idx + i] = True
ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
- tvm.runtime.convert(['shared']),
- tvm.tir.Call.Intrinsic))
+ tvm.runtime.convert(['shared'])))
return ib.get()
indices_out[base_idx + tid * axis_mul_after] = \
tvm.tir.generic.cast(tid, indices_out.dtype)
ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
- tvm.runtime.convert(['shared']),
- tvm.tir.Call.Intrinsic))
+ tvm.runtime.convert(['shared'])))
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
indices_out[offset] = indices_out[offset + axis_mul_after]
indices_out[offset + axis_mul_after] = temp_index[0]
ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
- tvm.runtime.convert(['shared']),
- tvm.tir.Call.Intrinsic))
+ tvm.runtime.convert(['shared'])))
return ib.get()
output[offset] = output[offset + axis_mul_after]
output[offset + axis_mul_after] = temp_index[0]
ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
- tvm.runtime.convert(['shared']),
- tvm.tir.Call.Intrinsic))
+ tvm.runtime.convert(['shared'])))
return ib.get()
return ib.get()
a_int8 = ins[0].vload([0], "uint8x4")
- re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8)
+ re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8)
vec_ai32 = re_int32.astype('int32x16')
- vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32)
+ vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], "int8x64")
vec_one = tvm.tir.const(1, "int16x32")
- pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
- 'llvm.x86.avx512.pmaddubs.w.512',
- tvm.tir.const(0, 'uint32'),
- vec_a, vec_b)
- quad_reduction = tvm.tir.call_llvm_intrin('int32x16',
- 'llvm.x86.avx512.pmaddw.d.512',
- tvm.tir.const(0, 'uint32'),
- pair_reduction, vec_one)
+ pair_reduction = tvm.tir.call_llvm_pure_intrin(
+ 'int16x32',
+ 'llvm.x86.avx512.pmaddubs.w.512',
+ tvm.tir.const(0, 'uint32'),
+ vec_a, vec_b)
+ quad_reduction = tvm.tir.call_llvm_pure_intrin(
+ 'int32x16',
+ 'llvm.x86.avx512.pmaddw.d.512',
+ tvm.tir.const(0, 'uint32'),
+ pair_reduction, vec_one)
if index == 0:
ib.emit(outs[0].vstore(0, quad_reduction))
else:
return ib.get()
a_int8 = ins[0].vload([0], "uint8x2")
- re_int16 = tvm.tir.call_pure_intrin('int16', 'tir.reinterpret', a_int8)
+ re_int16 = tvm.tir.call_intrin('int16', 'tir.reinterpret', a_int8)
vec_ai16 = re_int16.astype('int16x32')
- vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai16)
+ vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai16)
for i in range(4):
vec_b = ins[1].vload([i*32, 0], "int8x64")
- pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
- 'llvm.x86.avx512.pmaddubs.w.512',
- tvm.tir.const(0, 'uint32'),
- vec_a, vec_b)
+ pair_reduction = tvm.tir.call_llvm_pure_intrin(
+ 'int16x32',
+ 'llvm.x86.avx512.pmaddubs.w.512',
+ tvm.tir.const(0, 'uint32'),
+ vec_a, vec_b)
if index == 0:
ib.emit(outs[0].vstore([i*32], pair_reduction))
else:
return ib.get()
a_int8 = ins[0].vload([0], "uint8x4")
- re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8)
+ re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8)
vec_ai32 = re_int32.astype('int32x16')
vec_b = ins[1].vload([0, 0], "int8x64")
llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name)
if llvm_id != 0: # VNNI is available for current LLVM version
- vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'tir.reinterpret', vec_b)
+ vec_bi32 = tvm.tir.call_intrin('int32x16', 'tir.reinterpret', vec_b)
vec_zero = tvm.tir.const(0, "int32x16")
- quad_reduction = tvm.tir.call_llvm_intrin('int32x16',
- 'llvm.x86.avx512.vpdpbusd.512',
- tvm.tir.const(0, 'uint32'),
- vec_zero,
- vec_ai32, vec_bi32)
+ quad_reduction = tvm.tir.call_llvm_pure_intrin(
+ 'int32x16',
+ 'llvm.x86.avx512.vpdpbusd.512',
+ tvm.tir.const(0, 'uint32'),
+ vec_zero,
+ vec_ai32, vec_bi32)
else: # Fall back to the normal AVX512
- vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32)
+ vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32)
vec_one = tvm.tir.const(1, "int16x32")
- pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
- 'llvm.x86.avx512.pmaddubs.w.512',
- tvm.tir.const(0, 'uint32'),
- vec_a, vec_b)
- quad_reduction = tvm.tir.call_llvm_intrin('int32x16',
- 'llvm.x86.avx512.pmaddw.d.512',
- tvm.tir.const(0, 'uint32'),
- pair_reduction, vec_one)
+ pair_reduction = tvm.tir.call_llvm_pure_intrin(
+ 'int16x32',
+ 'llvm.x86.avx512.pmaddubs.w.512',
+ tvm.tir.const(0, 'uint32'),
+ vec_a, vec_b)
+ quad_reduction = tvm.tir.call_llvm_pure_intrin(
+ 'int32x16',
+ 'llvm.x86.avx512.pmaddw.d.512',
+ tvm.tir.const(0, 'uint32'),
+ pair_reduction, vec_one)
if index == 0:
ib.emit(outs[0].vstore(0, quad_reduction))
def mylog(x):
"""customized log intrinsic function"""
- return tvm.tir.call_pure_intrin(x.dtype, "tir.mylog", x)
+ return tvm.tir.call_intrin(x.dtype, "tir.mylog", x)
def my_cuda_mylog_rule(op):
return op
# new op registration is triggered by registering an attribute of the op
-tvm.ir.register_op_attr("tir.mylog", "TVectorizable", True)
+tvm.ir.register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
n = te.var("n")
self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp")
ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle")
self.command_handle = tvm.tir.Call(
- "handle", "tir.tvm_thread_context", [ctx],
- tvm.tir.Call.Intrinsic)
+ "handle", "tir.tvm_thread_context", [ctx])
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
self.gemm = intrin.gemm(env, env.mock_mode)
# register a dummy into to trigger registration of the ops
# change the info to lowering rule later.
-tvm.ir.register_op_attr("tir.vta.coproc_sync", "TVectorizable", False)
-tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TVectorizable", False)
-tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TVectorizable", False)
+tvm.ir.register_op_attr("tir.vta.coproc_sync", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
+tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
+tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
+tvm.ir.register_op_attr("tir.vta.uop_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush")
+
tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle")
+tvm.ir.register_op_attr("tir.vta.command_handle", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
def _init_env():
if _match_pragma(stmt, "coproc_sync"):
success[0] = True
sync = tvm.tir.Call(
- "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic)
+ "int32", "vta.coproc_sync", [])
return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
if _match_pragma(stmt, "trim_loop"):
op = stmt.body