[TIR][OP][API-CHANGE] Remove CallNode.call_type in favor of attribute. (#5937)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 27 Jun 2020 17:54:26 +0000 (10:54 -0700)
committerGitHub <noreply@github.com>
Sat, 27 Jun 2020 17:54:26 +0000 (10:54 -0700)
This is a followup refactor for tir::Call.
Now that we have switched call->name to call->op, the function effect property
can be registered through the op itself, so we no longer need the call_type in the CallNode.

- Introduce CallEffectKind to provide a more fine grained categorization of calls.
- Introduce call_pure_extern and call_llvm_pure_intrin to
  allow us to indicate pure calls in those cases.
- Migrate existing usecases to the new API.

82 files changed:
include/tvm/tir/builtin.h
include/tvm/tir/expr.h
include/tvm/tir/op.h
include/tvm/tir/op_attr_types.h
python/tvm/te/hybrid/calls.py
python/tvm/tir/__init__.py
python/tvm/tir/expr.py
python/tvm/tir/ir_builder.py
python/tvm/tir/op.py
src/arith/ir_mutator_with_analyzer.cc
src/arith/pattern_match.h
src/contrib/hybrid/codegen_hybrid.cc
src/printer/tir_text_printer.cc
src/target/intrin_rule.cc
src/target/intrin_rule.h
src/target/llvm/codegen_arm.cc
src/target/llvm/codegen_llvm.cc
src/target/llvm/codegen_llvm.h
src/target/llvm/codegen_x86_64.cc
src/target/llvm/intrin_rule_llvm.h
src/target/llvm/intrin_rule_nvptx.cc
src/target/llvm/intrin_rule_rocm.cc
src/target/source/codegen_c.cc
src/target/source/codegen_c.h
src/target/source/intrin_rule_aocl.cc
src/target/source/intrin_rule_cuda.cc
src/target/source/intrin_rule_metal.cc
src/target/source/intrin_rule_opencl.cc
src/target/source/intrin_rule_vhls.cc
src/target/spirv/codegen_spirv.cc
src/target/spirv/intrin_rule_spirv.cc
src/te/autodiff/jacobian.cc
src/te/operation/compute_op.cc
src/te/operation/cross_thread_reduction.cc
src/te/operation/extern_op.cc
src/te/operation/tensor_compute_op.cc
src/te/operation/tensorize.cc
src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
src/tir/analysis/side_effect.cc
src/tir/ir/buffer.cc
src/tir/ir/expr.cc
src/tir/ir/expr_functor.cc
src/tir/ir/stmt.cc
src/tir/op/builtin.cc
src/tir/op/op.cc
src/tir/op/runtime.cc
src/tir/transforms/arg_binder.cc
src/tir/transforms/bf16_legalize.cc
src/tir/transforms/coproc_sync.cc
src/tir/transforms/inject_virtual_thread.cc
src/tir/transforms/ir_util.h
src/tir/transforms/lower_intrin.cc
src/tir/transforms/lower_thread_allreduce.cc
src/tir/transforms/lower_tvm_builtin.cc
src/tir/transforms/lower_warp_memory.cc
src/tir/transforms/make_packed_api.cc
src/tir/transforms/rewrite_unsafe_select.cc
src/tir/transforms/split_host_device.cc
src/tir/transforms/storage_flatten.cc
src/tir/transforms/storage_rewrite.cc
src/tir/transforms/thread_storage_sync.cc
src/tir/transforms/vectorize_loop.cc
tests/cpp/ir_functor_test.cc
tests/python/unittest/test_arith_canonical_simplify.py
tests/python/unittest/test_target_codegen_c_host.py
tests/python/unittest/test_target_codegen_llvm.py
tests/python/unittest/test_tir_constructor.py
tests/python/unittest/test_tir_nodes.py
tests/python/unittest/test_tir_transform_bf16_legalize.py
tests/python/unittest/test_tir_transform_combine_context_call.py
tests/python/unittest/test_tir_transform_inject_virtual_thread.py
topi/include/topi/detail/extern.h
topi/include/topi/elemwise.h
topi/python/topi/arm_cpu/bitserial_conv2d.py
topi/python/topi/arm_cpu/tensor_intrin.py
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/rcnn/proposal.py
topi/python/topi/cuda/sort.py
topi/python/topi/x86/tensor_intrin.py
tutorials/language/intrin_math.py
vta/python/vta/environment.py
vta/python/vta/transform.py

index 96526cc..464ce6c 100644 (file)
@@ -153,10 +153,24 @@ TVM_DLL const Op& fma();
 TVM_DLL const Op& call_extern();
 
 /*!
+ * \brief Call an pure extern C function with given name
+ *        and signature from the types of args in the runtime environment.
+ *
+ *  Type call_pure_extern(name, args...) {
+ *     return dlsym(name)(args...);
+ *  }
+ *
+ * \note This intrinsic does not provide any type checking,
+ *       and is main used for backward compatibility reasons.
+ *       Always consider use pre-registered and typed tvm::Op first.
+ */
+TVM_DLL const Op& call_pure_extern();
+
+/*!
  * \brief Call an LLVM intrinsic with a given intrinsic id
  *        and signature from the types of args in the runtime environment.
  *
- *  Type call_llvm_intrin(intrin_id, args...) {
+ *  Type call_llvm_pure_intrin(intrin_id, args...) {
  *     return dlsym(name)(args...);
  *  }
  *
@@ -165,15 +179,27 @@ TVM_DLL const Op& call_extern();
 TVM_DLL const Op& call_llvm_intrin();
 
 /*!
- * \brief Call an SPIRV GLSL450 intrinsic.
+ * \brief Call an LLVM pure intrinsic with a given intrinsic id
+ *        and signature from the types of args in the runtime environment.
+ *
+ *  Type call_llvm_pure_intrin(intrin_id, args...) {
+ *     return dlsym(name)(args...);
+ *  }
+ *
+ * \note This op does not provide any type checking.
+ */
+TVM_DLL const Op& call_llvm_pure_intrin();
+
+/*!
+ * \brief Call an SPIRV pure GLSL450 intrinsic.
  *
- *  Type call_spirv_glsl450(intrin_id, args...) {
+ *  Type call_spirv_pure_glsl450(intrin_id, args...) {
  *     return dlsym(name)(args...);
  *  }
  *
  * \note This op does not provide any type checking.
  */
-TVM_DLL const Op& call_spirv_glsl450();
+TVM_DLL const Op& call_spirv_pure_glsl450();
 
 // TODO(tvm-team) revisit the builtins below
 // some of them can simply become ops with special codegen attr.
index f0e6d89..100d163 100644 (file)
@@ -875,19 +875,6 @@ class Let : public PrimExpr {
  */
 class CallNode : public PrimExprNode {
  public:
-  /*! \brief Possible types of calls. */
-  enum CallType : int {
-    /*! \brief Extern "C" function. */
-    Extern = 0,
-    /*! \brief Extern CXX function. */
-    ExternCPlusPlus = 1,
-    /*! \brief Extern "C" without side-effect. */
-    PureExtern = 2,
-    /*! \brief Intrinsic functions. */
-    Intrinsic = 4,
-    /*! \brief Intrinsic functions that are pure. */
-    PureIntrinsic = 5
-  };
   /*!
    * \brief The operator(function) being invoked
    *
@@ -898,31 +885,22 @@ class CallNode : public PrimExprNode {
 
   /*! \brief The arguments. */
   Array<PrimExpr> args;
-  /*! \brief Type of calls. */
-  CallType call_type;
-
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
     v->Visit("op", &op);
     v->Visit("args", &args);
-    v->Visit("call_type", &call_type);
   }
 
   bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
-    return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) &&
-           equal(call_type, other->call_type);
+    return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
     hash_reduce(dtype);
     hash_reduce(op);
     hash_reduce(args);
-    hash_reduce(call_type);
   }
 
-  /*! \return Whether call node is pure. */
-  bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }
-
   static constexpr const char* _type_key = "tir.Call";
   TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
 };
@@ -933,9 +911,7 @@ class CallNode : public PrimExprNode {
  */
 class Call : public PrimExpr {
  public:
-  using CallType = CallNode::CallType;
-
-  TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type);
+  TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args);
   TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
 };
 
index 09eb33c..34cb52f 100644 (file)
@@ -553,10 +553,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
 
 // Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
-  inline PrimExpr OpName(PrimExpr x) {                                  \
-    static const Op& op = Op::Get("tir." #OpName);                      \
-    return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \
+#define TVM_DECLARE_INTRIN_UNARY(OpName)           \
+  inline PrimExpr OpName(PrimExpr x) {             \
+    static const Op& op = Op::Get("tir." #OpName); \
+    return tir::Call(x.dtype(), op, {x});          \
   }
 
 TVM_DECLARE_INTRIN_UNARY(exp);
@@ -583,10 +583,10 @@ TVM_DECLARE_INTRIN_UNARY(acosh);
 TVM_DECLARE_INTRIN_UNARY(asinh);
 TVM_DECLARE_INTRIN_UNARY(atanh);
 
-#define TVM_DECLARE_INTRIN_BINARY(OpName)                                  \
-  inline PrimExpr OpName(PrimExpr x, PrimExpr y) {                         \
-    static const Op& op = Op::Get("tir." #OpName);                         \
-    return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); \
+#define TVM_DECLARE_INTRIN_BINARY(OpName)          \
+  inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
+    static const Op& op = Op::Get("tir." #OpName); \
+    return tir::Call(x.dtype(), op, {x, y});       \
   }
 
 TVM_DECLARE_INTRIN_BINARY(atan2);
index d7c1350..ec7fc17 100644 (file)
@@ -43,6 +43,43 @@ using TGlobalSymbol = String;
  */
 using TVectorizable = bool;
 
+/*!
+ * \brief The effect type of the call.
+ */
+enum class CallEffectKind : int {
+  /*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */
+  kExprAnnotation = 0,
+  /*!
+   * \brief Pure function that do not interacts
+   *        with any external state.
+   */
+  kPure = 1,
+  /*!
+   * \brief Function's that may read from states(e.g. RAM)
+   */
+  kReadState = 2,
+  /*!
+   * \brief Function that may read/write from states(e.g. RAM).
+   */
+  kUpdateState = 3,
+  /*!
+   * \brief Opaque function, cannot make any assumption
+   */
+  kOpaque = kUpdateState,
+  /*!
+   * \brief Special intrinsic to annotate call arguments info
+   *        only valid as a direct argument to a call.
+   */
+  kSpecialCallArg = 4,
+  /*!
+   * \brief Embed opaque information in the Expr, cannot be codegen.
+   */
+  kEmbedInfo = 5
+};
+
+/*! \brief Use integer to record the kind. */
+using TCallEffectKind = Integer;
+
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_OP_ATTR_TYPES_H_
index a119c20..78ed1dc 100644 (file)
@@ -22,7 +22,7 @@ import tvm.te
 from tvm.ir.container import Array
 from tvm import target as _tgt
 from tvm.tir import expr as _expr
-from tvm.tir import call_pure_intrin
+from tvm.tir import call_intrin
 from tvm.tir.stmt import For
 
 from .util import _internal_assert
@@ -148,7 +148,7 @@ def likely(func_id, args):
     _internal_assert(args.__len__() == 1, \
                      "Only one expression can be likely")
     _internal_assert(func_id == "likely", "This function cannot be directly invoked!")
-    return call_pure_intrin(args[0].dtype, 'tir.likely', *args)
+    return call_intrin(args[0].dtype, 'tir.likely', *args)
 
 
 def max_num_threads(func_id, args):
index 90ccde4..9dbdc07 100644 (file)
@@ -24,8 +24,8 @@ from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
 from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
 from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
 from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
-from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
-from .expr import IterVar, Any
+from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
+from .expr import Call, CallEffectKind, Let, IterVar, Any
 
 from .stmt import Stmt, LetStmt, AssertStmt, For
 from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
@@ -34,8 +34,8 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
 
 from .function import PrimFunc
 
-from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, all, any, min_value, max_value, trace
+from .op import call_packed, call_intrin, call_pure_extern, call_extern
+from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
 from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
 from .op import sin, sinh, asin, asinh
 from .op import cos, cosh, acos, acosh
index 386badf..c8f151e 100644 (file)
@@ -266,6 +266,23 @@ class NotEqualOp(ObjectGeneric, ExprOp):
         return _ffi_api._OpNE(self.a, self.b)
 
 
+class IntImmEnum(ObjectGeneric):
+    """Lazily evaluate an IntImm in case
+    the constructor is not available in runtime.
+
+    Parameters
+    ----------
+    value : int
+        The enum value
+    """
+    def __init__(self, value):
+        self.value = value
+
+    def asobject(self):
+        """Convert object."""
+        return IntImm("int32", self.value)
+
+
 class PrimExprWithOp(ExprOp, PrimExpr):
     """Helper base class to inherit from PrimExpr."""
     # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
@@ -959,6 +976,16 @@ class Shuffle(PrimExprWithOp):
             _ffi_api.Shuffle, vectors, indices)
 
 
+class CallEffectKind:
+    """Possible kinds of Call effects."""
+    # only expose up to opaque
+    ExprAnnotation = IntImmEnum(0)
+    Pure = IntImmEnum(1)
+    ReadState = IntImmEnum(2)
+    UpdateState = IntImmEnum(3)
+    Opaque = UpdateState
+
+
 @tvm._ffi.register_object("tir.Call")
 class Call(PrimExprWithOp):
     """Call node.
@@ -974,16 +1001,8 @@ class Call(PrimExprWithOp):
 
     args : list of Expr
         The input arguments to the call
-
-    call_type : int
-        The type of the call
     """
-    Extern = 0
-    ExternCPlusPlus = 1
-    PureExtern = 2
-    Intrinsic = 4
-    PureIntrinsic = 5
-    def __init__(self, dtype, op, args, call_type):
+    def __init__(self, dtype, op, args):
         if isinstance(op, str):
             if not op.startswith("tir."):
                 raise ValueError(
@@ -992,7 +1011,7 @@ class Call(PrimExprWithOp):
                      "certain about the intrinsic name, pass in Op.get(name) instead") % op)
             op = Op.get(op)
         self.__init_handle_by_constructor__(
-            _ffi_api.Call, dtype, op, args, call_type)
+            _ffi_api.Call, dtype, op, args)
 
 
 @tvm._ffi.register_object("tir.Let")
index 089127c..20180d1 100644 (file)
@@ -379,8 +379,7 @@ class IRBuilder(object):
         expr : Expr
             The expression will likely tag.
         """
-        return _expr.Call(expr.dtype, "tir.likely", [expr],
-                          _expr.Call.PureIntrinsic)
+        return _expr.Call(expr.dtype, "tir.likely", [expr])
 
     def get(self):
         """Return the builded IR.
index 6826241..cbbd59f 100644 (file)
@@ -29,10 +29,8 @@ def _pack_buffer(buf):
     """Build intrinsics that packs the buffer.
     """
     assert buf.shape
-    shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape,
-                 Call.Intrinsic)
-    strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides,
-                   Call.Intrinsic) if buf.strides else 0
+    shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape)
+    strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0
     pack_args = [buf.data,
                  shape,
                  strides,
@@ -40,7 +38,7 @@ def _pack_buffer(buf):
                  const(0, dtype=buf.dtype),
                  buf.elem_offset]
     return Call("handle", Op.get("tir.tvm_stack_make_array"),
-                pack_args, Call.Intrinsic)
+                pack_args)
 
 def call_packed(*args):
     """Build expression by call an external packed function.
@@ -68,11 +66,11 @@ def call_packed(*args):
     """
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
     return Call(
-        "int32", Op.get("tir.tvm_call_packed"), call_args, Call.Intrinsic)
+        "int32", Op.get("tir.tvm_call_packed"), call_args)
 
 
-def call_pure_intrin(dtype, func_name, *args):
-    """Build expression by calling a pure intrinsic function.
+def call_intrin(dtype, func_name, *args):
+    """Build expression by calling an intrinsic function.
 
     Intrinsics can be overloaded with multiple data types via
     the intrinsic translation rule.
@@ -93,16 +91,12 @@ def call_pure_intrin(dtype, func_name, *args):
     call : PrimExpr
         The call expression.
     """
-    args = convert(args)
     return Call(
-        dtype, func_name, convert(args), Call.PureIntrinsic)
+        dtype, func_name, convert(args))
 
 
-def call_intrin(dtype, func_name, *args):
-    """Build expression by calling an intrinsic function.
-
-    Intrinsics can be overloaded with multiple data types via
-    the intrinsic translation rule.
+def call_pure_extern(dtype, func_name, *args):
+    """Build expression by calling a pure extern function.
 
     Parameters
     ----------
@@ -110,7 +104,7 @@ def call_intrin(dtype, func_name, *args):
         The data type of the result.
 
     func_name: str
-        The intrinsic function name.
+        The extern function name.
 
     args : list
         Positional arguments.
@@ -120,13 +114,12 @@ def call_intrin(dtype, func_name, *args):
     call : PrimExpr
         The call expression.
     """
-    args = convert(args)
     return Call(
-        dtype, func_name, convert(args), Call.Intrinsic)
+        dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args))
 
 
-def call_pure_extern(dtype, func_name, *args):
-    """Build expression by calling a pure extern function.
+def call_extern(dtype, func_name, *args):
+    """Build expression by calling a extern function.
 
     Parameters
     ----------
@@ -145,34 +138,39 @@ def call_pure_extern(dtype, func_name, *args):
         The call expression.
     """
     return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.PureExtern)
+        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args))
 
 
-def call_extern(dtype, func_name, *args):
-    """Build expression by calling a extern function.
+def call_llvm_intrin(dtype, name, *args):
+    """Build expression by calling a llvm intrinsic function
 
     Parameters
     ----------
     dtype : str
-        The data type of the result.
+       The data type of the result.
 
-    func_name: str
-        The extern function name.
+    name : str
+       The name of the llvm intrinsic function.
 
     args : list
-        Positional arguments.
+       Poistional arguments.
 
     Returns
     -------
     call : PrimExpr
         The call expression.
     """
-    return Call(
-        dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.Extern)
+    # pylint: disable=import-outside-toplevel
+    from tvm.target import codegen
+    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    return call_intrin(
+        dtype, Op.get("tir.call_llvm_intrin"),
+        tvm.tir.const(llvm_id, 'uint32'), *args)
 
 
-def call_llvm_intrin(dtype, name, *args):
-    """Build expression by calling an llvm intrinsic function
+def call_llvm_pure_intrin(dtype, name, *args):
+    """Build expression by calling a pure llvm intrinsic function
 
     Parameters
     ----------
@@ -194,8 +192,9 @@ def call_llvm_intrin(dtype, name, *args):
     from tvm.target import codegen
     llvm_id = codegen.llvm_lookup_intrinsic_id(name)
     assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
-    return call_pure_intrin(dtype, Op.get("tir.call_llvm_intrin"),
-                            tvm.tir.const(llvm_id, 'uint32'), *args)
+    return call_intrin(
+        dtype, Op.get("tir.call_llvm_pure_intrin"),
+        tvm.tir.const(llvm_id, 'uint32'), *args)
 
 
 def any(*args):
@@ -279,7 +278,7 @@ def trace(args, trace_action="tvm.default_trace_action"):
     call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
     call_args.insert(0, trace_action)
     return tvm.tir.Call(
-        args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args, tvm.tir.Call.Intrinsic)
+        args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args)
 
 
 
@@ -328,7 +327,7 @@ def exp(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.exp", x)
+    return call_intrin(x.dtype, "tir.exp", x)
 
 
 def exp2(x):
@@ -344,7 +343,7 @@ def exp2(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.exp2", x)
+    return call_intrin(x.dtype, "tir.exp2", x)
 
 
 def exp10(x):
@@ -360,7 +359,7 @@ def exp10(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.exp10", x)
+    return call_intrin(x.dtype, "tir.exp10", x)
 
 
 def erf(x):
@@ -376,7 +375,7 @@ def erf(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.erf", x)
+    return call_intrin(x.dtype, "tir.erf", x)
 
 
 def tanh(x):
@@ -392,7 +391,7 @@ def tanh(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.tanh", x)
+    return call_intrin(x.dtype, "tir.tanh", x)
 
 
 def sigmoid(x):
@@ -408,7 +407,7 @@ def sigmoid(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.sigmoid", x)
+    return call_intrin(x.dtype, "tir.sigmoid", x)
 
 
 def log(x):
@@ -424,7 +423,7 @@ def log(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.log", x)
+    return call_intrin(x.dtype, "tir.log", x)
 
 
 def log2(x):
@@ -440,7 +439,7 @@ def log2(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.log2", x)
+    return call_intrin(x.dtype, "tir.log2", x)
 
 
 def log10(x):
@@ -456,7 +455,7 @@ def log10(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.log10", x)
+    return call_intrin(x.dtype, "tir.log10", x)
 
 
 def log1p(x):
@@ -472,7 +471,7 @@ def log1p(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.log1p", x)
+    return call_intrin(x.dtype, "tir.log1p", x)
 
 
 def tan(x):
@@ -488,7 +487,7 @@ def tan(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.tan", x)
+    return call_intrin(x.dtype, "tir.tan", x)
 
 
 def cos(x):
@@ -504,7 +503,7 @@ def cos(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.cos", x)
+    return call_intrin(x.dtype, "tir.cos", x)
 
 
 def cosh(x):
@@ -520,7 +519,7 @@ def cosh(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.cosh", x)
+    return call_intrin(x.dtype, "tir.cosh", x)
 
 
 def acos(x):
@@ -536,7 +535,7 @@ def acos(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.acos", x)
+    return call_intrin(x.dtype, "tir.acos", x)
 
 
 def acosh(x):
@@ -552,7 +551,7 @@ def acosh(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.acosh", x)
+    return call_intrin(x.dtype, "tir.acosh", x)
 
 
 def sin(x):
@@ -568,7 +567,7 @@ def sin(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.sin", x)
+    return call_intrin(x.dtype, "tir.sin", x)
 
 
 def sinh(x):
@@ -584,7 +583,7 @@ def sinh(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.sinh", x)
+    return call_intrin(x.dtype, "tir.sinh", x)
 
 
 def asin(x):
@@ -600,7 +599,7 @@ def asin(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.asin", x)
+    return call_intrin(x.dtype, "tir.asin", x)
 
 
 def asinh(x):
@@ -616,7 +615,7 @@ def asinh(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.asinh", x)
+    return call_intrin(x.dtype, "tir.asinh", x)
 
 
 def atan(x):
@@ -632,7 +631,7 @@ def atan(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.atan", x)
+    return call_intrin(x.dtype, "tir.atan", x)
 
 
 def atanh(x):
@@ -648,7 +647,7 @@ def atanh(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.atanh", x)
+    return call_intrin(x.dtype, "tir.atanh", x)
 
 
 def atan2(x1, x2):
@@ -667,7 +666,7 @@ def atan2(x1, x2):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x1.dtype, "tir.atan2", x1, x2)
+    return call_intrin(x1.dtype, "tir.atan2", x1, x2)
 
 
 def sqrt(x):
@@ -683,7 +682,7 @@ def sqrt(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.sqrt", x)
+    return call_intrin(x.dtype, "tir.sqrt", x)
 
 
 def rsqrt(x):
@@ -699,7 +698,7 @@ def rsqrt(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.rsqrt", x)
+    return call_intrin(x.dtype, "tir.rsqrt", x)
 
 
 def floor(x):
@@ -824,7 +823,7 @@ def nextafter(x1, x2):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x1.dtype, "tir.nextafter", x1, x2)
+    return call_intrin(x1.dtype, "tir.nextafter", x1, x2)
 
 
 def hypot(x1, x2):
@@ -843,7 +842,7 @@ def hypot(x1, x2):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x1.dtype, "tir.hypot", x1, x2)
+    return call_intrin(x1.dtype, "tir.hypot", x1, x2)
 
 
 def copysign(x1, x2):
@@ -862,7 +861,7 @@ def copysign(x1, x2):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x1.dtype, "tir.copysign", x1, x2)
+    return call_intrin(x1.dtype, "tir.copysign", x1, x2)
 
 
 def ldexp(x1, x2):
@@ -881,7 +880,7 @@ def ldexp(x1, x2):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x1.dtype, "tir.ldexp", x1, x2)
+    return call_intrin(x1.dtype, "tir.ldexp", x1, x2)
 
 
 def isnan(x):
@@ -964,7 +963,7 @@ def popcount(x):
     y : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.popcount", x)
+    return call_intrin(x.dtype, "tir.popcount", x)
 
 def fmod(x, y):
     """Return the remainder of x divided by y with the same sign as x.
@@ -981,7 +980,7 @@ def fmod(x, y):
     z : PrimExpr
         The result.
     """
-    return call_pure_intrin(x.dtype, "tir.fmod", x, y)
+    return call_intrin(x.dtype, "tir.fmod", x, y)
 
 
 def if_then_else(cond, t, f):
index c367d0c..259fcd9 100644 (file)
@@ -146,7 +146,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
         false_value.same_as(op->args[2])) {
       return GetRef<PrimExpr>(op);
     } else {
-      return Call(op->dtype, op->op, {cond, true_value, false_value}, op->call_type);
+      return Call(op->dtype, op->op, {cond, true_value, false_value});
     }
   }
   return StmtExprMutator::VisitExpr_(op);
index de84251..81a4d61 100644 (file)
@@ -679,7 +679,7 @@ class PCallExpr : public Pattern<PCallExpr<Op, TArgs...>> {
 #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName)                         \
   struct OpName {                                                                         \
     static PrimExpr Eval(Array<PrimExpr> args) {                                          \
-      return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic);     \
+      return tir::Call(args[0].dtype(), GetOp(), args);                                   \
     }                                                                                     \
     static const Op& GetOp() { return tir::builtin::IntrinOpName(); }                     \
   };                                                                                      \
@@ -695,25 +695,23 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or);
 TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor);
 
 // unary intrinsics
-#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName)                      \
-  struct OpName {                                                                     \
-    static PrimExpr Eval(Array<PrimExpr> args) {                                      \
-      return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \
-    }                                                                                 \
-    static const Op& GetOp() { return tir::builtin::IntrinOpName(); }                 \
-  };                                                                                  \
-  template <typename TA>                                                              \
-  inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) {                       \
-    return PCallExpr<OpName, TA>(a.derived());                                        \
+#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName)      \
+  struct OpName {                                                     \
+    static PrimExpr Eval(Array<PrimExpr> args) {                      \
+      return tir::Call(args[0].dtype(), GetOp(), args);               \
+    }                                                                 \
+    static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
+  };                                                                  \
+  template <typename TA>                                              \
+  inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) {       \
+    return PCallExpr<OpName, TA>(a.derived());                        \
   }
 
 TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not);
 
 // if_then_else
 struct PIfThenElseOp {
-  static PrimExpr Eval(Array<PrimExpr> args) {
-    return tir::Call(args[1].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic);
-  }
+  static PrimExpr Eval(Array<PrimExpr> args) { return tir::Call(args[1].dtype(), GetOp(), args); }
   static const Op& GetOp() { return tir::builtin::if_then_else(); }
 };
 
index 0d5d654..b65ae91 100644 (file)
@@ -238,7 +238,8 @@ void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLIN
     PrintExpr(op->args[0], os);
     os << " else ";
     PrintExpr(op->args[2], os);
-  } else if (op->op.same_as(builtin::call_extern())) {
+  } else if (op->op.same_as(builtin::call_pure_extern()) ||
+             op->op.same_as(builtin::call_extern())) {
     StringImm fname = Downcast<StringImm>(op->args[0]);
     os << fname << "(";
     for (size_t i = 1; i < op->args.size(); i++) {
index a11de01..7ab26fa 100644 (file)
@@ -326,23 +326,6 @@ Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
   return doc;
 }
 
-inline const char* CallType2String(CallNode::CallType t) {
-  switch (t) {
-    case CallNode::Extern:
-      return "extern";
-    case CallNode::ExternCPlusPlus:
-      return "extern_cpp";
-    case CallNode::PureExtern:
-      return "pure_extern";
-    case CallNode::Intrinsic:
-      return "intrin";
-    case CallNode::PureIntrinsic:
-      return "pure_intrin";
-  }
-  LOG(FATAL) << "Unknown CallType";
-  return "Unknown";
-}
-
 Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
   Doc doc;
   if (auto* ptr_op = op->op.as<OpNode>()) {
@@ -357,8 +340,7 @@ Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
   for (const auto& arg : op->args) {
     args.push_back(Print(arg));
   }
-  doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype)
-      << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) << ")";
+  doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")";
   return doc;
 }
 
index 37855fb..31fadf1 100644 (file)
@@ -29,53 +29,53 @@ namespace tvm {
 namespace codegen {
 namespace intrin {
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchPureExtern<FloatSuffix>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchPureExtern<FloatSuffix>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
     .set_body([](const TVMArgs& args, TVMRetValue* rv) {
@@ -87,7 +87,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
       *rv = one / sqrt(call->args[0]);
     });
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchPureExtern<FloatSuffix>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
     .set_body([](const TVMArgs& args, TVMRetValue* rv) {
index 36e5539..359c5b9 100644 (file)
@@ -55,7 +55,7 @@ struct Direct {
 
 // Call pure extern function.
 template <typename T>
-inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
+inline void DispatchPureExtern(const TVMArgs& args, TVMRetValue* rv) {
   PrimExpr e = args[0];
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
@@ -72,7 +72,7 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
     for (auto arg : call->args) {
       new_args.push_back(arg);
     }
-    *rv = Call(call->dtype, tir::builtin::call_extern(), new_args, CallNode::PureExtern);
+    *rv = Call(call->dtype, tir::builtin::call_pure_extern(), new_args);
   } else {
     *rv = e;
   }
index 13ce59d..5e5a94b 100644 (file)
@@ -46,7 +46,7 @@ class CodeGenARM final : public CodeGenCPU {
 };
 
 llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
-  if (op->op.same_as(builtin_call_llvm_intrin_)) {
+  if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
     if (id == ::llvm::Intrinsic::ctpop) {
       PrimExpr e = ARMPopcount(op);
@@ -70,7 +70,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
     vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
     vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
     vcnt_args.push_back(e);
-    return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt_args, CallNode::PureIntrinsic);
+    return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args);
   }
 
   // Popcount lowering rule:
@@ -94,16 +94,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
   vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
   vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt8_args.push_back(input8);
-  PrimExpr vcnt8 =
-      tir::Call(uint8_type, builtin_call_llvm_intrin_, vcnt8_args, CallNode::PureIntrinsic);
+  PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args);
 
   // Accumulation 8->16bit
   Array<PrimExpr> vcnt16_args;
   vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
   vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt16_args.push_back(vcnt8);
-  PrimExpr vcnt16 =
-      tir::Call(uint16_type, builtin_call_llvm_intrin_, vcnt16_args, CallNode::PureIntrinsic);
+  PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args);
   if (call->dtype.bits() == 16) {
     return vcnt16;
   }
@@ -113,8 +111,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
   vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
   vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt32_args.push_back(vcnt16);
-  PrimExpr vcnt32 =
-      tir::Call(uint32_type, builtin_call_llvm_intrin_, vcnt32_args, CallNode::PureIntrinsic);
+  PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args);
   if (call->dtype.bits() == 32) {
     return vcnt32;
   }
@@ -124,7 +121,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
   vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
   vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt64_args.push_back(vcnt32);
-  return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt64_args, CallNode::PureIntrinsic);
+  return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args);
 }
 
 TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
index 49f14c3..99a23c6 100644 (file)
@@ -738,7 +738,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
 }
 
 llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
-  if (op->op.same_as(builtin_call_llvm_intrin_)) {
+  if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
     CHECK_GE(op->args.size(), 2U);
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
     int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
@@ -1077,7 +1077,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
   if (auto* ptr_op = op->op.as<OpNode>()) {
     auto call_op = GetRef<Op>(ptr_op);
-    if (op->op.same_as(builtin_call_extern_)) {
+    if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
       // call extern intrinsic
       CHECK_GE(op->args.size(), 1U);
       auto global_symbol = Downcast<StringImm>(op->args[0]);
index 2bfe047..9e7b56a 100644 (file)
@@ -326,7 +326,10 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
   // global symbol table.
   OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
   const Op& builtin_call_extern_ = builtin::call_extern();
+  const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
   const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin();
+  const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin();
+
   /*! \brief Helper struct for debug infos. */
   struct DebugInfo {
     std::unique_ptr<llvm::DIBuilder> di_builder_;
index 5d269fa..6f3d4f7 100644 (file)
@@ -90,7 +90,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
           DTypeToLLVMType(DataType::Float(32, from.lanes())),
           {
               MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(),
-                                  {op->value}, tir::CallNode::PureIntrinsic)),
+                                  {op->value})),
               MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())),
               /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
               /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
@@ -102,11 +102,10 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
     const auto has_f16c = TargetHasFeature(*target_machine_, "f16c");
 
     if (from.lanes() >= 8 && has_f16c) {
-      return CallVectorIntrin(
-          ::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
-          DTypeToLLVMType(DataType::Float(32, from.lanes())),
-          {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(),
-                               {op->value}, tir::CallNode::PureIntrinsic))});
+      return CallVectorIntrin(::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
+                              DTypeToLLVMType(DataType::Float(32, from.lanes())),
+                              {MakeValue(tir::Call(DataType::Int(16, from.lanes()),
+                                                   tir::builtin::reinterpret(), {op->value}))});
     }
 #endif
   }
index cc9437d..1a6775e 100644 (file)
@@ -50,8 +50,7 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
-  *rv =
-      tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::PureIntrinsic);
+  *rv = tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs);
 }
 
 template <unsigned id, int num_signature>
@@ -66,7 +65,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
-  *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::Intrinsic);
+  *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs);
 }
 
 }  // namespace codegen
index a0ffe11..0e33294 100644 (file)
@@ -32,7 +32,7 @@
 namespace tvm {
 namespace codegen {
 
-inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
+inline void DispatchPureExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
   PrimExpr e = args[0];
   using namespace tir;
   const CallNode* call = e.as<CallNode>();
@@ -52,54 +52,54 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
   for (auto arg : call->args) {
     new_args.push_back(arg);
   }
-  *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern);
+  *rv = Call(call->dtype, builtin::call_pure_extern(), new_args);
 }
 
 namespace llvm {
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchPureExternLibDevice);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchExternLibDevice);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchPureExternLibDevice);
 
 }  // namespace llvm
 }  // namespace codegen
index 07520ae..22ebf9b 100644 (file)
@@ -32,7 +32,7 @@
 namespace tvm {
 namespace codegen {
 
-inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
+inline void DispatchPureExternOCML(const TVMArgs& args, TVMRetValue* rv) {
   PrimExpr e = args[0];
   using namespace tir;
   const CallNode* call = e.as<CallNode>();
@@ -51,7 +51,7 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
     new_args.push_back(arg);
   }
 
-  *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern);
+  *rv = Call(call->dtype, builtin::call_pure_extern(), new_args);
 }
 
 inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) {
@@ -66,10 +66,10 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) {
   // get own lane in self (__lane_id)
   PrimExpr minus_one = tir::make_const(DataType::Int(32), -1);
   PrimExpr zero = tir::make_zero(DataType::Int(32));
-  PrimExpr lo = Call(DataType::Int(32), builtin::call_extern(),
-                     {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, CallNode::PureExtern);
-  PrimExpr self = Call(DataType::Int(32), builtin::call_extern(),
-                       {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, CallNode::PureExtern);
+  PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(),
+                     {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero});
+  PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(),
+                       {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo});
 
   // compute lane to get from
   PrimExpr width = call->args[3];
@@ -87,9 +87,8 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) {
     index = self + delta;
     index = Select((self & (width - 1)) + delta >= width, self, index);
   }
-  PrimExpr res =
-      Call(var.dtype(), builtin::call_extern(),
-           {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}, CallNode::PureExtern);
+  PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(),
+                      {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
   *rv = res;
 }
 
@@ -108,49 +107,49 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_up").set_body(Dispatc
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchPureExternOCML);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchExternOCML);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchPureExternOCML);
 
 }  // namespace llvm
 }  // namespace codegen
index ffeaba0..05582fb 100644 (file)
@@ -575,7 +575,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
   if (auto* ptr_op = op->op.as<OpNode>()) {
     auto call_op = GetRef<Op>(ptr_op);
 
-    if (op->op.same_as(builtin_call_extern_)) {
+    if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
       CHECK_GE(op->args.size(), 1U);
       auto func = Downcast<StringImm>(op->args[0]);
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
index 9346f87..87a4a29 100644 (file)
@@ -262,6 +262,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
   OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
   // cache commonly used ops
   const Op& builtin_call_extern_ = builtin::call_extern();
+  const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
 
  private:
   /*! \brief whether to print in SSA form */
index 0cafd02..69279a0 100644 (file)
@@ -27,49 +27,49 @@ namespace tvm {
 namespace codegen {
 namespace intrin {
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchPureExtern<Direct>);
 
 }  // namespace intrin
 }  // namespace codegen
index 53a2799..9ffceb6 100644 (file)
@@ -110,7 +110,7 @@ struct CUDAWarpIntrinsic {
 
 static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) {
   Call call = args[0];
-  *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, CallNode::PureExtern);
+  *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
 }
 
 template <typename T>
@@ -121,53 +121,52 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
   CHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
   Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
 
-  *rv =
-      Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args, CallNode::PureExtern);
+  *rv = Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
 }
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchPureExtern<CUDAFastMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchPureExtern<CUDAFastMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchPureExtern<CUDAFastMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchPureExtern<CUDAFastMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchPureExtern<CUDAFastMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchExtern<CUDAFastMathTan>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchPureExtern<CUDAFastMathTan>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchPureExtern<CUDAFastMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchExtern<CUDAFastMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchPureExtern<CUDAFastMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchPureExtern<CUDAMath>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchExtern<CUDAPopcount>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchPureExtern<CUDAPopcount>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
     .set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
@@ -181,28 +180,32 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask")
     .set_body(DispatchCUDAWarpActiveMask);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern<CUDAMath>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchPureExtern<CUDAMath>);
 
 // Register low-level builtin ops.
 // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins.
 TVM_REGISTER_OP("tir.cuda.__shfl_sync")
     .set_num_inputs(4)
     .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_sync")
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
     .set_attr<bool>("cuda.need_warp_shuffle", true);
 
 TVM_REGISTER_OP("tir.cuda.__shfl_up_sync")
     .set_num_inputs(4)
     .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_up_sync")
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
     .set_attr<bool>("cuda.need_warp_shuffle", true);
 
 TVM_REGISTER_OP("tir.cuda.__shfl_down_sync")
     .set_num_inputs(4)
     .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_down_sync")
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
     .set_attr<bool>("cuda.need_warp_shuffle", true);
 
 TVM_REGISTER_OP("tir.cuda.__activemask")
     .set_num_inputs(0)
     .set_attr<TGlobalSymbol>("TGlobalSymbol", "__activemask")
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_attr<bool>("cuda.need_warp_shuffle", true);
 
 }  // namespace intrin
index 00fb9f9..80a1031 100644 (file)
@@ -27,45 +27,45 @@ namespace tvm {
 namespace codegen {
 namespace intrin {
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchPureExtern<Direct>);
 
 }  // namespace intrin
 }  // namespace codegen
index 82eabdd..7f81e33 100644 (file)
@@ -29,45 +29,45 @@ namespace tvm {
 namespace codegen {
 namespace intrin {
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchPureExtern<Direct>);
 
 // There is no warp shuffle instruction in standard OpenCL
 // When shuffle is used, we assume it is intel's shuffle extension
@@ -80,7 +80,7 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
   CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
       << "Intel warp shuffle dose not support width != warp_size";
   Array<PrimExpr> opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}};
-  *rv = Call(call->dtype, builtin::call_extern(), opencl_args, CallNode::PureExtern);
+  *rv = Call(call->dtype, builtin::call_pure_extern(), opencl_args);
 }
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle);
index fb01d65..da9bc79 100644 (file)
@@ -27,43 +27,43 @@ namespace tvm {
 namespace codegen {
 namespace intrin {
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchPureExtern<Direct>);
 
-TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchExtern<Direct>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchPureExtern<Direct>);
 
 }  // namespace intrin
 }  // namespace codegen
index 6c12343..ff3bc7d 100644 (file)
@@ -237,7 +237,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
 }
 
 spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
-  if (op->op.same_as(builtin::call_spirv_glsl450())) {
+  if (op->op.same_as(builtin::call_spirv_pure_glsl450())) {
     CHECK_GE(op->args.size(), 2U);
     uint32_t inst_id = static_cast<uint32_t>(op->args[0].as<IntImmNode>()->value);
     std::vector<spirv::Value> values;
@@ -317,13 +317,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
     return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype),
                                MakeValue(op->args[0]));
   } else {
-    if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
-      LOG(FATAL) << "Unresolved intrinsic " << op->op << " with return type " << op->dtype;
-    } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
-      LOG(FATAL) << "Unresolved extern " << op->op << " with return type " << op->dtype;
-    } else {
-      LOG(FATAL) << "Unresolved call type " << op->call_type;
-    }
+    LOG(FATAL) << "Unresolved call  " << op->op;
     return spirv::Value();
   }
 }
index 1b9d2e4..ea575ca 100644 (file)
@@ -44,8 +44,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
-  *rv = tir::Call(call->dtype, tir::builtin::call_spirv_glsl450(), cargs,
-                  tir::CallNode::PureIntrinsic);
+  *rv = tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
 }
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
index f625412..e2479d8 100644 (file)
@@ -95,34 +95,32 @@ class JacobianMutator : public ExprMutator {
 
   PrimExpr VisitExpr_(const CallNode* op) {
     PrimExpr expr = GetRef<PrimExpr>(op);
-    if (op->call_type == CallNode::CallType::PureIntrinsic) {
-      if (op->op.same_as(op_exp_)) {
-        return Mul(Mutate(op->args[0]), expr);
-      } else if (op->op.same_as(op_log_)) {
-        return Div(Mutate(op->args[0]), op->args[0]);
-      } else if (op->op.same_as(op_sigmoid_)) {
-        return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr)));
-      } else if (op->op.same_as(op_sqrt_)) {
-        return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0)));
-      } else if (op->op.same_as(op_tanh_)) {
-        return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr)));
-      } else if (op->op.same_as(op_pow_)) {
-        auto x = op->args[0], y = op->args[1];
-        return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
-      } else if (op->op.same_as(op_fabs_)) {
-        auto type = op->args[0].dtype();
-        return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)),
-                                               FloatImm(type, 1.0), FloatImm(type, -1.0)));
-      } else if (op->op.same_as(op_if_then_else_)) {
-        Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
-        return Call(op->dtype, op->op, new_args, op->call_type);
-      } else if (piecewise_const.count(op->op)) {
-        return FloatImm(expr.dtype(), 0.0);
-      } else {
-        LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op;
-      }
+    if (op->op.same_as(op_exp_)) {
+      return Mul(Mutate(op->args[0]), expr);
+    } else if (op->op.same_as(op_log_)) {
+      return Div(Mutate(op->args[0]), op->args[0]);
+    } else if (op->op.same_as(op_sigmoid_)) {
+      return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr)));
+    } else if (op->op.same_as(op_sqrt_)) {
+      return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0)));
+    } else if (op->op.same_as(op_tanh_)) {
+      return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr)));
+    } else if (op->op.same_as(op_pow_)) {
+      auto x = op->args[0], y = op->args[1];
+      return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
+    } else if (op->op.same_as(op_fabs_)) {
+      auto type = op->args[0].dtype();
+      return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0),
+                                             FloatImm(type, -1.0)));
+    } else if (op->op.same_as(op_if_then_else_)) {
+      Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
+      return Call(op->dtype, op->op, new_args);
+    } else if (piecewise_const.count(op->op)) {
+      return FloatImm(expr.dtype(), 0.0);
+    } else {
+      LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op;
+      return PrimExpr();
     }
-    NOT_IMPLEMENTED;
   }
 
   PrimExpr VisitExpr_(const AddNode* op) { return Add(Mutate(op->a), Mutate(op->b)); }
index b4725c5..21343ec 100644 (file)
@@ -277,10 +277,9 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
         if (attr->dim_align_factor != 0) {
           Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
                                    attr->dim_align_offset};
-          realize = tir::AttrStmt(
-              t, tir::attr::buffer_dim_align,
-              Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic),
-              realize);
+          realize =
+              tir::AttrStmt(t, tir::attr::buffer_dim_align,
+                            Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), realize);
         }
       }
     }
index eeaab30..427be32 100644 (file)
@@ -196,8 +196,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   // Apply the existing input predicate if any.
   output_preds.push_back(input_pred);
 
-  Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(),
-                                   freduce_args, CallNode::Intrinsic));
+  Stmt reduce_body =
+      Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args));
   reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope,
                          make_zero(DataType::Handle()), reduce_body);
 
index 01019e4..d789938 100644 (file)
@@ -153,7 +153,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage,
       tuple.push_back(buffer->shape[k]);
     }
     ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
-                   Call(DataType::Handle(), builtin::tvm_tuple(), tuple, CallNode::Intrinsic), ret);
+                   Call(DataType::Handle(), builtin::tvm_tuple(), tuple), ret);
   };
   for (size_t i = output_placeholders.size(); i != 0; --i) {
     f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
index 714e885..f6f0058 100644 (file)
@@ -152,9 +152,9 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
       tuple.push_back(region[i]->min);
       tuple.push_back(region[i]->extent);
     }
-    input_bind_nest.emplace_back(AttrStmt(
-        bind_spec, tir::attr::buffer_bind_scope,
-        Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+    input_bind_nest.emplace_back(
+        AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+                 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
   }
 
   // output binding
@@ -176,9 +176,9 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
       }
     }
 
-    output_bind_nest.emplace_back(AttrStmt(
-        bind_spec, tir::attr::buffer_bind_scope,
-        Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+    output_bind_nest.emplace_back(
+        AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+                 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
   }
 
   // Check variable remap
index dd978a4..d48bf78 100644 (file)
@@ -368,9 +368,9 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
       tuple.push_back(r->min);
       tuple.push_back(r->extent);
     }
-    input_bind_nest.emplace_back(AttrStmt(
-        bind_spec, tir::attr::buffer_bind_scope,
-        Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+    input_bind_nest.emplace_back(
+        AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+                 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
   }
   // output binding
   const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
@@ -388,9 +388,9 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
     Tensor tensor = stage->op.output(i - intrin->inputs.size());
     Buffer buffer = intrin->buffers[i];
     Array<ObjectRef> bind_spec{buffer, tensor};
-    output_bind_nest.emplace_back(AttrStmt(
-        bind_spec, tir::attr::buffer_bind_scope,
-        Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
+    output_bind_nest.emplace_back(
+        AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+                 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
   }
   // Check variable remap
   std::unordered_map<const VarNode*, PrimExpr> vmap;
index 67121b8..be1bdd9 100644 (file)
@@ -850,14 +850,12 @@ class TensorCoreIRMutator : public StmtExprMutator {
           return Evaluate(
               Call(DataType::Handle(), builtin::tvm_bmma_sync(),
                    {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
-                    buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
-                   CallNode::Intrinsic));
+                    buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}));
         } else {
           return Evaluate(
               Call(DataType::Handle(), builtin::tvm_mma_sync(),
                    {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
-                    buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
-                   CallNode::Intrinsic));
+                    buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}));
         }
       };
 
@@ -881,8 +879,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
         auto fill_fragment_call = [this, &op](const Buffer& buffer) {
           return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(),
                                {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
-                                buffer->elem_offset, op->value},
-                               CallNode::Intrinsic));
+                                buffer->elem_offset, op->value}));
         };
 
         ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -903,8 +900,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
       ThreadIdxMutator thread_idx_mutator(warp_y);
       PrimExpr mutated_value = thread_idx_mutator(op->value);
       // TODO(tvm-team) The extern function name seems to be a hack.
-      PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value},
-                          CallNode::Extern);
+      PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value});
 
       auto pload = dst.as<ProducerLoadNode>();
       PrimExpr matrix_major;
@@ -922,8 +918,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
       auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
         return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(),
                              {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
-                              buffer->elem_offset, src, stride, matrix_major},
-                             CallNode::Intrinsic));
+                              buffer->elem_offset, src, stride, matrix_major}));
       };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -943,16 +938,14 @@ class TensorCoreIRMutator : public StmtExprMutator {
       PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
       ThreadIdxMutator thread_idx_mutator(warp_y);
       dst = thread_idx_mutator(dst);
-      dst =
-          Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}, CallNode::Extern);
+      dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst});
 
       auto pload = op->value.as<ProducerLoadNode>();
 
       auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
         return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(),
                              {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
-                              buffer->elem_offset, dst, stride, StringImm("col_major")},
-                             CallNode::Intrinsic));
+                              buffer->elem_offset, dst, stride, StringImm("col_major")}));
       };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -1067,7 +1060,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
       args.push_back(pload->indices[i]);
       args.push_back(shape[i]);
     }
-    auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args, CallNode::Intrinsic);
+    auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args);
     Array<ObjectRef> node = {buffer, tensor};
     return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer));
   }
index b5fb328..923cda3 100644 (file)
  * \file side_effect.cc
  * \brief side effect analysis
  */
+#include <tvm/ir/op.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op_attr_types.h>
 
 namespace tvm {
 namespace tir {
@@ -36,11 +38,19 @@ class ExprSideEffect : public ExprVisitor {
   }
 
   void VisitExpr_(const CallNode* op) final {
-    if (!op->is_pure()) {
+    static auto op_call_effect = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
+
+    if (auto* ptr_op = op->op.as<OpNode>()) {
+      auto effect_kind = op_call_effect[GetRef<Op>(ptr_op)];
+      if (effect_kind != CallEffectKind::kPure && effect_kind != CallEffectKind::kExprAnnotation) {
+        has_side_effect_ = true;
+        return;
+      } else {
+        ExprVisitor::VisitExpr_(op);
+      }
+    } else {
       has_side_effect_ = true;
       return;
-    } else {
-      ExprVisitor::VisitExpr_(op);
     }
   }
 
index 6cccfa0..e9f65ee 100644 (file)
@@ -377,7 +377,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
   }
   Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
                            make_const(DataType::Int(32), access_mask)};
-  return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args, tir::CallNode::Intrinsic);
+  return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
 }
 
 Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
index 4b20351..b4bb984 100644 (file)
@@ -698,7 +698,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     });
 
 // Call
-Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type) {
+Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args) {
   for (size_t i = 0; i < args.size(); ++i) {
     CHECK(args[i].defined());
   }
@@ -707,12 +707,11 @@ Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_typ
   node->dtype = dtype;
   node->op = std::move(op);
   node->args = std::move(args);
-  node->call_type = call_type;
   data_ = std::move(node);
 }
 
 TVM_REGISTER_GLOBAL("tir.Call")
-    .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, int call_type) {
+    .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args) {
       Array<PrimExpr> prim_expr_args;
       for (const auto& it : args) {
         CHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>());
@@ -722,7 +721,7 @@ TVM_REGISTER_GLOBAL("tir.Call")
           prim_expr_args.push_back(Downcast<PrimExpr>(it));
         }
       }
-      return Call(type, op, prim_expr_args, static_cast<CallNode::CallType>(call_type));
+      return Call(type, op, prim_expr_args);
     });
 
 TVM_REGISTER_NODE_TYPE(CallNode);
index 98b9fd0..afc128b 100644 (file)
@@ -166,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
   if (args.same_as(op->args)) {
     return GetRef<PrimExpr>(op);
   } else {
-    return Call(op->dtype, op->op, args, op->call_type);
+    return Call(op->dtype, op->op, args);
   }
 }
 
index 7b4ac7e..296f492 100644 (file)
@@ -22,6 +22,7 @@
  */
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/op.h>
+#include <tvm/tir/op_attr_types.h>
 #include <tvm/tir/stmt.h>
 
 namespace tvm {
@@ -566,10 +567,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 PrimExpr TypeAnnotation(DataType dtype) {
   static auto op = Op::Get("tir.type_annotation");
-  return tir::Call(dtype, op, {}, tir::CallNode::PureIntrinsic);
+  return tir::Call(dtype, op, {});
 }
 
-TVM_REGISTER_OP("tir.type_annotation");
+TVM_REGISTER_OP("tir.type_annotation")
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
 }  // namespace tir
 }  // namespace tvm
index 8efcf3f..d23662c 100644 (file)
@@ -38,117 +38,191 @@ namespace builtin {
   }                                                \
   TVM_REGISTER_OP("tir." #OpName)
 
-TIR_DEFINE_BUILTIN_FUNC(reinterpret).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(reinterpret)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+    .set_num_inputs(1);
 
-TIR_DEFINE_BUILTIN_FUNC(likely).set_num_inputs(1).set_attr<TVectorizable>("TVectorizable", true);
+TIR_DEFINE_BUILTIN_FUNC(likely)
+    .set_num_inputs(1)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
+    .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(bitwise_and)
     .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(bitwise_or)
     .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(bitwise_xor)
     .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(bitwise_not)
     .set_num_inputs(1)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(shift_left)
     .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(shift_right)
     .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_attr<TVectorizable>("TVectorizable", true);
 
-TIR_DEFINE_BUILTIN_FUNC(large_uint_imm).set_num_inputs(2);
+TIR_DEFINE_BUILTIN_FUNC(large_uint_imm)
+    .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
+
+TIR_DEFINE_BUILTIN_FUNC(address_of)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+    .set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(if_then_else)
+    .set_num_inputs(3)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(address_of).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(if_then_else).set_num_inputs(3);
+TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(popcount)
+    .set_num_inputs(1)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+    .set_attr<TVectorizable>("TVectorizable", true);
 
-TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(fma)
+    .set_num_inputs(3)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
+    .set_attr<TVectorizable>("TVectorizable", true);
 
-TIR_DEFINE_BUILTIN_FUNC(popcount).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(call_extern)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(fma).set_num_inputs(3).set_attr<TVectorizable>("TVectorizable", true);
+TIR_DEFINE_BUILTIN_FUNC(call_pure_extern)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(call_extern);
+TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin);
+TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(call_spirv_glsl450);
+TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(prefetch);
+TIR_DEFINE_BUILTIN_FUNC(prefetch).set_attr<TCallEffectKind>("TCallEffectKind",
+                                                            Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr).set_num_inputs(5);
+TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr)
+    .set_num_inputs(5)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle).set_num_inputs(0);
+TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle)
+    .set_num_inputs(0)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_context_id).set_num_inputs(0);
+TIR_DEFINE_BUILTIN_FUNC(tvm_context_id)
+    .set_num_inputs(0)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_tuple);
+TIR_DEFINE_BUILTIN_FUNC(tvm_tuple).set_attr<TCallEffectKind>("TCallEffectKind",
+                                                             Integer(CallEffectKind::kEmbedInfo));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get).set_num_inputs(3);
+TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get)
+    .set_num_inputs(3)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set).set_num_inputs(4);
+TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set)
+    .set_num_inputs(4)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error).set_num_inputs(0);
+TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error)
+    .set_num_inputs(0)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca).set_num_inputs(2);
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca)
+    .set_num_inputs(2)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape);
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array).set_num_inputs(6);
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array)
+    .set_num_inputs(6)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
 // When num_inputs are not set, the function is assumed to be variable length.
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context).set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context)
+    .set_num_inputs(1)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered);
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
 // TODO(tvm-team) revisit storage sync once we have a good memory hierachy structure.
-TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask);
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit);
+TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce);
+TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment);
+TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync);
+TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(vectorhigh);
+TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(vectorlow);
+TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr<TCallEffectKind>("TCallEffectKind",
+                                                             Integer(CallEffectKind::kPure));
 
-TIR_DEFINE_BUILTIN_FUNC(vectorcombine);
+TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
 }  // namespace builtin
 }  // namespace tir
index f8049ea..0f67126 100644 (file)
@@ -38,10 +38,14 @@ namespace tvm {
 using namespace tir;
 
 // macro to register an unary op
-#define TIR_REGISTER_PURE_UNARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(1)
+#define TIR_REGISTER_PURE_UNARY_OP(OpName)                             \
+  TVM_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
+      "TCallEffectKind", Integer(CallEffectKind::kPure))
 
 // macro to register an binary op
-#define TIR_REGISTER_PURE_BINARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(2)
+#define TIR_REGISTER_PURE_BINARY_OP(OpName)                            \
+  TVM_REGISTER_OP(OpName).set_num_inputs(2).set_attr<TCallEffectKind>( \
+      "TCallEffectKind", Integer(CallEffectKind::kPure))
 
 runtime::DataType GetRuntimeDataType(const Type& type) {
   if (auto* n = type.as<PrimTypeNode>()) {
@@ -83,8 +87,7 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
 // LargeUIntImm
 PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
   return tir::Call(t, tir::builtin::large_uint_imm(),
-                   {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)},
-                   tir::CallNode::PureIntrinsic);
+                   {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)});
 }
 
 // The public function with a quick checking path.
@@ -262,7 +265,7 @@ PrimExpr cast(const DataType& t, PrimExpr value) {
 // reinterpret
 PrimExpr reinterpret(const DataType& t, PrimExpr value) {
   if (value.dtype() == t) return value;
-  return tir::Call(t, tir::builtin::reinterpret(), {value}, tir::CallNode::PureIntrinsic);
+  return tir::Call(t, tir::builtin::reinterpret(), {value});
 }
 
 // operator+
@@ -387,17 +390,15 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value)
   }
 
   return tir::Call(true_value.dtype(), tir::builtin::if_then_else(),
-                   {cond, true_value, false_value}, tir::CallNode::PureIntrinsic);
+                   {cond, true_value, false_value});
 }
 
 // likely
 PrimExpr likely(PrimExpr cond) {
   if (is_const(cond)) return cond;
-  return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, tir::CallNode::PureIntrinsic);
+  return tir::Call(cond.dtype(), tir::builtin::likely(), {cond});
 }
 
-TVM_REGISTER_OP("tir.likely").set_num_inputs(1);
-
 // operator>
 PrimExpr operator>(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
@@ -481,7 +482,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) {
     }
   });
 
-  return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, tir::CallNode::PureIntrinsic);
+  return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b});
 }
 
 // shift left
@@ -500,7 +501,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) {
       if (pb->value == 0) return a;
     }
   });
-  return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, tir::CallNode::PureIntrinsic);
+  return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b});
 }
 
 // bitwise and
@@ -512,7 +513,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) {
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
   });
-  return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, tir::CallNode::PureIntrinsic);
+  return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b});
 }
 
 // bitwise_or
@@ -524,7 +525,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) {
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
   });
-  return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, tir::CallNode::PureIntrinsic);
+  return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b});
 }
 
 // bitwise_xor
@@ -536,17 +537,15 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) {
     const DataType& rtype = a.dtype();
     if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
   });
-  return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, tir::CallNode::PureIntrinsic);
+  return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b});
 }
 
 // bitwie_not
 PrimExpr operator~(PrimExpr a) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
-  return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, tir::CallNode::PureIntrinsic);
+  return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a});
 }
 
-TVM_REGISTER_OP("tir.bitwise_not");
-
 TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; });
 
 // pow
@@ -554,10 +553,10 @@ PrimExpr pow(PrimExpr x, PrimExpr y) {
   BinaryOpMatchTypes(x, y);
   CHECK(x.dtype().is_float()) << "power only applies to float";
   static auto op = Op::Get("tir.pow");
-  return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic);
+  return tir::Call(x.dtype(), op, {x, y});
 }
 
-TVM_REGISTER_OP("tir.pow").set_num_inputs(2).set_attr<TVectorizable>("TVectorizable", true);
+TIR_REGISTER_PURE_BINARY_OP("tir.pow").set_attr<TVectorizable>("TVectorizable", true);
 
 // abs
 PrimExpr abs(PrimExpr x) {
@@ -575,7 +574,7 @@ PrimExpr abs(PrimExpr x) {
       return FloatImm(x.dtype(), std::fabs(fx->value));
     }
     static auto op = Op::Get("tir.fabs");
-    return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+    return tir::Call(x.dtype(), op, {x});
   } else if (x.dtype().is_uint()) {
     return x;
   } else {
@@ -600,10 +599,9 @@ PrimExpr isnan(PrimExpr x) {
     }
     static auto op = Op::Get("tir.isnan");
     if (x.dtype().bits() == 16) {
-      return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))},
-                       tir::CallNode::PureIntrinsic);
+      return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))});
     } else {
-      return tir::Call(t, op, {x}, tir::CallNode::PureIntrinsic);
+      return tir::Call(t, op, {x});
     }
   } else {
     LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op...";
@@ -611,8 +609,6 @@ PrimExpr isnan(PrimExpr x) {
   }
 }
 
-TIR_REGISTER_PURE_UNARY_OP("tir.isnan");
-
 // isinf
 PrimExpr isinf(PrimExpr x) {
   DataType t = DataType::Bool(x.dtype().lanes());
@@ -685,7 +681,7 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
   BinaryOpMatchTypes(x, y);
   CHECK(x.dtype().is_float()) << "fmod only applies to float";
   static auto op = Op::Get("tir.fmod");
-  return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic);
+  return tir::Call(x.dtype(), op, {x, y});
 }
 
 TIR_REGISTER_PURE_UNARY_OP("tir.fmod");
@@ -699,7 +695,7 @@ PrimExpr floor(PrimExpr x) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
   static auto op = Op::Get("tir.floor");
-  return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+  return tir::Call(x.dtype(), op, {x});
 }
 
 TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr<TVectorizable>("TVectorizable", true);
@@ -713,7 +709,7 @@ PrimExpr ceil(PrimExpr x) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
   static auto op = Op::Get("tir.ceil");
-  return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+  return tir::Call(x.dtype(), op, {x});
 }
 
 TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr<TVectorizable>("TVectorizable", true);
@@ -727,7 +723,7 @@ PrimExpr round(PrimExpr x) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
   static auto op = Op::Get("tir.round");
-  return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+  return tir::Call(x.dtype(), op, {x});
 }
 
 TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr<TVectorizable>("TVectorizable", true);
@@ -741,7 +737,7 @@ PrimExpr nearbyint(PrimExpr x) {
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
   static auto op = Op::Get("tir.nearbyint");
-  return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+  return tir::Call(x.dtype(), op, {x});
 }
 
 TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint");
@@ -757,7 +753,7 @@ PrimExpr trunc(PrimExpr x) {
     return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)));
   }
   static auto op = Op::Get("tir.trunc");
-  return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
+  return tir::Call(x.dtype(), op, {x});
 }
 
 TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr<TVectorizable>("TVectorizable", true);
@@ -787,8 +783,6 @@ TIR_REGISTER_PURE_UNARY_OP("tir.log1p");
 
 TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr<TVectorizable>("TVectorizable", true);
 
-TIR_REGISTER_PURE_UNARY_OP("tir.popcount").set_attr<TVectorizable>("TVectorizable", true);
-
 TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr<TVectorizable>("TVectorizable", true);
index 1c540e3..adabae9 100644 (file)
@@ -29,11 +29,13 @@ namespace tir {
 
 TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace")
     .set_num_inputs(5)
-    .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace");
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace")
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
 TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace")
     .set_num_inputs(3)
-    .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace");
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace")
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
 }  // namespace tir
 }  // namespace tvm
index 80c5268..b88d298 100644 (file)
@@ -204,8 +204,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
   def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
   init_nest_.emplace_back(
       LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
-  PrimExpr is_null =
-      Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}, CallNode::PureIntrinsic);
+  PrimExpr is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides});
   if (buffer->strides.size() == 0) {
     // Assert the buffer is compact
     DataType stype = buffer->DefaultIndexType();
index 9722d11..4a44b85 100644 (file)
@@ -189,13 +189,13 @@ class BF16LowerRewriter : StmtExprMutator {
       auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
       auto uint32_v = Cast(uint32_dtype, op_val);
       // to be endian invariant.
-      return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}, CallNode::PureIntrinsic);
+      return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16});
 
     } else if (op->dtype.is_bfloat16()) {
       // if is cast_to_bf16, check if op->value is fp32
       CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
       auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
-      auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}, CallNode::PureIntrinsic);
+      auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val});
       auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes());
       /* the following TIR is equivalent to the C++ code below:
       uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
index 092a7cd..eb9ef32 100644 (file)
@@ -196,8 +196,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
   }
 
   std::vector<Stmt> GetSync(std::string sync_name) {
-    return {
-        Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}, CallNode::Intrinsic))};
+    return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))};
   }
 
   const std::unordered_set<const VarNode*>& touched_;
@@ -334,8 +333,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
     PrimExpr min = r->min;
     PrimExpr extent = r->extent;
     return Evaluate(Call(DataType::Int(32), Op::Get(func),
-                         {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
-                         CallNode::Intrinsic));
+                         {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}));
   }
   // Write barrier name
   bool read_barrier_{false};
@@ -558,13 +556,11 @@ class CoProcInstDepDetector : public StmtVisitor {
 
   Stmt MakePush(int from, int to) {
     return Evaluate(Call(DataType::Int(32), sync_push_op_,
-                         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
-                         CallNode::Intrinsic));
+                         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
   }
   Stmt MakePop(int from, int to) {
     return Evaluate(Call(DataType::Int(32), sync_pop_op_,
-                         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
-                         CallNode::Intrinsic));
+                         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}));
   }
   // sync states.
   SyncState first_state_, last_state_, curr_state_;
index 7180dd2..d540579 100644 (file)
@@ -231,8 +231,7 @@ class VTInjector : public StmtExprMutator {
       PrimExpr extent = this->VisitExpr(op->args[3]);
       PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
       offset = stride * var_ + offset;
-      return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]},
-                  op->call_type);
+      return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]});
     } else if (op->op.same_as(builtin::tvm_context_id())) {
       return allow_share_ ? GetRef<PrimExpr>(op) : var_;
     } else {
index 758923b..2f9d706 100644 (file)
@@ -87,7 +87,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index,
                              builtin::TVMStructFieldKind kind) {
   Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
                           make_const(DataType::Int(32), static_cast<int>(kind))};
-  return Call(dtype, builtin::tvm_struct_get(), args, CallNode::PureIntrinsic);
+  return Call(dtype, builtin::tvm_struct_get(), args);
 }
 
 /*!
@@ -99,8 +99,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index,
 inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) {
   return Call(DataType::Handle(), builtin::address_of(),
               {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
-                    const_true(dtype.lanes()))},
-              CallNode::PureIntrinsic);
+                    const_true(dtype.lanes()))});
 }
 
 /*!
@@ -115,7 +114,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) {
     offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes());
   }
   return Call(DataType::Handle(), builtin::address_of(),
-              {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic);
+              {Load(dtype, handle, offset, const_true(dtype.lanes()))});
 }
 
 /*!
@@ -129,7 +128,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) {
 inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) {
   Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
                           make_const(DataType::Int(32), static_cast<int>(kind)), value};
-  return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args, CallNode::Intrinsic));
+  return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args));
 }
 
 /*!
index d38cb7b..5ec4fe3 100644 (file)
@@ -51,17 +51,13 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
-    // NOTE: call_type will eventually be deprecated and the information
-    // will be folded into Op's attr
-    if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
-      if (auto* ptr_op = op->op.as<OpNode>()) {
-        // Still use legacy string based rewriting
-        // TODO(tvm-team): migrate the pattern application from global function look up
-        // to an OpAttrMap<PackedFunc>
-        std::string name = ptr_op->name;
-        PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
-        if (r.defined()) return r;
-      }
+    if (auto* ptr_op = op->op.as<OpNode>()) {
+      // Still use legacy string based rewriting
+      // TODO(tvm-team): migrate the pattern application from global function look up
+      // to an OpAttrMap<PackedFunc>
+      std::string name = ptr_op->name;
+      PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
+      if (r.defined()) return r;
     }
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
@@ -238,7 +234,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     PrimExpr rhs = SwapBroadcastCast(b);
 
     if (fma_ != nullptr && op->dtype.is_float()) {
-      PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}, CallNode::PureIntrinsic));
+      PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
       if (r.defined()) return this->VisitExpr(r);
     } else {
       if (!lhs.same_as(a) || !rhs.same_as(b)) {
index dab8d5a..04b8953 100644 (file)
@@ -242,8 +242,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       Var mask_var("mask", DataType::UInt(32));
       {
         PrimExpr pred = const_true(1);
-        PrimExpr mask =
-            Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic);
+        PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
         seq.emplace_back(Store(mask_var, mask, index, pred));
         // Push allocation with an empty body. Later this will be fixed
         // when the entire body is ready.
@@ -464,8 +463,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   }
   // sync thread op.
   static Stmt SyncThread(const std::string& sync) {
-    return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)},
-                         CallNode::Intrinsic));
+    return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}));
   }
 
   // Emit warp shuffle  calls.
@@ -475,7 +473,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred);
     PrimExpr width = IntImm(DataType::Int(32), warp_size_);
     Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width};
-    return Call(val.dtype(), op, args, CallNode::Intrinsic);
+    return Call(val.dtype(), op, args);
   }
 
   // Check if this is a reduction on threadIdx.x and its extent matches
index e618230..f071704 100644 (file)
@@ -41,7 +41,7 @@ inline PrimExpr ConstInt32(size_t index) {
 
 inline PrimExpr StackAlloca(std::string type, size_t num) {
   Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
-  return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args, CallNode::Intrinsic);
+  return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
 }
 
 // Calculate the statistics of packed function.
@@ -103,11 +103,9 @@ class BuiltinLower : public StmtExprMutator {
     }
     CHECK(device_type_.defined()) << "Unknown device type in current IR";
     CHECK(device_id_.defined()) << "Unknown device id in current IR";
-    Stmt throw_last_error =
-        Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}, CallNode::Intrinsic));
+    Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
 
-    Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var},
-                                         CallNode::PureIntrinsic),
+    Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
                                     throw_last_error),
                          op->body});
     Stmt alloca = LetStmt(
@@ -115,14 +113,12 @@ class BuiltinLower : public StmtExprMutator {
         Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
              {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
               cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
-              IntImm(DataType::Int(32), op->dtype.bits())},
-             CallNode::Extern),
+              IntImm(DataType::Int(32), op->dtype.bits())}),
         body);
 
     PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
                             {cast(DataType::Int(32), device_type_),
-                             cast(DataType::Int(32), device_id_), op->buffer_var},
-                            CallNode::Extern);
+                             cast(DataType::Int(32), device_id_), op->buffer_var});
     Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
     body = SeqStmt({alloca, free_stmt});
     body = AttrStmt(op->buffer_var, attr::storage_alignment,
@@ -245,8 +241,7 @@ class BuiltinLower : public StmtExprMutator {
     Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
                                    ConstInt32(arg_stack_begin),
                                    ConstInt32(arg_stack_begin + op->args.size() - 1)};
-    return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args,
-                CallNode::Intrinsic);
+    return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args);
   }
 
   PrimExpr MakeCallTracePacked(const CallNode* op) {
@@ -287,8 +282,7 @@ class BuiltinLower : public StmtExprMutator {
                                    ConstInt32(arg_stack_begin + op->args.size() - 1),
                                    // Pass traced value.
                                    op->args[args_size - 1]};
-    return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args,
-                CallNode::Intrinsic);
+    return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args);
   }
 
  private:
index 3e7d13b..72423e0 100644 (file)
@@ -250,10 +250,9 @@ class WarpAccessRewriter : protected StmtExprMutator {
           << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index
           << " local_index=" << local_index;
       PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);
-      PrimExpr mask =
-          Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic);
+      PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
       return Call(load_value.dtype(), builtin::tvm_warp_shuffle(),
-                  {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic);
+                  {mask, load_value, group, width_, warp_size_});
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
index 9bb5fc6..bfcf0b7 100644 (file)
@@ -86,7 +86,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
                               IntImm(DataType::Int(32), builtin::kTVMValueContent)};
     // load 64 bit version
     DataType api_type = APIType(t);
-    PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args, CallNode::PureIntrinsic);
+    PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
     // cast to the target version.
     if (api_type != t) {
       res = Cast(t, res);
@@ -191,8 +191,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
     if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
       Stmt set_device =
           Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
-                        {StringImm(runtime::symbol::tvm_set_device), device_type, device_id},
-                        CallNode::Intrinsic));
+                        {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
       body = SeqStmt({set_device, body});
     }
   }
index e553536..f1286d7 100644 (file)
@@ -24,6 +24,7 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op_attr_types.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
@@ -43,11 +44,16 @@ class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
     } else if (op->op.same_as(builtin::address_of())) {
       const LoadNode* l = op->args[0].as<LoadNode>();
       return this->VisitExpr(l->index);
-    } else if (op->is_pure()) {
-      for (PrimExpr e : op->args) {
-        if (VisitExpr(e)) return true;
+    } else if (auto* ptr_op = op->op.as<OpNode>()) {
+      auto effect_kind = op_call_effect_[GetRef<Op>(ptr_op)];
+      if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) {
+        for (PrimExpr e : op->args) {
+          if (VisitExpr(e)) return true;
+        }
+        return false;
+      } else {
+        return true;
       }
-      return false;
     } else {
       return true;
     }
@@ -94,6 +100,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
   bool BinaryOp(const T* op) {
     return VisitExpr(op->a) || VisitExpr(op->b);
   }
+
+  OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
 };
 
 class UnsafeSelectRewriter : public StmtExprMutator {
@@ -106,7 +114,7 @@ class UnsafeSelectRewriter : public StmtExprMutator {
     if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) &&
         cond_is_scalar_bool) {
       return Call(op->dtype, builtin::if_then_else(),
-                  {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic);
+                  {op->condition, op->true_value, op->false_value});
     } else {
       return expr;
     }
index c35caf5..f339c56 100644 (file)
@@ -238,8 +238,7 @@ class HostDeviceSplitter : public StmtMutator {
     for (PrimExpr ext : m.thread_extent_) {
       call_args.push_back(ext);
     }
-    return Evaluate(
-        Call(DataType::Int(32), builtin::tvm_call_packed(), call_args, CallNode::Intrinsic));
+    return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args));
   }
 
   // target ir module
index 3080550..8eb43f8 100644 (file)
@@ -321,10 +321,8 @@ class StorageFlattener : public StmtExprMutator {
         stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
       } else {
         PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
-        PrimExpr address =
-            Call(DataType::Handle(), builtin::address_of(), {load}, CallNode::PureIntrinsic);
-        PrimExpr prefetch =
-            Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}, CallNode::Intrinsic);
+        PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load});
+        PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1});
         stmt = Evaluate(prefetch);
         PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
         stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
index d7a258c..09d9651 100644 (file)
@@ -404,8 +404,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       if (se->bits_offset != 0) {
         offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
       }
-      return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]},
-                  op->call_type);
+      return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]});
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
index cdd9377..a38be3c 100644 (file)
@@ -211,7 +211,7 @@ class ThreadSyncInserter : public StmtExprMutator {
         barrier = MakeGlobalBarrier();
       } else {
         barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
-                                {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic));
+                                {StringImm(sync_scope_.to_string())}));
       }
       // Mutate after query, to avoid stmt change.
       auto ret = StmtExprMutator::VisitStmt(stmt);
@@ -299,8 +299,7 @@ class ThreadSyncInserter : public StmtExprMutator {
   Stmt InitGlobalBarrier(const AttrStmtNode* op) {
     CHECK(op != nullptr);
     Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)};
-    Stmt prep =
-        Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs, CallNode::Intrinsic));
+    Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
     Stmt body = op->body;
     for (const auto& kv : rw_stats_) {
       const auto& e = kv.second;
@@ -309,8 +308,7 @@ class ThreadSyncInserter : public StmtExprMutator {
       }
     }
     rw_stats_.clear();
-    Stmt kinit = Evaluate(
-        Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}, CallNode::Intrinsic));
+    Stmt kinit = Evaluate(Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}));
     body = SeqStmt({kinit, body});
     body = AttrStmt(op->node, op->attr_key, op->value, body);
     return SeqStmt({prep, body});
@@ -334,8 +332,7 @@ class ThreadSyncInserter : public StmtExprMutator {
       CHECK_EQ(num_work_dim_, thread_extents_.size());
     }
     return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
-                         {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_},
-                         CallNode::Intrinsic));
+                         {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}));
   }
   // data structure.
   StorageScope sync_scope_;
index 1a2ec50..e015990 100644 (file)
@@ -214,7 +214,7 @@ class Vectorizer : public StmtExprMutator {
       int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
       t = BroadcastTo(t, lanes);
       f = BroadcastTo(f, lanes);
-      return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->call_type);
+      return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
     }
   }
   // Call
@@ -239,7 +239,7 @@ class Vectorizer : public StmtExprMutator {
       if (op->args.same_as(new_args)) {
         return GetRef<PrimExpr>(op);
       } else {
-        return Call(op->dtype, op->op, new_args, op->call_type);
+        return Call(op->dtype, op->op, new_args);
       }
     } else {
       int lane = 0;
@@ -248,7 +248,7 @@ class Vectorizer : public StmtExprMutator {
       if (op->args.same_as(new_args)) {
         return GetRef<PrimExpr>(op);
       } else {
-        return Call(op->dtype.with_lanes(lane), op->op, new_args, op->call_type);
+        return Call(op->dtype.with_lanes(lane), op->op, new_args);
       }
     }
   }
index ce50ed0..de06a0e 100644 (file)
@@ -193,8 +193,8 @@ TEST(IRF, StmtMutator) {
   }
 
   {
-    auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1},
-                              CallNode::Extern));
+    auto body =
+        Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}));
     auto res = v(std::move(body));
     CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[1].same_as(x));
   }
index 9882a3b..e12f970 100644 (file)
@@ -204,7 +204,7 @@ def test_reduce_combiner_simplify():
 
     # Test that components with side effects are not removed
     dummy = tvm.ir.GlobalVar("dummy")
-    side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs, tvm.tir.Call.Intrinsic)
+    side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs)
     ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
              sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
     ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
index 18a98ee..698dd74 100644 (file)
@@ -98,7 +98,7 @@ def test_reinterpret():
     nn = 1024
     n = tvm.runtime.convert(nn)
     A = te.placeholder((n,), name='A', dtype="int32")
-    B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "tir.reinterpret", A(*i)), name='B')
+    B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", A(*i)), name='B')
     s = te.create_schedule(B.op)
 
     def check_c():
index a6a2315..911ffb4 100644 (file)
@@ -29,12 +29,12 @@ def test_llvm_intrin():
     n = tvm.runtime.convert(4)
     A = ib.pointer("float32", name="A")
     args = [
-        tvm.tir.call_pure_intrin("handle", "tir.address_of", A[0]),
+        tvm.tir.call_intrin("handle", "tir.address_of", A[0]),
         0, 3, 1
     ]
     ib.emit(tvm.tir.Evaluate(
         tvm.tir.Call(
-            "int32", "tir.prefetch", args, tvm.tir.Call.Intrinsic)))
+            "int32", "tir.prefetch", args)))
     body = ib.get()
 
     mod = tvm.IRModule.from_expr(
@@ -65,7 +65,7 @@ def test_llvm_overloaded_intrin():
     def use_llvm_intrinsic(A, C):
         ib = tvm.tir.ir_builder.create()
         L = A.vload((0,0))
-        I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz',
+        I = tvm.tir.call_llvm_pure_intrin('int32', 'llvm.ctlz',
             tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1'))
         S = C.vstore((0,0), I)
         ib.emit(S)
@@ -124,7 +124,7 @@ def test_llvm_lookup_intrin():
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("uint8x8", name="A")
     z = tvm.tir.const(0, 'int32')
-    x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
+    x = tvm.tir.call_llvm_pure_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
     ib.emit(x)
     body = ib.get()
     mod = tvm.IRModule.from_expr(
index d2c504b..578e32f 100644 (file)
@@ -112,12 +112,11 @@ def test_expr_constructor():
     assert x.vectors[0] == a
     assert x.indices[0].value == 0
 
-    x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a], tvm.tir.Call.Extern)
+    x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a])
     assert isinstance(x, tvm.tir.Call)
     assert x.dtype == "float32"
     assert x.op.name == "tir.call_extern"
     assert x.args[1] == a
-    assert x.call_type == tvm.tir.Call.Extern
 
     v = te.var("aa")
     x = tvm.tir.Let(v, 1, v)
index 39acb3a..ab730cd 100644 (file)
@@ -171,19 +171,19 @@ def test_all():
 def test_bitwise():
     x = te.var('x')
     y = te.var('y')
-    assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")'
-    assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")'
-    assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")'
-    assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")'
-    assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")'
-    assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")'
-    assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")'
-    assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")'
-    assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32, type="pure_intrin")'
-    assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32, type="pure_intrin")'
+    assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32)'
+    assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32)'
+    assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32)'
+    assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32)'
+    assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32)'
+    assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32)'
+    assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32)'
+    assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32)'
+    assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32)'
+    assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32)'
     assert str(10 % x) == 'floormod(10, x: int32)'
 
-    assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32, type="pure_intrin")'
+    assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32)'
     assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
     assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
     assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
@@ -240,10 +240,10 @@ def test_divide_by_zero():
 
 def test_isnan():
     x = te.var('x', 'float32')
-    assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool, type="pure_intrin")'
+    assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool)'
     assert str(tvm.tir.isnan(x).dtype) == 'bool'
     y = te.var('y', 'float16')
-    assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")'
+    assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool)'
     z = te.var('z', 'int32')
     assert str(tvm.tir.isnan(z)) == 'False'
     k = te.var('k', 'int8x2')
index 55a6819..599ddba 100644 (file)
@@ -115,19 +115,19 @@ def test_eliminate():
 def test_legalize():
     def to32(v):
         uint32_v = topi.cast(v, "uint32")
-        uint32_v = tvm.tir.call_pure_intrin(
+        uint32_v = tvm.tir.call_intrin(
             "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32"))
-        return tvm.tir.call_pure_intrin("float32", "tir.reinterpret", uint32_v)
+        return tvm.tir.call_intrin("float32", "tir.reinterpret", uint32_v)
 
     def to16(v):
-        uint32_v = tvm.tir.call_pure_intrin("uint32", "tir.reinterpret", v)
-        rounding_bias = tvm.tir.call_pure_intrin(
+        uint32_v = tvm.tir.call_intrin("uint32", "tir.reinterpret", v)
+        rounding_bias = tvm.tir.call_intrin(
             "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
-        rounding_bias = tvm.tir.call_pure_intrin(
+        rounding_bias = tvm.tir.call_intrin(
             "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32"))
         rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
         uint32_v = uint32_v + rounding_bias
-        uint32_v = tvm.tir.call_pure_intrin(
+        uint32_v = tvm.tir.call_intrin(
             "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
         return topi.cast(uint32_v, 'uint16')
 
index d7a25ca..2886958 100644 (file)
@@ -22,7 +22,7 @@ def test_for():
     def device_context(dev_id):
         ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
         return tvm.tir.Call(
-            "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic)
+            "handle", "tir.tvm_thread_context", [ctx])
 
     ib = tvm.tir.ir_builder.create()
     n = te.var("n")
index 4964039..be725d6 100644 (file)
@@ -36,7 +36,7 @@ def test_vthread():
             bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
             ib.emit(tvm.tir.call_extern("int32", "Run",
                                     bbuffer.access_ptr("r"),
-                                    tvm.tir.call_pure_intrin("int32", "tir.tvm_context_id")))
+                                    tvm.tir.call_intrin("int32", "tir.tvm_context_id")))
             C[i * nthread + tx] = B[i] + 1
         return ib.get()
 
index 7068b95..5349818 100644 (file)
@@ -112,12 +112,12 @@ inline Array<Tensor> make_extern(const Array<Array<PrimExpr> >& out_shapes,
  */
 inline PrimExpr pack_buffer(Buffer buf) {
   CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
-  auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(),
-                              buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
+  auto shape =
+      tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape);
   PrimExpr strides;
   if (buf->strides.size() > 0) {
-    strides = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(),
-                             buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
+    strides =
+        tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape);
   } else {
     strides = 0;
   }
@@ -127,8 +127,7 @@ inline PrimExpr pack_buffer(Buffer buf) {
                             make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
                             make_const(buf->dtype, 0),
                             buf->elem_offset};
-  return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args,
-                        tvm::tir::CallNode::CallType::Intrinsic);
+  return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args);
 }
 
 /*!
@@ -141,8 +140,7 @@ inline PrimExpr pack_buffer(Buffer buf) {
  * \return An expression representing the invocation
  */
 inline PrimExpr call_packed(Array<PrimExpr> args) {
-  return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args,
-                        tvm::tir::CallNode::CallType::Intrinsic);
+  return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args);
 }
 
 }  // namespace detail
index 0ec7e4d..9b418d0 100644 (file)
@@ -310,8 +310,7 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te
   return compute(
       x->shape,
       [&](const Array<Var>& i) {
-        return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)},
-                              tvm::tir::CallNode::PureIntrinsic);
+        return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)});
       },
       name, tag);
 }
index f035251..e76b374 100644 (file)
@@ -231,21 +231,21 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
                                 cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_)
                             else:
                                 cnts = tvm.tir.popcount(w_ & x_)
-                            upper_half = tvm.tir.call_pure_intrin(
+                            upper_half = tvm.tir.call_intrin(
                                 half_dtype, 'tir.vectorhigh', cnts)
-                            lower_half = tvm.tir.call_pure_intrin(
+                            lower_half = tvm.tir.call_intrin(
                                 half_dtype, 'tir.vectorlow', cnts)
                             cnts8[i] = upper_half + lower_half
                         for i in range(m//2):
-                            cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_2, cnts8[i*2], cnts8[i*2+1])
+                            cnts4[i] = tvm.tir.call_llvm_pure_intrin(
+                                half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1])
                         for i in range(m//4):
-                            cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_2, cnts4[i*2], cnts4[i*2+1])
-                        cnts = tvm.tir.call_pure_intrin(
+                            cnts2[i] = tvm.tir.call_llvm_pure_intrin(
+                                half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1])
+                        cnts = tvm.tir.call_intrin(
                             full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
                         shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
-                        out = tvm.tir.call_llvm_intrin(
+                        out = tvm.tir.call_llvm_pure_intrin(
                             return_dtype, vpadalu,
                             args_2, zz.vload(0, return_dtype), shifted_cnts)
                     else: # ki == 8
@@ -257,15 +257,15 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
                             else:
                                 cnts8[i] = tvm.tir.popcount(w_ & x_)
                         for i in range(m//2):
-                            cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_2, cnts8[i*2], cnts8[i*2+1])
+                            cnts4[i] = tvm.tir.call_llvm_pure_intrin(
+                                half_dtype, vpadd, args_2, cnts8[i*2], cnts8[i*2+1])
                         for i in range(m//4):
-                            cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_2, cnts4[i*2], cnts4[i*2+1])
-                        cnts = tvm.tir.call_pure_intrin(
+                            cnts2[i] = tvm.tir.call_llvm_pure_intrin(
+                                half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1])
+                        cnts = tvm.tir.call_intrin(
                             full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
                         shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
-                        out = tvm.tir.call_llvm_intrin(
+                        out = tvm.tir.call_llvm_pure_intrin(
                             return_dtype, vpadalu,
                             args_2, zz.vload(0, return_dtype), shifted_cnts)
                     irb.emit(zz.vstore(0, out))
index 6ef2548..dfa2f05 100644 (file)
@@ -425,21 +425,22 @@ def dot_int8_int8_int32(int32_lanes, dtype='uint'):
             dtype_c = '%s32x%d' % (dtype, int32_lanes)
 
             a_int8 = ins[0].vload([0], dtype_a)
-            re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'tir.reinterpret', a_int8)
+            re_int32 = tvm.tir.call_intrin('%s32' % dtype, 'tir.reinterpret', a_int8)
             # broadcast a
             vec_ai32 = re_int32.astype(dtype_c)
 
-            vec_a = tvm.tir.call_pure_intrin(dtype_b, 'tir.reinterpret', vec_ai32)
+            vec_a = tvm.tir.call_intrin(dtype_b, 'tir.reinterpret', vec_ai32)
             vec_b = ins[1].vload([0, 0], dtype_b)
             vec_c = outs[0].vload([0], dtype_c)
 
             inst = 'udot' if dtype == 'uint' else 'sdot'
             inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % (
                 inst, int32_lanes, int32_lanes * num_int8_elements)
-            vdot = tvm.tir.call_llvm_intrin(dtype_c,
-                                            inst,
-                                            tvm.tir.const(2, 'uint32'),
-                                            vec_c, vec_a, vec_b)
+            vdot = tvm.tir.call_llvm_pure_intrin(
+                dtype_c,
+                inst,
+                tvm.tir.const(2, 'uint32'),
+                vec_c, vec_a, vec_b)
             ib.emit(outs[0].vstore(0, vdot))
             return ib.get()
 
index c98d7e9..9e3200a 100644 (file)
@@ -38,10 +38,10 @@ def cuda_atomic_add_rule(op):
 tvm.target.intrin.register_intrin_rule(
     "cuda", "atomic_add", cuda_atomic_add_rule, override=True)
 
-tvm.ir.register_op_attr("tir.atomic_add", "TVectorizable", False)
+tvm.ir.register_op_attr("tir.atomic_add", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
 
 def atomic_add(x, y):
-    return tvm.tir.call_pure_intrin(y.dtype, "tir.atomic_add", x, y)
+    return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
 
 
 def get_valid_counts_ir(data, valid_count, out, out_indices,
@@ -114,7 +114,7 @@ def get_valid_counts_ir(data, valid_count, out, out_indices,
         with ib.if_scope(
                 tvm.tir.all(data[tid * elem_length + score_index] > score_threshold,
                             tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
-            atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tir.address_of",
+            atomic_add_return[0] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of",
                                                                        valid_count[i]), one_count)
             with ib.for_range(0, elem_length) as k:
                 out[tid * elem_length + k] = data[tid * elem_length + k]
index 5b7e090..1414384 100644 (file)
@@ -186,8 +186,7 @@ def argsort_ir(data_buf, out_index_buf):
                 index_out[offset] = index_out[offset + 1]
                 index_out[offset + 1] = temp_index[0]
             ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                                 tvm.runtime.convert(['shared']),
-                                 tvm.tir.Call.Intrinsic))
+                                 tvm.runtime.convert(['shared'])))
     return ib.get()
 
 
@@ -247,8 +246,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
                 with ib.if_scope(iou > nms_threshold):
                     p_out[base_idx + i] = True
         ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                             tvm.runtime.convert(['shared']),
-                             tvm.tir.Call.Intrinsic))
+                             tvm.runtime.convert(['shared'])))
     return ib.get()
 
 
index 7181d57..a8d1572 100644 (file)
@@ -116,8 +116,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
                     indices_out[base_idx + tid * axis_mul_after] = \
                         tvm.tir.generic.cast(tid, indices_out.dtype)
     ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                         tvm.runtime.convert(['shared']),
-                         tvm.tir.Call.Intrinsic))
+                         tvm.runtime.convert(['shared'])))
     idxd = tvm.tir.indexdiv
     idxm = tvm.tir.indexmod
 
@@ -144,8 +143,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
                             indices_out[offset] = indices_out[offset + axis_mul_after]
                             indices_out[offset + axis_mul_after] = temp_index[0]
                 ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                                     tvm.runtime.convert(['shared']),
-                                     tvm.tir.Call.Intrinsic))
+                                     tvm.runtime.convert(['shared'])))
 
     return ib.get()
 
@@ -236,8 +234,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
                         output[offset] = output[offset + axis_mul_after]
                         output[offset + axis_mul_after] = temp_index[0]
                 ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
-                                     tvm.runtime.convert(['shared']),
-                                     tvm.tir.Call.Intrinsic))
+                                     tvm.runtime.convert(['shared'])))
 
     return ib.get()
 
index 31de70e..17c0b36 100644 (file)
@@ -88,19 +88,21 @@ def dot_16x1x16_uint8_int8_int32_skylake():
                 return ib.get()
 
             a_int8 = ins[0].vload([0], "uint8x4")
-            re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8)
+            re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8)
             vec_ai32 = re_int32.astype('int32x16')
-            vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32)
+            vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32)
             vec_b = ins[1].vload([0, 0], "int8x64")
             vec_one = tvm.tir.const(1, "int16x32")
-            pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
-                                                      'llvm.x86.avx512.pmaddubs.w.512',
-                                                      tvm.tir.const(0, 'uint32'),
-                                                      vec_a, vec_b)
-            quad_reduction = tvm.tir.call_llvm_intrin('int32x16',
-                                                      'llvm.x86.avx512.pmaddw.d.512',
-                                                      tvm.tir.const(0, 'uint32'),
-                                                      pair_reduction, vec_one)
+            pair_reduction = tvm.tir.call_llvm_pure_intrin(
+                'int16x32',
+                'llvm.x86.avx512.pmaddubs.w.512',
+                tvm.tir.const(0, 'uint32'),
+                vec_a, vec_b)
+            quad_reduction = tvm.tir.call_llvm_pure_intrin(
+                'int32x16',
+                'llvm.x86.avx512.pmaddw.d.512',
+                tvm.tir.const(0, 'uint32'),
+                pair_reduction, vec_one)
             if index == 0:
                 ib.emit(outs[0].vstore(0, quad_reduction))
             else:
@@ -174,16 +176,17 @@ def dot_16x1x16_uint8_int8_int16():
                 return ib.get()
 
             a_int8 = ins[0].vload([0], "uint8x2")
-            re_int16 = tvm.tir.call_pure_intrin('int16', 'tir.reinterpret', a_int8)
+            re_int16 = tvm.tir.call_intrin('int16', 'tir.reinterpret', a_int8)
             vec_ai16 = re_int16.astype('int16x32')
-            vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai16)
+            vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai16)
 
             for i in range(4):
                 vec_b = ins[1].vload([i*32, 0], "int8x64")
-                pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
-                                                          'llvm.x86.avx512.pmaddubs.w.512',
-                                                          tvm.tir.const(0, 'uint32'),
-                                                          vec_a, vec_b)
+                pair_reduction = tvm.tir.call_llvm_pure_intrin(
+                    'int16x32',
+                    'llvm.x86.avx512.pmaddubs.w.512',
+                    tvm.tir.const(0, 'uint32'),
+                    vec_a, vec_b)
                 if index == 0:
                     ib.emit(outs[0].vstore([i*32], pair_reduction))
                 else:
@@ -254,7 +257,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake():
                 return ib.get()
 
             a_int8 = ins[0].vload([0], "uint8x4")
-            re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8)
+            re_int32 = tvm.tir.call_intrin('int32', 'tir.reinterpret', a_int8)
             vec_ai32 = re_int32.astype('int32x16')
             vec_b = ins[1].vload([0, 0], "int8x64")
 
@@ -262,24 +265,27 @@ def dot_16x1x16_uint8_int8_int32_cascadelake():
             llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name)
 
             if llvm_id != 0: # VNNI is available for current LLVM version
-                vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'tir.reinterpret', vec_b)
+                vec_bi32 = tvm.tir.call_intrin('int32x16', 'tir.reinterpret', vec_b)
                 vec_zero = tvm.tir.const(0, "int32x16")
-                quad_reduction = tvm.tir.call_llvm_intrin('int32x16',
-                                                          'llvm.x86.avx512.vpdpbusd.512',
-                                                          tvm.tir.const(0, 'uint32'),
-                                                          vec_zero,
-                                                          vec_ai32, vec_bi32)
+                quad_reduction = tvm.tir.call_llvm_pure_intrin(
+                    'int32x16',
+                    'llvm.x86.avx512.vpdpbusd.512',
+                    tvm.tir.const(0, 'uint32'),
+                    vec_zero,
+                    vec_ai32, vec_bi32)
             else: # Fall back to the normal AVX512
-                vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32)
+                vec_a = tvm.tir.call_intrin('int8x64', 'tir.reinterpret', vec_ai32)
                 vec_one = tvm.tir.const(1, "int16x32")
-                pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
-                                                          'llvm.x86.avx512.pmaddubs.w.512',
-                                                          tvm.tir.const(0, 'uint32'),
-                                                          vec_a, vec_b)
-                quad_reduction = tvm.tir.call_llvm_intrin('int32x16',
-                                                          'llvm.x86.avx512.pmaddw.d.512',
-                                                          tvm.tir.const(0, 'uint32'),
-                                                          pair_reduction, vec_one)
+                pair_reduction = tvm.tir.call_llvm_pure_intrin(
+                    'int16x32',
+                    'llvm.x86.avx512.pmaddubs.w.512',
+                    tvm.tir.const(0, 'uint32'),
+                    vec_a, vec_b)
+                quad_reduction = tvm.tir.call_llvm_pure_intrin(
+                    'int32x16',
+                    'llvm.x86.avx512.pmaddw.d.512',
+                    tvm.tir.const(0, 'uint32'),
+                    pair_reduction, vec_one)
 
             if index == 0:
                 ib.emit(outs[0].vstore(0, quad_reduction))
index 65bfd4c..4a4ff96 100644 (file)
@@ -135,7 +135,7 @@ print(fcuda.imported_modules[0].get_source())
 
 def mylog(x):
     """customized log intrinsic function"""
-    return tvm.tir.call_pure_intrin(x.dtype, "tir.mylog", x)
+    return tvm.tir.call_intrin(x.dtype, "tir.mylog", x)
 
 
 def my_cuda_mylog_rule(op):
@@ -148,7 +148,7 @@ def my_cuda_mylog_rule(op):
         return op
 
 # new op registration is triggered by registering an attribute of the op
-tvm.ir.register_op_attr("tir.mylog", "TVectorizable", True)
+tvm.ir.register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
 tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
 
 n = te.var("n")
index 947c583..3e6a0c5 100644 (file)
@@ -79,8 +79,7 @@ class DevContext(object):
         self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp")
         ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle")
         self.command_handle = tvm.tir.Call(
-            "handle", "tir.tvm_thread_context", [ctx],
-            tvm.tir.Call.Intrinsic)
+            "handle", "tir.tvm_thread_context", [ctx])
         self.DEBUG_NO_SYNC = False
         env._dev_ctx = self
         self.gemm = intrin.gemm(env, env.mock_mode)
@@ -316,12 +315,15 @@ def coproc_dep_pop(op):
 
 # register a dummy into to trigger registration of the ops
 # change the info to lowering rule later.
-tvm.ir.register_op_attr("tir.vta.coproc_sync", "TVectorizable", False)
-tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TVectorizable", False)
-tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TVectorizable", False)
+tvm.ir.register_op_attr("tir.vta.coproc_sync", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
+tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
+tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
 
+tvm.ir.register_op_attr("tir.vta.uop_push", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
 tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush")
+
 tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle")
+tvm.ir.register_op_attr("tir.vta.command_handle", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)
 
 
 def _init_env():
index e92b178..d9f47f1 100644 (file)
@@ -298,7 +298,7 @@ def InjectCoProcSync():
             if _match_pragma(stmt, "coproc_sync"):
                 success[0] = True
                 sync = tvm.tir.Call(
-                    "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic)
+                    "int32", "vta.coproc_sync", [])
                 return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
             if _match_pragma(stmt, "trim_loop"):
                 op = stmt.body