* [TIR][REFACTOR][API-CHANGE] Change Call.name(string) to Call.op(tvm::Op/RelayExpr)
This PR brings a major refactor to the tir::Call structure.
The current Call structure uses a string field(name) to identify the
function/intrinsic being called. This approach is limited as we start
to expand TIR to be more structured. In particular, we are interested in
the following aspects:
- Type a function and perform better compile time type checking so that we
can find errors early.
- Register additional properties about an operator, such as:
- Whether an intrinsic can be vectorized
- What is the adjoint function of the intrinsic(for tensor expression AD)
- Whether the operator has side effect.
- Perform specific codegen about an intrinsic if necessary.
- Call into another function in the same module.
The refactor changes the Call.name field to Call.op.
The Call.op field has a RelayExpr type, and we can pass:
- A tvm::Op which represents the corresponding intrinsic.
- A tvm::GlobalVar for calling into another function in the IRModule.
All the current intrinsics are migrated by registering an tvm::Op.
Because the unified IR shares a single Op registry. We use the "tir"
namespace for tir related intrinsics, for example bitwise and is now registered
under `tir.bitwise_and`.
To simplify upgrade, we introduce a `tir.call_extern` intrinsic
that allows us to call into arbitary external function without type checking.
However, we should move towards more type checked variants in the system.
Under the new op design. We should no longer try to pattern match all the
specific intrincis. Instead, we should rely on attr of each Op to do transformation.
For example, the vectorization pass depends on the TVectorizable property of the op,
which can be registered independently.
In this way, we can still grow the number of intrinsics when necessary
without having to change all the passes.
The same rule applies for tensor expression AD. Currently we are performing
AD by pattern match on operators like exp, sin, cos. We should instead
change to the ajoint registeration mechanism like those in relay.
Followup refactors need to be performed, including:
- Fold the Call.call_type into operator's attribute.
- Enrich the operator registry information
- Refactor passes(e.g. AD, intrin lowering) to use the attribute based transformation
* Fix nms
* Fix remaining testcase
* Address review comment
/*!
* \brief The operator(function) being invoked
*
- * - It can be relay::Op which corresponds to the primitive operators.
+ * - It can be tvm::Op which corresponds to the primitive operators.
* - It can also be user defined functions (Function, GlobalVar, Var).
*/
Expr op;
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/builtin.h
+ * \brief TIR builtin intrinsics.
+ *
+ * TIR builtin intrinsics are stored as tvm:Op.
+ * They are processed in the same way as we process Ops.
+ *
+ * It is not necessary to create a function for every Op,
+ * as we can obtain them through Op::Get.
+ *
+ * This file contains the most commonly used intrinsics or
+ * those that have special semantics and need compiler support.
+ */
+#ifndef TVM_TIR_BUILTIN_H_
+#define TVM_TIR_BUILTIN_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/tir/expr.h>
+
+namespace tvm {
+namespace tir {
+
+/*! \brief Collection of builtin intrinsics as ops */
+namespace builtin {
+/*!
+ * \brief Reinterpret the value using the target type.
+ */
+TVM_DLL const Op& reinterpret();
+
+/*!
+ * \brief Marks a condition is likely going to happen.
+ */
+TVM_DLL const Op& likely();
+
+/*!
+ * \brief Bitwise and operator.
+ */
+TVM_DLL const Op& bitwise_and();
+
+/*!
+ * \brief Bitwise or operator.
+ */
+TVM_DLL const Op& bitwise_or();
+
+/*!
+ * \brief Bitwise xor operator.
+ */
+TVM_DLL const Op& bitwise_xor();
+
+/*!
+ * \brief Bitwise not operator.
+ */
+TVM_DLL const Op& bitwise_not();
+
+/*!
+ * \brief Left shift
+ */
+TVM_DLL const Op& shift_left();
+
+/*!
+ * \brief Right shift
+ */
+TVM_DLL const Op& shift_right();
+
+/*!
+ * \brief See pesudo code
+ *
+ * Construct a big uint that may not be representable by int64
+ *
+ * Expr large_uint_imm(uint32_t v0, uin32_t v1) {
+ * return (v1 << 32) | v0;
+ * }
+ */
+TVM_DLL const Op& large_uint_imm();
+
+/*!
+ * \brief See pesudo code
+ *
+ * Handle address_of(Load *op) {
+ * return &op->buffer_var[index];
+ * }
+ */
+TVM_DLL const Op& address_of();
+
+/*!
+ * \brief Same as select, used for unsafe memory access.
+ *
+ * Type tvm_if_then_else(cond, a, b) {
+ * return cond ? a : b;
+ * }
+ */
+TVM_DLL const Op& if_then_else();
+
+/*!
+ * \brief See pesudo code
+ *
+ * bool isnullptr(void* handle) {
+ * return handle == nullptr
+ * }
+ */
+TVM_DLL const Op& isnullptr();
+
+/*!
+ * \brief Check if value is nan
+ */
+TVM_DLL const Op& isnan();
+
+/*!
+ * \brief Popcount
+ */
+TVM_DLL const Op& popcount();
+
+/*!
+ * \brief Fused multiply add
+ *
+ * Type fma(a, b, c) {
+ * return a * b + c;
+ * }
+ */
+TVM_DLL const Op& fma();
+
+/*!
+ * \brief Call an extern C function with given name
+ * and signature from the types of args in the runtime environment.
+ *
+ * Type call_extern(name, args...) {
+ * return dlsym(name)(args...);
+ * }
+ *
+ * \note This intrinsic does not provide any type checking,
+ * and is main used for backward compatibility reasons.
+ * Always consider use pre-registered and typed tvm::Op first.
+ */
+TVM_DLL const Op& call_extern();
+
+/*!
+ * \brief Call an LLVM intrinsic with a given intrinsic id
+ * and signature from the types of args in the runtime environment.
+ *
+ * Type call_llvm_intrin(intrin_id, args...) {
+ * return dlsym(name)(args...);
+ * }
+ *
+ * \note This op does not provide any type checking.
+ */
+TVM_DLL const Op& call_llvm_intrin();
+
+/*!
+ * \brief Call an SPIRV GLSL450 intrinsic.
+ *
+ * Type call_spirv_glsl450(intrin_id, args...) {
+ * return dlsym(name)(args...);
+ * }
+ *
+ * \note This op does not provide any type checking.
+ */
+TVM_DLL const Op& call_spirv_glsl450();
+
+// TODO(tvm-team) revisit the builtins below
+// some of them can simply become ops with special codegen attr.
+/*!
+ * \brief Prefetch a cacheline
+ */
+TVM_DLL const Op& prefetch();
+
+/*!
+ * \brief Get head access address with memory access pattern info.
+ *
+ * This operator also marks range of the memory access
+ * The offset and extent are in unit of the DType(including vectorization factor).
+ * rw_mask is a bit_mask setting whether the access is a read(1) or write(2).
+ * The access is assume to happen in the current expression.
+ *
+ * PtrType tvm_access_ptr(Expr dtype, DType* data,
+ * int offset, int extent,
+ * int rw_mask) {
+ * // DType == dtype.type();
+ * return &data[offset];
+ * }
+ */
+TVM_DLL const Op& tvm_access_ptr();
+
+/*!
+ * \brief Create a function local static handle that iniitalizes to nullptr.
+ * can be used to cache function local static resources.
+ */
+TVM_DLL const Op& tvm_static_handle();
+
+/*!
+ * \brief Return a unique context id, used for hint of workspace separation.
+ * Different context id ganrantees not having overlapping workspace.
+ */
+TVM_DLL const Op& tvm_context_id();
+
+/*!
+ * \brief tvm_tuple is not an actual function and cannot codegen.
+ * It is used to represent tuple structure in value field of AttrStmt,
+ * for the sake of giving hint to optimization.
+ *
+ * Handle tvm_tuple(value0, value1, ..., value_n);
+ */
+TVM_DLL const Op& tvm_tuple();
+
+/*!
+ * \brief See pesudo code
+ *
+ * Type tvm_struct_get(StructType* arr, int index, int field_id) {
+ * return arr[index]->field;
+ * }
+ * \sa TVMStructFieldKind
+ */
+TVM_DLL const Op& tvm_struct_get();
+
+/*!
+ * \brief See pesudo code
+ *
+ * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) {
+ * arr[index]->field = value;
+ * }
+ * \sa TVMStructFieldKind
+ */
+TVM_DLL const Op& tvm_struct_set();
+
+/*!
+ * \brief See pesudo code
+ *
+ * void tvm_throw_last_error() {
+ * throw TVMGetLastError();
+ * }
+ */
+TVM_DLL const Op& tvm_throw_last_error();
+
+/*!
+ * \brief See pesudo code
+ *
+ * dtype in {shape, array, arg_value, arg_tcode}
+ *
+ * Handle tvm_stack_alloca(string dtype, int num) {
+ * return new on stack dtype[num];
+ * }
+ */
+TVM_DLL const Op& tvm_stack_alloca();
+
+/*!
+ * \brief Allocate a shape tuple on stack, return the handle.
+ *
+ * Handle tvm_stack_make_shape(list args) {
+ * ret = alloca stack int64_t[len(args)];
+ * for i in range(len(args)):
+ * ret[i] = args[i]
+ * return &ret[0];
+ * }
+ */
+TVM_DLL const Op& tvm_stack_make_shape();
+
+/*!
+ * \brief Allocate a NDArray(DLTensor) on stack, return the handle.
+ *
+ * Type tvm_stack_make_array(Expr data,
+ * Expr shape,
+ * Expr strides,
+ * Expr ndim,
+ * Expr dtype,
+ * Expr elem_offset) {
+ * ret = alloca stack DLTensor();
+ * ret->data = data;
+ * ret->shape = shape;
+ * ret->strides = strides != 0 ? strides : nullptr;
+ * ret->ndim = ndim;
+ * ret->dtype = dtype.type();
+ * ret->byte_offset = elem_offset * sizeof(dtype);
+ * return ret;
+ * }
+ */
+TVM_DLL const Op& tvm_stack_make_array();
+
+/*!
+ * \brief See pesudo code
+ *
+ * int tvm_call_packed(name, TVMValue* args) {
+ * ModuleNode* env = GetCurrentEnv();
+ * const PackedFunc* f = env->GetFuncFromEnv(name);
+ * (*f)(args, type_code_of(args), len(args));
+ * return 0;
+ * }
+ */
+TVM_DLL const Op& tvm_call_packed();
+
+/*!
+ * \brief See pesudo code
+ *
+ * int tvm_call_trace_packed(name, TVMValue* args) {
+ * ModuleNode* env = GetCurrentEnv();
+ * const PackedFunc* f = env->GetFuncFromEnv(name);
+ * (*f)(args, type_code_of(args), len(args));
+ * return 0;
+ * }
+ */
+TVM_DLL const Op& tvm_call_trace_packed();
+
+/*!
+ * \brief See pesudo code
+ * Mark the content as thread local context, can get optimized
+ * by only call the call once at thread start.
+ *
+ * Do not allow nesting(getting a thread context from another).
+ *
+ * Handle tvm_thread_context(Expr call) {
+ * return call;
+ * }
+ */
+TVM_DLL const Op& tvm_thread_context();
+
+/*!
+ * \brief Lowered version of call packed, the space of value and
+ * type codes are explicitly allocated.
+ *
+ * int tvm_call_packed_lowered(name,
+ * TVMValue* value_stack,
+ * int* tcode_stack,
+ * int begin,
+ * int end) {
+ * ModuleNode* env = GetCurrentEnv();
+ * const PackedFunc* f = env->GetFuncFromEnv(name);
+ * f->CallPacked(TVMArgs(value_stack[begin:end],
+ * tcode_stack[begin:end]),
+ * TVMRetValue(value_stack + end, tcode_stack + end));
+ * }
+ */
+TVM_DLL const Op& tvm_call_packed_lowered();
+
+/*!
+ * \brief Lowered version of trace intrinsic, the space of value and
+ * type codes are explicitly allocated. The return value is the
+ * (end - 1) value on the stack.
+ *
+ * int tvm_call_trace_packed_lowered(name,
+ * TVMValue* value_stack,
+ * int* tcode_stack,
+ * int begin,
+ * int end) {
+ * ModuleNode* env = GetCurrentEnv();
+ * const PackedFunc* f = env->GetFuncFromEnv(name);
+ * f->CallPacked(TVMArgs(value_stack[begin:end],
+ * tcode_stack[begin:end]),
+ * TVMRetValue(value_stack + end, tcode_stack + end));
+ * }
+ */
+TVM_DLL const Op& tvm_call_trace_packed_lowered();
+
+/*!
+ * \brief See pseudo code
+ *
+ * int tvm_storage_sync(std::string storage_scope) {
+ * __sync(storage_scope);
+ * return 0;
+ * }
+ */
+TVM_DLL const Op& tvm_storage_sync();
+
+/*!
+ * \brief See pseudo code
+ *
+ * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) {
+ * return (value passed in by warp indicated by this_warp_id);
+ * }
+ *
+ * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) {
+ * return (value passed in by warp indicated by this_warp_id - offset);
+ * }
+ *
+ * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) {
+ * return (value passed in by warp indicated by this_warp_id + offset);
+ * }
+ *
+ * unsigned tvm_warp_activemask() {
+ * return (32-bit mask of currently active threads in the calling warp);
+ * }
+ *
+ * Parameter warp_id indicates the source thread ID in a warp.
+ *
+ * Parameter offset indicates the relative distance to this_warp_id.
+ *
+ * Parameter width indicates the number of threads involved in one
+ * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync,
+ * __shfl_down_sync and __activemask.
+ *
+ * Parameter warp_size is the size of a warp, which helps a backend
+ * to determine wheter the width paramter is legal.
+ *
+ */
+TVM_DLL const Op& tvm_warp_shuffle();
+TVM_DLL const Op& tvm_warp_shuffle_up();
+TVM_DLL const Op& tvm_warp_shuffle_down();
+TVM_DLL const Op& tvm_warp_activemask();
+
+/*!
+ * \brief Initialize the global barrier.
+ * Call this at beginning of kernel that need global barrier.
+ */
+TVM_DLL const Op& tvm_global_barrier_kinit();
+
+/*!
+ * \brief See pesudo code
+ *
+ * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
+ * Var reduce_temp0, .., Var thread_idx1, ...) {
+ * // constraint by the other thread_idx remain the same.
+ * // reduce_temp is used to save intermediate result.
+ * reduce_temp0, ... = reduce(combiner, source0, ..., cond
+ * over [thread_idx1, thread_idx2] passed by any caller)
+ * }
+ */
+TVM_DLL const Op& tvm_thread_allreduce();
+
+// TODO(tvm-team) TensorCore specific intrinsics should be directly registered under
+// cuda. namespace and used through op.
+/*!
+ * \brief tvm intrinsic for tensor core load operators.
+ *
+ * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
+ * Expr index, Expr buffer_ptr, Expr stride,
+ * StringImm layout) {
+ * // m, n, k are the shape of wmma fragment.
+ * // Determine fragment layout(column-major or row major) by layout.
+ * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
+ * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
+ * }
+ */
+TVM_DLL const Op& tvm_load_matrix_sync();
+
+/*!
+ * \brief tvm intrinsic for tensor core mma_sync operators.
+ *
+ * void tvm_mma_sync(Var fragment_d, Expr index_d,
+ * Var fragment_a, Expr index_a,
+ * Var fragment_b, Expr index_b,
+ * Var fragment_c, Expr index_c) {
+ * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
+ * fragment_b[index_b], fragment_c[index_c]);
+ * }
+ */
+TVM_DLL const Op& tvm_mma_sync();
+
+/*!
+ * \brief tvm intrinsic for tensor core bmma_sync operators.
+ *
+ * void tvm_bmma_sync(Var fragment_d, Expr index_d,
+ * Var fragment_a, Expr index_a,
+ * Var fragment_b, Expr index_b,
+ * Var fragment_c, Expr index_c) {
+ * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a],
+ * fragment_b[index_b], fragment_c[index_c]);
+ * }
+ */
+TVM_DLL const Op& tvm_bmma_sync();
+
+/*!
+ * \brief tvm intrinsic for tensor core fill_fragment operators.
+ *
+ * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
+ * Expr index, Expr value) {
+ * // m, n, k are the shape of wmma fragment
+ * // fragments must be in 'wmma.accumulator' scope.
+ * nvcuda::wmma::fill_fragment(fragment[index], value);
+ * }
+ */
+TVM_DLL const Op& tvm_fill_fragment();
+
+/*!
+ * \brief tvm intrinsic for tensor core store operators.
+ *
+ * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
+ * Expr index, Expr buffer_ptr, Expr stride,
+ * StringImm layout) {
+ * // m, n, k are the shape of wmma fragment
+ * // fragments must be in 'wmma.accumulator' scope.
+ * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
+ * }
+ */
+TVM_DLL const Op& tvm_store_matrix_sync();
+
+// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
+/*!
+ * \brief Get the high level half of the vector
+ */
+TVM_DLL const Op& vectorhigh();
+
+/*!
+ * \brief Get the low-level half of the vector
+ */
+TVM_DLL const Op& vectorlow();
+
+/*!
+ * \brief Concat two vectors.
+ */
+TVM_DLL const Op& vectorcombine();
+
+/*! \brief The kind of structure field info used in intrinsic */
+enum TVMStructFieldKind : int {
+ // array head address
+ kArrAddr,
+ kArrData,
+ kArrShape,
+ kArrStrides,
+ kArrNDim,
+ kArrTypeCode,
+ kArrTypeBits,
+ kArrTypeLanes,
+ kArrByteOffset,
+ kArrDeviceId,
+ kArrDeviceType,
+ kArrKindBound_,
+ // TVMValue field
+ kTVMValueContent,
+ kTVMValueKindBound_
+};
+} // namespace builtin
+} // namespace tir
+} // namespace tvm
+#endif // TVM_TIR_BUILTIN_H_
/*! \brief Intrinsic functions that are pure. */
PureIntrinsic = 5
};
- /*! \brief The name of the function/intrinsic. */
- String name;
+ /*!
+ * \brief The operator(function) being invoked
+ *
+ * - It can be tvm::Op which corresponds to the primitive operators(intrinsics).
+ * - It can also be another function in the IRModule (GlobalVar).
+ */
+ RelayExpr op;
+
/*! \brief The arguments. */
Array<PrimExpr> args;
/*! \brief Type of calls. */
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
- v->Visit("name", &name);
+ v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
- return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) &&
+ return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) &&
equal(call_type, other->call_type);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
- hash_reduce(name);
+ hash_reduce(op);
hash_reduce(args);
hash_reduce(call_type);
}
/*! \return Whether call node is pure. */
bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); }
- /*!
- * \return Whether call node corresponds to a defined intrinsic.
- * \param intrin_name The name of the intrinsic.
- */
- bool is_intrinsic(const char* intrin_name) const {
- return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name);
- }
-
- /*! \return Whether call node can be vectorized. */
- bool is_vectorizable() const;
-
static constexpr const char* _type_key = "tir.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
-
- // Build-in intrinsics
- static constexpr const char* reinterpret = "reinterpret";
- static constexpr const char* bitwise_and = "bitwise_and";
- static constexpr const char* bitwise_not = "bitwise_not";
- static constexpr const char* bitwise_xor = "bitwise_xor";
- static constexpr const char* bitwise_or = "bitwise_or";
- static constexpr const char* shift_left = "shift_left";
- static constexpr const char* shift_right = "shift_right";
- static constexpr const char* popcount = "popcount";
- static constexpr const char* likely = "likely";
- static constexpr const char* prefetch = "prefetch";
- static constexpr const char* isnan = "isnan";
- static constexpr const char* isfinite = "isfinite";
- static constexpr const char* isinf = "isinf";
-
- /*! \brief Vectorizable intrinsic list. */
- static const char* vectorizable_intrinsics[];
};
/*!
public:
using CallType = CallNode::CallType;
- TVM_DLL Call(DataType dtype, String name, Array<PrimExpr> args, CallType call_type);
+ TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type);
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
};
}
return ret;
}
-
-/*! \brief namespace of TVM Intrinsic functions */
-namespace intrinsic {
-/*!
- * \brief See pesudo code
- *
- * Construct a big uint that may not be representable by int64
- *
- * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) {
- * return (v1 << 32) | v0;
- * }
- */
-constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm";
-/*!
- * \brief See pesudo code
- *
- * Handle tvm_address_of(Load *op) {
- * return &op->buffer_var[index];
- * }
- */
-constexpr const char* tvm_address_of = "tvm_address_of";
-/*!
- * \brief Same as select, used for unsafe memory access.
- *
- * Type tvm_if_then_else(cond, a, b) {
- * return cond ? a : b;
- * }
- */
-constexpr const char* tvm_if_then_else = "tvm_if_then_else";
-/*!
- * \brief Get head access address with memory access pattern info.
- *
- * This operator also marks range of the memory access
- * The offset and extent are in unit of the DType(including vectorization factor).
- * rw_mask is a bit_mask setting whether the access is a read(1) or write(2).
- * The access is assume to happen in the current expression.
- *
- * PtrType tvm_access_ptr(Expr dtype, DType* data,
- * int offset, int extent,
- * int rw_mask) {
- * // DType == dtype.type();
- * return &data[offset];
- * }
- */
-constexpr const char* tvm_access_ptr = "tvm_access_ptr";
-/*!
- * \brief Create a function local static handle that iniitalizes to nullptr.
- * can be used to cache function local static resources.
- */
-constexpr const char* tvm_static_handle = "tvm_static_handle";
-/*!
- * \brief Return a unique context id, used for hint of workspace separation.
- * Different context id ganrantees not having overlapping workspace.
- */
-constexpr const char* tvm_context_id = "tvm_context_id";
-/*!
- * \brief tvm_tuple is not an actual function and cannot codegen.
- * It is used to represent tuple structure in value field of AttrStmt,
- * for the sake of giving hint to optimization.
- *
- * Handle tvm_tuple(value0, value1, ..., value_n);
- */
-constexpr const char* tvm_tuple = "tvm_tuple";
-/*!
- * \brief See pesudo code
- *
- * Type tvm_struct_get(StructType* arr, int index, int field_id) {
- * return arr[index]->field;
- * }
- * \sa TVMStructFieldKind
- */
-constexpr const char* tvm_struct_get = "tvm_struct_get";
-/*!
- * \brief See pesudo code
- *
- * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) {
- * arr[index]->field = value;
- * }
- * \sa TVMStructFieldKind
- */
-constexpr const char* tvm_struct_set = "tvm_struct_set";
-/*!
- * \brief See pesudo code
- *
- * bool tvm_handle_is_null(void* handle) {
- * return handle == nullptr
- * }
- */
-constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
-/*!
- * \brief See pesudo code
- *
- * void tvm_throw_last_error() {
- * throw TVMGetLastError();
- * }
- */
-constexpr const char* tvm_throw_last_error = "tvm_throw_last_error";
-/*!
- * \brief See pesudo code
- *
- * dtype in {shape, array, arg_value, arg_tcode}
- *
- * Handle tvm_stack_alloca(string dtype, int num) {
- * return new on stack dtype[num];
- * }
- */
-constexpr const char* tvm_stack_alloca = "tvm_stack_alloca";
-/*!
- * \brief Allocate a shape tuple on stack, return the handle.
- *
- * Handle tvm_stack_make_shape(list args) {
- * ret = alloca stack int64_t[len(args)];
- * for i in range(len(args)):
- * ret[i] = args[i]
- * return &ret[0];
- * }
- */
-constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
-/*!
- * \brief Allocate a NDArray(DLTensor) on stack, return the handle.
- *
- * Type tvm_stack_make_array(Expr data,
- * Expr shape,
- * Expr strides,
- * Expr ndim,
- * Expr dtype,
- * Expr elem_offset) {
- * ret = alloca stack DLTensor();
- * ret->data = data;
- * ret->shape = shape;
- * ret->strides = strides != 0 ? strides : nullptr;
- * ret->ndim = ndim;
- * ret->dtype = dtype.type();
- * ret->byte_offset = elem_offset * sizeof(dtype);
- * return ret;
- * }
- */
-constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
-/*!
- * \brief See pesudo code
- *
- * int tvm_call_packed(name, TVMValue* args) {
- * ModuleNode* env = GetCurrentEnv();
- * const PackedFunc* f = env->GetFuncFromEnv(name);
- * (*f)(args, type_code_of(args), len(args));
- * return 0;
- * }
- */
-constexpr const char* tvm_call_packed = "tvm_call_packed";
-/*!
- * \brief See pesudo code
- *
- * int tvm_call_trace_packed(name, TVMValue* args) {
- * ModuleNode* env = GetCurrentEnv();
- * const PackedFunc* f = env->GetFuncFromEnv(name);
- * (*f)(args, type_code_of(args), len(args));
- * return 0;
- * }
- */
-constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed";
-/*!
- * \brief See pesudo code
- * Mark the content as thread local context, can get optimized
- * by only call the call once at thread start.
- *
- * Do not allow nesting(getting a thread context from another).
- *
- * Handle tvm_thread_context(Expr call) {
- * return call;
- * }
- */
-constexpr const char* tvm_thread_context = "tvm_thread_context";
-/*!
- * \brief Lowered version of call packed, the space of value and
- * type codes are explicitly allocated.
- *
- * int tvm_call_packed_lowered(name,
- * TVMValue* value_stack,
- * int* tcode_stack,
- * int begin,
- * int end) {
- * ModuleNode* env = GetCurrentEnv();
- * const PackedFunc* f = env->GetFuncFromEnv(name);
- * f->CallPacked(TVMArgs(value_stack[begin:end],
- * tcode_stack[begin:end]),
- * TVMRetValue(value_stack + end, tcode_stack + end));
- * }
- */
-constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
-/*!
- * \brief Lowered version of trace intrinsic, the space of value and
- * type codes are explicitly allocated. The return value is the
- * (end - 1) value on the stack.
- *
- * int tvm_call_trace_packed_lowered(name,
- * TVMValue* value_stack,
- * int* tcode_stack,
- * int begin,
- * int end) {
- * ModuleNode* env = GetCurrentEnv();
- * const PackedFunc* f = env->GetFuncFromEnv(name);
- * f->CallPacked(TVMArgs(value_stack[begin:end],
- * tcode_stack[begin:end]),
- * TVMRetValue(value_stack + end, tcode_stack + end));
- * }
- */
-constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered";
-/*!
- * \brief See pseudo code
- *
- * int tvm_storage_sync(std::string storage_scope) {
- * __sync(storage_scope);
- * return 0;
- * }
- */
-constexpr const char* tvm_storage_sync = "tvm_storage_sync";
-
-/*!
- * \brief See pseudo code
- *
- * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) {
- * return (value passed in by warp indicated by this_warp_id);
- * }
- *
- * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) {
- * return (value passed in by warp indicated by this_warp_id - offset);
- * }
- *
- * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) {
- * return (value passed in by warp indicated by this_warp_id + offset);
- * }
- *
- * unsigned tvm_warp_activemask() {
- * return (32-bit mask of currently active threads in the calling warp);
- * }
- *
- * Parameter warp_id indicates the source thread ID in a warp.
- *
- * Parameter offset indicates the relative distance to this_warp_id.
- *
- * Parameter width indicates the number of threads involved in one
- * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync,
- * __shfl_down_sync and __activemask.
- *
- * Parameter warp_size is the size of a warp, which helps a backend
- * to determine wheter the width paramter is legal.
- *
- */
-constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
-constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up";
-constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down";
-constexpr const char* tvm_warp_activemask = "tvm_warp_activemask";
-
-/*!
- * \brief Initialize the global barrier.
- * Call this at beginning of kernel that need global barrier.
- */
-constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
-/*!
- * \brief See pesudo code
- *
- * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
- * Var reduce_temp0, .., Var thread_idx1, ...) {
- * // constraint by the other thread_idx remain the same.
- * // reduce_temp is used to save intermediate result.
- * reduce_temp0, ... = reduce(combiner, source0, ..., cond
- * over [thread_idx1, thread_idx2] passed by any caller)
- * }
- */
-constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
-/*!
- * \brief tvm intrinsic for tensor core load operators.
- *
- * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
- * Expr index, Expr buffer_ptr, Expr stride,
- * StringImm layout) {
- * // m, n, k are the shape of wmma fragment.
- * // Determine fragment layout(column-major or row major) by layout.
- * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
- * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
- * }
- */
-constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
-/*!
- * \brief tvm intrinsic for tensor core mma_sync operators.
- *
- * void tvm_mma_sync(Var fragment_d, Expr index_d,
- * Var fragment_a, Expr index_a,
- * Var fragment_b, Expr index_b,
- * Var fragment_c, Expr index_c) {
- * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
- * fragment_b[index_b], fragment_c[index_c]);
- * }
- */
-constexpr const char* tvm_mma_sync = "tvm_mma_sync";
-/*!
- * \brief tvm intrinsic for tensor core bmma_sync operators.
- *
- * void tvm_bmma_sync(Var fragment_d, Expr index_d,
- * Var fragment_a, Expr index_a,
- * Var fragment_b, Expr index_b,
- * Var fragment_c, Expr index_c) {
- * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a],
- * fragment_b[index_b], fragment_c[index_c]);
- * }
- */
-constexpr const char* tvm_bmma_sync = "tvm_bmma_sync";
-/*!
- * \brief tvm intrinsic for tensor core fill_fragment operators.
- *
- * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
- * Expr index, Expr value) {
- * // m, n, k are the shape of wmma fragment
- * // fragments must be in 'wmma.accumulator' scope.
- * nvcuda::wmma::fill_fragment(fragment[index], value);
- * }
- */
-constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
-/*!
- * \brief tvm intrinsic for tensor core store operators.
- *
- * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
- * Expr index, Expr buffer_ptr, Expr stride,
- * StringImm layout) {
- * // m, n, k are the shape of wmma fragment
- * // fragments must be in 'wmma.accumulator' scope.
- * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
- * }
- */
-constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
-
-/*! \brief The kind of structure field info used in intrinsic */
-enum TVMStructFieldKind : int {
- // array head address
- kArrAddr,
- kArrData,
- kArrShape,
- kArrStrides,
- kArrNDim,
- kArrTypeCode,
- kArrTypeBits,
- kArrTypeLanes,
- kArrByteOffset,
- kArrDeviceId,
- kArrDeviceType,
- kArrKindBound_,
- // TVMValue field
- kTVMValueContent,
- kTVMValueKindBound_
-};
-} // namespace intrinsic
-
} // namespace tir
} // namespace tvm
* While we could have express parameter unpacking and constraint using
* normal statements, making buffer_map as first class citizen of PrimFunc
* will make program analysis much easier.
- *
- * \note This field can be nullptr
*/
Map<tir::Var, Buffer> buffer_map;
* \param attrs Additional function attributes.
*/
TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
- Map<tir::Var, Buffer> buffer_map = NullValue<Map<tir::Var, Buffer>>(),
+ Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
DictAttrs attrs = NullValue<DictAttrs>());
TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
#ifndef TVM_TIR_OP_H_
#define TVM_TIR_OP_H_
+#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
// Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName) \
- inline PrimExpr OpName(PrimExpr x) { \
- return tir::Call(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \
+#define TVM_DECLARE_INTRIN_UNARY(OpName) \
+ inline PrimExpr OpName(PrimExpr x) { \
+ static const Op& op = Op::Get("tir." #OpName); \
+ return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \
}
TVM_DECLARE_INTRIN_UNARY(exp);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/op_attr_types.h
+ * \brief Attribute types in the Op registry for TIR ops.
+ *
+ * These attributes can be set via OpRegEntry::set_attr
+ *
+ * \sa tvm/ir/op.h
+ */
+#ifndef TVM_TIR_OP_ATTR_TYPES_H_
+#define TVM_TIR_OP_ATTR_TYPES_H_
+
+#include <tvm/runtime/container.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Global symbol of the op after lowering.
+ */
+using TGlobalSymbol = String;
+
+/*!
+ * \brief Whether the op is overloaded for vector form.
+ */
+using TVectorizable = bool;
+
+} // namespace tir
+} // namespace tvm
+#endif // TVM_TIR_OP_ATTR_TYPES_H_
* \param dtype The data type
* \return Expr a expression with dtype.
*/
-inline PrimExpr TypeAnnotation(DataType dtype) {
- return tir::Call(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic);
-}
+TVM_DLL PrimExpr TypeAnnotation(DataType dtype);
// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);
(out, _) = proc.communicate()
if proc.returncode != 0:
- msg = "Compilation error:\n"
+ msg = code
+ msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
import tvm._ffi
import tvm.runtime._ffi_api
-from tvm.runtime import convert, DataType
-from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
+from tvm.runtime import DataType
+import tvm.tir
+from tvm.tir.expr import Cast as _Cast, FloatImm as _FloatImm
def register(type_name, type_code):
if t.lanes > 1:
dtype += "x" + str(t.lanes)
if isinstance(op, (_Cast, _FloatImm)):
- return _Call(dtype, extern_func_name, convert([op.value]),
- _Call.Extern)
- return _Call(dtype, extern_func_name, convert([op.a, op.b]),
- _Call.Extern)
+ return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value)
+ return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b)
return lower
--------
register_intrin_rule : The registeration function for intrin rule.
"""
+ name = op.name
+ assert name.startswith("tir.")
+ prefix = name[4:]
+
if op.dtype == "float32":
- return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
+ return call_pure_extern(op.dtype, "%sf" % prefix, *op.args)
if op.dtype == "float64":
- return call_pure_extern(op.dtype, op.name, *op.args)
+ return call_pure_extern(op.dtype, prefix, *op.args)
return op
register_intrin_rule : The registeration function for intrin rule.
"""
if str(op.dtype).startswith("float"):
- return call_pure_extern(op.dtype, op.name, *op.args)
+ return call_pure_extern(op.dtype, op.op.name[4:], *op.args)
return None
# opencl pattern for exp
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
- return call_pure_intrin(args[0].dtype, 'likely', *args)
+ return call_pure_intrin(args[0].dtype, 'tir.likely', *args)
def max_num_threads(func_id, args):
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr
+from tvm.ir import PrimExpr, Op
import tvm.ir._ffi_api
from . import generic as _generic
from . import _ffi_api
def __invert__(self):
if _dtype_is_float(self):
raise RuntimeError("Cannot use ~ operator on float type Expr.")
- return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic)
+ return _ffi_api.bitwise_not(self)
def __lt__(self, other):
return _ffi_api._OpLT(self, other)
dtype : str
The return data type
- name : str
- The name of the function
+ op : Union[RelayExpr, str]
+ The function to be called, or the name
+ to the global tvm.Op
args : list of Expr
The input arguments to the call
PureExtern = 2
Intrinsic = 4
PureIntrinsic = 5
- def __init__(self, dtype, name, args, call_type):
- self.__init_handle_by_constructor__(
- _ffi_api.Call, dtype, name, args, call_type)
+ def __init__(self, dtype, op, args, call_type):
+ if isinstance(op, str):
+ if not op.startswith("tir."):
+ raise ValueError(
+ ("Cannot handle str op argument %s. This function only handles str " +
+ "argument with the tir namespace. If you are " +
+ "certain about the intrinsic name, pass in Op.get(name) instead") % op)
+ op = Op.get(op)
+ self.__init_handle_by_constructor__(
+ _ffi_api.Call, dtype, op, args, call_type)
@tvm._ffi.register_object("tir.Let")
expr : Expr
The expression will likely tag.
"""
- return _expr.Call(expr.dtype, "likely", [expr],
+ return _expr.Call(expr.dtype, "tir.likely", [expr],
_expr.Call.PureIntrinsic)
def get(self):
"""Operators used in TIR expression."""
import tvm._ffi
from tvm.runtime import convert, const
-from tvm.ir import Array
+from tvm.ir import Array, Op
from .buffer import Buffer
-from .expr import Call, Var, CommReducer
+from .expr import Call, StringImm, Var, CommReducer
from . import _ffi_api
"""Build intrinsics that packs the buffer.
"""
assert buf.shape
- shape = Call("handle", "tvm_stack_make_shape", buf.shape,
+ shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape,
Call.Intrinsic)
- strides = Call("handle", "tvm_stack_make_shape", buf.strides,
+ strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides,
Call.Intrinsic) if buf.strides else 0
pack_args = [buf.data,
shape,
len(buf.shape),
const(0, dtype=buf.dtype),
buf.elem_offset]
- return Call("handle", "tvm_stack_make_array",
+ return Call("handle", Op.get("tir.tvm_stack_make_array"),
pack_args, Call.Intrinsic)
def call_packed(*args):
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
return Call(
- "int32", "tvm_call_packed", call_args, Call.Intrinsic)
+ "int32", Op.get("tir.tvm_call_packed"), call_args, Call.Intrinsic)
def call_pure_intrin(dtype, func_name, *args):
The call expression.
"""
return Call(
- dtype, func_name, convert(args), Call.PureExtern)
+ dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.PureExtern)
def call_extern(dtype, func_name, *args):
The call expression.
"""
return Call(
- dtype, func_name, convert(args), Call.Extern)
+ dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.Extern)
def call_llvm_intrin(dtype, name, *args):
from tvm.target import codegen
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
- return call_pure_intrin(dtype, 'llvm_intrin', tvm.tir.const(llvm_id, 'uint32'), *args)
+ return call_pure_intrin(dtype, Op.get("tir.call_llvm_intrin"),
+ tvm.tir.const(llvm_id, 'uint32'), *args)
def any(*args):
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
call_args.insert(0, trace_action)
return tvm.tir.Call(
- args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic)
+ args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args, tvm.tir.Call.Intrinsic)
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "exp", x)
+ return call_pure_intrin(x.dtype, "tir.exp", x)
def exp2(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "exp2", x)
+ return call_pure_intrin(x.dtype, "tir.exp2", x)
def exp10(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "exp10", x)
+ return call_pure_intrin(x.dtype, "tir.exp10", x)
def erf(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "erf", x)
+ return call_pure_intrin(x.dtype, "tir.erf", x)
def tanh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tanh", x)
+ return call_pure_intrin(x.dtype, "tir.tanh", x)
def sigmoid(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "sigmoid", x)
+ return call_pure_intrin(x.dtype, "tir.sigmoid", x)
def log(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "log", x)
+ return call_pure_intrin(x.dtype, "tir.log", x)
def log2(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "log2", x)
+ return call_pure_intrin(x.dtype, "tir.log2", x)
def log10(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "log10", x)
+ return call_pure_intrin(x.dtype, "tir.log10", x)
def log1p(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "log1p", x)
+ return call_pure_intrin(x.dtype, "tir.log1p", x)
def tan(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "tan", x)
+ return call_pure_intrin(x.dtype, "tir.tan", x)
def cos(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "cos", x)
+ return call_pure_intrin(x.dtype, "tir.cos", x)
def cosh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "cosh", x)
+ return call_pure_intrin(x.dtype, "tir.cosh", x)
def acos(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "acos", x)
+ return call_pure_intrin(x.dtype, "tir.acos", x)
def acosh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "acosh", x)
+ return call_pure_intrin(x.dtype, "tir.acosh", x)
def sin(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "sin", x)
+ return call_pure_intrin(x.dtype, "tir.sin", x)
def sinh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "sinh", x)
+ return call_pure_intrin(x.dtype, "tir.sinh", x)
def asin(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "asin", x)
+ return call_pure_intrin(x.dtype, "tir.asin", x)
def asinh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "asinh", x)
+ return call_pure_intrin(x.dtype, "tir.asinh", x)
def atan(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "atan", x)
+ return call_pure_intrin(x.dtype, "tir.atan", x)
def atanh(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "atanh", x)
+ return call_pure_intrin(x.dtype, "tir.atanh", x)
def atan2(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "atan2", x1, x2)
+ return call_pure_intrin(x1.dtype, "tir.atan2", x1, x2)
def sqrt(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "sqrt", x)
+ return call_pure_intrin(x.dtype, "tir.sqrt", x)
def rsqrt(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "rsqrt", x)
+ return call_pure_intrin(x.dtype, "tir.rsqrt", x)
def floor(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "nextafter", x1, x2)
+ return call_pure_intrin(x1.dtype, "tir.nextafter", x1, x2)
def hypot(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "hypot", x1, x2)
+ return call_pure_intrin(x1.dtype, "tir.hypot", x1, x2)
def copysign(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "copysign", x1, x2)
+ return call_pure_intrin(x1.dtype, "tir.copysign", x1, x2)
def ldexp(x1, x2):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x1.dtype, "ldexp", x1, x2)
+ return call_pure_intrin(x1.dtype, "tir.ldexp", x1, x2)
def isnan(x):
y : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "popcount", x)
+ return call_pure_intrin(x.dtype, "tir.popcount", x)
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
z : PrimExpr
The result.
"""
- return call_pure_intrin(x.dtype, "fmod", x, y)
+ return call_pure_intrin(x.dtype, "tir.fmod", x, y)
def if_then_else(cond, t, f):
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr_functor.h>
#include <algorithm>
Entry VisitExpr_(const CallNode* op) final {
// only special handle >> and & which can be
// used for index calculation.
- if (op->is_intrinsic(CallNode::shift_right)) {
+
+ if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
- } else if (op->is_intrinsic(CallNode::bitwise_and)) {
+ } else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else {
return Everything(op->dtype);
Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr real_condition = condition;
+ static auto op_likely = Op::Get("tir.likely");
+
if (auto call = condition.as<CallNode>()) {
- if (call->is_intrinsic(CallNode::likely)) {
+ if (call->op.same_as(op_likely)) {
real_condition = call->args[0];
}
}
PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
- if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) {
+ static auto op_if_then_else = Op::Get("tir.if_then_else");
+ if (op->op.same_as(op_if_then_else)) {
PrimExpr cond = this->VisitExpr(op->args[0]);
PrimExpr true_value, false_value;
{
false_value.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->name, {cond, true_value, false_value}, op->call_type);
+ return Call(op->dtype, op->op, {cond, true_value, false_value}, op->call_type);
}
}
return StmtExprMutator::VisitExpr_(op);
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
Entry VisitExpr_(const CallNode* op) final {
// only special handle >> which can be
// used for index calculation.
- if (op->is_intrinsic(CallNode::shift_right)) {
+ if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
} else {
return Everything();
#define TVM_ARITH_PATTERN_MATCH_H_
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tuple>
bool Match_(const ObjectRef& node) const {
if (const tir::CallNode* ptr = node.as<tir::CallNode>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false;
- if (ptr->name != Op::kName) return false;
+ if (!ptr->op.same_as(Op::GetOp())) return false;
detail::PCallExprMatchFunctor fmatch(ptr);
detail::tuple_for_each(fmatch, args_);
return fmatch.matched_;
};
// arithemetic intrinsics
-#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
+#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \
struct OpName { \
static PrimExpr Eval(Array<PrimExpr> args) { \
- return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \
+ return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \
} \
- static constexpr const char* kName = IntrinStr; \
+ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
}; \
template <typename TA, typename TB> \
inline PCallExpr<OpName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
}
-TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
-TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right");
-TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and");
-TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
-TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
+TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left);
+TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right);
+TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and);
+TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or);
+TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor);
// unary intrinsics
-#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
- struct OpName { \
- static PrimExpr Eval(Array<PrimExpr> args) { \
- return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \
- } \
- static constexpr const char* kName = IntrinStr; \
- }; \
- template <typename TA> \
- inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \
- return PCallExpr<OpName, TA>(a.derived()); \
+#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \
+ struct OpName { \
+ static PrimExpr Eval(Array<PrimExpr> args) { \
+ return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \
+ } \
+ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
+ }; \
+ template <typename TA> \
+ inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \
+ return PCallExpr<OpName, TA>(a.derived()); \
}
-TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
+TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not);
// if_then_else
struct PIfThenElseOp {
static PrimExpr Eval(Array<PrimExpr> args) {
- return tir::Call(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic);
+ return tir::Call(args[1].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic);
}
- static constexpr const char* kName = "tvm_if_then_else";
+ static const Op& GetOp() { return tir::builtin::if_then_else(); }
};
/*!
#include "rewrite_simplify.h"
#include <tvm/arith/analyzer.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <algorithm>
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CallNode>();
if (op == nullptr) return ret;
- if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) {
+
+ if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) {
return op->args[0];
- } else if (op->is_intrinsic(CallNode::shift_right)) {
+ } else if (op->op.same_as(tir::builtin::shift_right())) {
if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
// the operator overload will eagerly constant fold.
return op->args[0] >> op->args[1];
}
- } else if (op->is_intrinsic(CallNode::bitwise_and)) {
+ } else if (op->op.same_as(tir::builtin::shift_left())) {
if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
// the operator overload will eagerly constant fold.
return op->args[0] & op->args[1];
}
}
ExprDeepEqual expr_equal;
- if (op->is_intrinsic(CallNode::likely)) {
+ if (op->op.same_as(tir::builtin::likely())) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (expr_equal(constraint, op->args[0])) {
#include "codegen_hybrid.h"
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <cctype>
#include <iomanip>
os << "]";
}
void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
- if (op->is_intrinsic(CallNode::bitwise_and)) {
+ if (op->op.same_as(builtin::bitwise_and())) {
PrintBinaryIntrinsitc(op, "&", os, this);
- } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
+ } else if (op->op.same_as(builtin::bitwise_xor())) {
PrintBinaryIntrinsitc(op, "^", os, this);
- } else if (op->is_intrinsic(CallNode::bitwise_or)) {
+ } else if (op->op.same_as(builtin::bitwise_or())) {
PrintBinaryIntrinsitc(op, "|", os, this);
- } else if (op->is_intrinsic(CallNode::shift_left)) {
+ } else if (op->op.same_as(builtin::shift_left())) {
PrintBinaryIntrinsitc(op, "<<", os, this);
- } else if (op->is_intrinsic(CallNode::shift_right)) {
+ } else if (op->op.same_as(builtin::shift_right())) {
PrintBinaryIntrinsitc(op, ">>", os, this);
- } else if (op->is_intrinsic(CallNode::bitwise_not)) {
+ } else if (op->op.same_as(builtin::bitwise_not())) {
CHECK_EQ(op->args.size(), 1U);
os << "(~";
PrintExpr(op->args[0], os);
os << ')';
- } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
+ } else if (op->op.same_as(builtin::if_then_else())) {
PrintExpr(op->args[1], os);
os << " if ";
PrintExpr(op->args[0], os);
os << " else ";
PrintExpr(op->args[2], os);
+ } else if (op->op.same_as(builtin::call_extern())) {
+ StringImm fname = Downcast<StringImm>(op->args[0]);
+ os << fname << "(";
+ for (size_t i = 1; i < op->args.size(); i++) {
+ PrintExpr(op->args[i], os);
+ if (i < op->args.size() - 1) {
+ os << ", ";
+ }
+ }
+ os << ")";
} else {
- os << op->name << "(";
+ auto* ptr_op = op->op.as<OpNode>();
+ CHECK(ptr_op != nullptr);
+ std::string name = ptr_op->name;
+ CHECK_EQ(name.compare(0, 4, "tir."), 0);
+ os << name.substr(4) << "(";
for (size_t i = 0; i < op->args.size(); i++) {
PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
// find operator by name
const Op& Op::Get(const String& name) {
const OpRegEntry* reg = OpRegistry::Global()->Get(name);
- CHECK(reg != nullptr) << "Operator " << name << " is not registered";
+ CHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered";
return reg->op();
}
Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
- doc << "@" << Doc::Text(op->name) << "(";
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ doc << "@" << Doc::Text(ptr_op->name) << "(";
+ } else {
+ // TODO(bohan): Print out the name by he global var in the module.
+ auto* op_gvar = op->op.as<GlobalVarNode>();
+ CHECK(op_gvar != nullptr);
+ doc << "@" << Doc::Text(op_gvar->name_hint) << "(";
+ }
std::vector<Doc> args;
for (const auto& arg : op->args) {
args.push_back(Print(arg));
Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
- doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body);
+ doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body);
return doc;
}
Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
Doc doc;
- doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")"
- << PrintBody(op->body);
+ doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine()
+ << Print(op->body);
return doc;
}
* \return compiler_begin op
*/
inline const Op& CompilerBeginOp() {
- static Op op = Op::Get("annotation.compiler_begin");
+ static auto op = Op::Get("annotation.compiler_begin");
return op;
}
* \return compiler_end op
*/
inline const Op& CompilerEndOp() {
- static Op op = Op::Get("annotation.compiler_end");
+ static auto op = Op::Get("annotation.compiler_end");
return op;
}
#define TVM_TARGET_INTRIN_RULE_H_
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <string>
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
- std::string name = T()(call->dtype, call->name);
+ // Use string based dispatch to extern for backward compact
+ // TODO(tvm-team) replace once the new dispatching system is inplace.
+ const OpNode* op = call->op.as<OpNode>();
+ CHECK(op != nullptr);
+ std::string name = op->name;
+ CHECK_EQ(name.substr(0, 4), "tir.");
+ name = T()(call->dtype, name.substr(4));
+
if (name.length() != 0) {
- *rv = Call(call->dtype, name, call->args, CallNode::PureExtern);
+ Array<PrimExpr> new_args = {StringImm(name)};
+ for (auto arg : call->args) {
+ new_args.push_back(arg);
+ }
+ *rv = Call(call->dtype, tir::builtin::call_extern(), new_args, CallNode::PureExtern);
} else {
*rv = e;
}
};
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
- if (op->is_intrinsic("llvm_intrin")) {
+ if (op->op.same_as(builtin_call_llvm_intrin_)) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
if (id == ::llvm::Intrinsic::ctpop) {
PrimExpr e = ARMPopcount(op);
vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt_args.push_back(e);
- return tir::Call(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
+ return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt_args, CallNode::PureIntrinsic);
}
// Popcount lowering rule:
vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
- PrimExpr vcnt8 = tir::Call(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt8 =
+ tir::Call(uint8_type, builtin_call_llvm_intrin_, vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit
Array<PrimExpr> vcnt16_args;
vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
- PrimExpr vcnt16 = tir::Call(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt16 =
+ tir::Call(uint16_type, builtin_call_llvm_intrin_, vcnt16_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 16) {
return vcnt16;
}
vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
- PrimExpr vcnt32 = tir::Call(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
+ PrimExpr vcnt32 =
+ tir::Call(uint32_type, builtin_call_llvm_intrin_, vcnt32_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 32) {
return vcnt32;
}
vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
- return tir::Call(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
+ return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt64_args, CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
}
llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index,
int kind) {
- if (kind < intrinsic::kArrKindBound_) {
+ if (kind < builtin::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
} else {
}
}
switch (kind) {
- case intrinsic::kArrAddr: {
+ case builtin::kArrAddr: {
return builder_->CreateInBoundsGEP(buf, index);
}
- case intrinsic::kArrData: {
+ case builtin::kArrData: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)});
}
- case intrinsic::kArrShape: {
+ case builtin::kArrShape: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)});
}
- case intrinsic::kArrStrides: {
+ case builtin::kArrStrides: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)});
}
- case intrinsic::kArrNDim: {
+ case builtin::kArrNDim: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
}
- case intrinsic::kArrTypeCode: {
+ case builtin::kArrTypeCode: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)});
}
- case intrinsic::kArrTypeBits: {
+ case builtin::kArrTypeBits: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)});
}
- case intrinsic::kArrTypeLanes: {
+ case builtin::kArrTypeLanes: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)});
}
- case intrinsic::kArrByteOffset: {
+ case builtin::kArrByteOffset: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
}
- case intrinsic::kArrDeviceId: {
+ case builtin::kArrDeviceId: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)});
}
- case intrinsic::kArrDeviceType: {
+ case builtin::kArrDeviceType: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)});
}
- case intrinsic::kTVMValueContent: {
+ case builtin::kTVMValueContent: {
CHECK_EQ(t.lanes(), 1);
CHECK(t.is_handle() || t.bits() == 64);
if (t.is_int()) {
}
}
-llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
- std::vector<llvm::Value*> arg_values(op->args.size());
- for (size_t i = 0; i < op->args.size(); ++i) {
- arg_values[i] = MakeValue(op->args[i]);
+llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol,
+ const Array<PrimExpr>& args, bool skip_first_arg) {
+ std::vector<llvm::Value*> arg_values;
+ for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
+ arg_values.push_back(MakeValue(args[i]));
}
std::vector<llvm::Type*> arg_types;
for (llvm::Value* v : arg_values) {
arg_types.push_back(v->getType());
}
- llvm::FunctionType* ftype =
- llvm::FunctionType::get(GetLLVMType(GetRef<PrimExpr>(op)), arg_types, false);
+ llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false);
// Check if it is available in global function table as injected function.
- auto it = gv_func_map_.find(op->name);
+ auto it = gv_func_map_.find(global_symbol);
if (it != gv_func_map_.end()) {
if (it->second == nullptr) {
- gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name);
- it = gv_func_map_.find(op->name);
+ gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol);
+ it = gv_func_map_.find(global_symbol);
}
#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second));
#endif
return builder_->CreateCall(ext_callee, arg_values);
} else {
- llvm::Function* f = module_->getFunction(op->name);
+ llvm::Function* f = module_->getFunction(global_symbol);
if (f == nullptr) {
f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
- op->name.operator llvm::StringRef(), module_.get());
+ global_symbol.operator llvm::StringRef(), module_.get());
}
#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(f);
}
llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
- if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
+ if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
return CreateCallPacked(op);
- } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) {
+ } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) {
return CreateCallTracePacked(op);
- } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) {
+ } else if (op->op.same_as(builtin::tvm_static_handle())) {
return CreateStaticHandle();
- } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
+ } else if (op->op.same_as(builtin::tvm_throw_last_error())) {
builder_->CreateRet(ConstInt32(-1));
return ConstInt32(-1);
- } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
+ } else if (op->op.same_as(builtin::tvm_struct_get())) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* ref =
this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
- if (kind == intrinsic::kArrAddr) {
+ if (kind == builtin::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
} else {
return builder_->CreateLoad(ref);
}
- } else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
+ } else if (op->op.same_as(builtin::tvm_struct_set())) {
CHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
- CHECK(kind != intrinsic::kArrAddr);
+ CHECK(kind != builtin::kArrAddr);
if (value->getType()->isPointerTy()) {
value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType());
}
builder_->CreateStore(value, ref);
return ConstInt32(0);
- } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
+ } else if (op->op.same_as(builtin::tvm_stack_alloca())) {
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImmNode>()->value;
return WithFunctionEntry([&]() -> llvm::AllocaInst* {
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const ForNode* op) override;
llvm::Value* CreateIntrinsic(const CallNode* op) override;
- llvm::Value* CreateCallExtern(const CallNode* op) override;
+ llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
+ bool skip_first_arg) override;
protected:
void AddStartupFunction() final;
llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr};
llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr};
- std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
+ std::unordered_map<String, llvm::GlobalVariable*> gv_func_map_;
// context for direct dynamic lookup
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
return it->second;
}
-llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) {
+llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol,
+ const Array<PrimExpr>& args, bool skip_first_arg) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
- for (size_t i = 0; i < op->args.size(); ++i) {
- arg_value.push_back(MakeValue(op->args[i]));
+ for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
+ arg_value.push_back(MakeValue(args[i]));
arg_type.push_back(arg_value.back()->getType());
}
- llvm::FunctionType* ftype =
- llvm::FunctionType::get(GetLLVMType(GetRef<PrimExpr>(op)), arg_type, false);
- llvm::Function* f = module_->getFunction(op->name);
+ llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false);
+ llvm::Function* f = module_->getFunction(global_symbol);
if (f == nullptr) {
f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
- op->name.operator llvm::StringRef(), module_.get());
+ global_symbol.operator llvm::StringRef(), module_.get());
}
llvm::CallInst* call = builder_->CreateCall(f, arg_value);
return call;
}
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
- if (op->is_intrinsic("llvm_intrin")) {
+ if (op->op.same_as(builtin_call_llvm_intrin_)) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
// type as LLVM.
llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op))
: llvm::Type::getVoidTy(*ctx_);
-
llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
<< llvm::Intrinsic::getName(id, {});
return builder_->CreateCall(f, arg_value);
- } else if (op->is_intrinsic(CallNode::bitwise_and)) {
+ } else if (op->op.same_as(builtin::bitwise_and())) {
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(CallNode::bitwise_or)) {
+ } else if (op->op.same_as(builtin::bitwise_or())) {
return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(CallNode::bitwise_not)) {
+ } else if (op->op.same_as(builtin::bitwise_not())) {
return builder_->CreateNot(MakeValue(op->args[0]));
- } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
+ } else if (op->op.same_as(builtin::bitwise_xor())) {
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(CallNode::shift_left)) {
+ } else if (op->op.same_as(builtin::shift_left())) {
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(CallNode::shift_right)) {
+ } else if (op->op.same_as(builtin::shift_right())) {
if (op->args[0].dtype().is_int()) {
return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else {
return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
}
- } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
+ } else if (op->op.same_as(builtin::tvm_storage_sync())) {
return CreateStorageSync(op);
- } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
+ } else if (op->op.same_as(builtin::address_of())) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
const RampNode* r = l->index.as<RampNode>();
addrspace = llvm::dyn_cast<llvm::PointerType>(ptr->getType())->getAddressSpace();
}
return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
- } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) {
+ } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
- } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
+ } else if (op->op.same_as(builtin::isnullptr())) {
return builder_->CreateIsNull(MakeValue(op->args[0]));
- } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+ } else if (op->op.same_as(builtin::large_uint_imm())) {
CHECK_EQ(op->args.size(), 2U);
uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val);
- } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
+ } else if (op->op.same_as(builtin::if_then_else())) {
CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition";
using llvm::BasicBlock;
BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_);
value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block);
return value;
- } else if (op->is_intrinsic(CallNode::reinterpret)) {
+ } else if (op->op.same_as(builtin::reinterpret())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
- } else if (op->is_intrinsic(CallNode::isnan)) {
+ } else if (op->op.same_as(builtin::isnan())) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
return builder_->CreateFCmpUNO(a, a);
- } else if (op->is_intrinsic("vectorlow")) {
+ } else if (op->op.same_as(builtin::vectorlow())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
return CreateVecSlice(v, 0, l / 2);
- } else if (op->is_intrinsic("vectorhigh")) {
+ } else if (op->op.same_as(builtin::vectorhigh())) {
llvm::Value* v = MakeValue(op->args[0]);
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
return CreateVecSlice(v, l / 2, l / 2);
- } else if (op->is_intrinsic("vectorcombine")) {
+ } else if (op->op.same_as(builtin::vectorcombine())) {
llvm::Value* v0 = MakeValue(op->args[0]);
llvm::Value* v1 = MakeValue(op->args[1]);
int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else {
- LOG(FATAL) << "unknown intrinsic " << op->name;
+ LOG(FATAL) << "unknown intrinsic " << op->op;
return nullptr;
}
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
- if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
- return CreateIntrinsic(op);
- } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
- return CreateCallExtern(op);
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ auto call_op = GetRef<Op>(ptr_op);
+ if (op->op.same_as(builtin_call_extern_)) {
+ // call extern intrinsic
+ CHECK_GE(op->args.size(), 1U);
+ auto global_symbol = Downcast<StringImm>(op->args[0]);
+ return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), global_symbol->value, op->args,
+ true);
+ } else if (op_attr_global_symbol_.count(call_op)) {
+ // call extern if the op itself have a global symbol.
+ return this->CreateCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
+ op->args, false);
+ } else {
+ return CreateIntrinsic(op);
+ }
} else {
- LOG(FATAL) << "Unknown call type "
- << "name= " << op->name << " call_type= " << op->call_type;
+ CHECK(op->op.as<GlobalVarNode>());
+ LOG(FATAL) << "Do not yet support cross function call";
return nullptr;
}
}
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
+#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
// create intrinstic given call
virtual llvm::Value* CreateIntrinsic(const CallNode* op);
// create extern function call
- virtual llvm::Value* CreateCallExtern(const CallNode* op);
+ // skip first arg mode used for call extern intrinsic.
+ virtual llvm::Value* CreateCallExtern(Type ret_type, String global_symbol,
+ const Array<PrimExpr>& args, bool skip_first_arg);
// Get the corresponding thread index
virtual llvm::Value* GetThreadIndex(const IterVar& iv);
// Get the corresponding thread index
std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const VarNode*> volatile_buf_;
+ // Cache potential common path ops to slightly improve lookup time.
+ // global symbol table.
+ OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
+ const Op& builtin_call_extern_ = builtin::call_extern();
+ const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin();
/*! \brief Helper struct for debug infos. */
struct DebugInfo {
std::unique_ptr<llvm::DIBuilder> di_builder_;
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};
int offset = 0;
- if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
+ if (op->op.same_as(builtin::tvm_warp_shuffle())) {
offset = 0;
- } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
+ } else if (op->op.same_as(builtin::tvm_warp_shuffle_up())) {
offset = 2;
- } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
+ } else if (op->op.same_as(builtin::tvm_warp_shuffle_down())) {
offset = 4;
} else {
return false;
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
- } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
+ } else if (op->op.same_as(builtin::tvm_warp_activemask())) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
DTypeToLLVMType(DataType::Float(32, from.lanes())),
{
- MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret,
+ MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(),
{op->value}, tir::CallNode::PureIntrinsic)),
MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())),
/*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
return CallVectorIntrin(
::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
DTypeToLLVMType(DataType::Float(32, from.lanes())),
- {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret,
+ {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(),
{op->value}, tir::CallNode::PureIntrinsic))});
}
#endif
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
+// TODO(tvm-team): migrate the legalization transformations as a separate
+// set of rules in TIR that can be shared across backends.
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
- PrimExpr ret = tir::Call(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic);
+ PrimExpr ret = exp(x * ln10);
*rv = ret;
});
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_two = make_const(x.dtype(), -2);
- PrimExpr exp_neg2x = tir::Call(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic);
- PrimExpr exp_pos2x = tir::Call(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic);
+ PrimExpr exp_neg2x = exp(neg_two * x);
+ PrimExpr exp_pos2x = exp(two * x);
PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
- PrimExpr sin_x = tir::Call(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic);
- PrimExpr cos_x = tir::Call(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic);
- PrimExpr tan_x = sin_x / cos_x;
+ PrimExpr tan_x = sin(x) / cos(x);
*rv = tan_x;
});
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
- PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
- PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
+ PrimExpr exp_negx = exp(neg_one * x);
+ PrimExpr exp_posx = exp(x);
PrimExpr ret = (exp_posx + exp_negx) / two;
*rv = ret;
});
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
- PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
- PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
+ PrimExpr exp_negx = exp(neg_one * x);
+ PrimExpr exp_posx = exp(x);
PrimExpr ret = (exp_posx - exp_negx) / two;
*rv = ret;
});
#include <tvm/runtime/registry.h>
#include <tvm/target/codegen.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <string>
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic);
+ *rv =
+ tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::PureIntrinsic);
}
template <unsigned id, int num_signature>
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic);
+ *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::Intrinsic);
}
} // namespace codegen
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <sstream>
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64.";
+
+ const OpNode* op = call->op.as<OpNode>();
+ CHECK(op != nullptr);
+ std::string name = op->name;
+ CHECK_EQ(name.substr(0, 4), "tir.");
+
std::ostringstream intrinsic_name;
- intrinsic_name << "__nv_" << call->name;
+ intrinsic_name << "__nv_" << name.substr(4);
if (call->dtype.bits() == 32) intrinsic_name << "f";
- *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern);
+
+ Array<PrimExpr> new_args = {StringImm(intrinsic_name.str())};
+ for (auto arg : call->args) {
+ new_args.push_back(arg);
+ }
+ *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern);
}
namespace llvm {
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
using namespace tir;
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
+
+ const OpNode* op = call->op.as<OpNode>();
+ CHECK(op != nullptr);
+ std::string name = op->name;
+ CHECK_EQ(name.substr(0, 4), "tir.");
+
std::ostringstream intrinsic_name;
- intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits();
- *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern);
+ intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits();
+
+ Array<PrimExpr> new_args = {StringImm(intrinsic_name.str())};
+ for (auto arg : call->args) {
+ new_args.push_back(arg);
+ }
+
+ *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern);
}
inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) {
// get own lane in self (__lane_id)
PrimExpr minus_one = tir::make_const(DataType::Int(32), -1);
PrimExpr zero = tir::make_zero(DataType::Int(32));
- PrimExpr lo =
- Call(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, CallNode::PureExtern);
- PrimExpr self =
- Call(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, CallNode::PureExtern);
+ PrimExpr lo = Call(DataType::Int(32), builtin::call_extern(),
+ {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, CallNode::PureExtern);
+ PrimExpr self = Call(DataType::Int(32), builtin::call_extern(),
+ {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, CallNode::PureExtern);
// compute lane to get from
PrimExpr width = call->args[3];
PrimExpr index;
- if (call->name == "tvm_warp_shuffle") {
+ if (call->op.same_as(builtin::tvm_warp_shuffle())) {
PrimExpr src_lane = call->args[2];
index = src_lane + (self & ~(width - 1));
- } else if (call->name == "tvm_warp_shuffle_up") {
+ } else if (call->op.same_as(builtin::tvm_warp_shuffle_up())) {
PrimExpr delta = call->args[2];
index = self - delta;
index = Select(index < (self & ~(width - 1)), self, index);
} else {
- CHECK_EQ(call->name, "tvm_warp_shuffle_down");
+ CHECK(call->op.same_as(builtin::tvm_warp_shuffle_down()));
PrimExpr delta = call->args[2];
index = self + delta;
index = Select((self & (width - 1)) + delta >= width, self, index);
}
PrimExpr res =
- Call(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, CallNode::PureExtern);
+ Call(var.dtype(), builtin::call_extern(),
+ {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}, CallNode::PureExtern);
*rv = res;
}
// Print a reference expression to a buffer.
std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index,
int kind) {
- if (kind < intrinsic::kArrKindBound_) {
+ if (kind < builtin::kArrKindBound_) {
std::ostringstream os;
os << "(((DLTensor*)";
this->PrintExpr(buffer, os);
os << ")";
- if (kind == intrinsic::kArrAddr) {
+ if (kind == builtin::kArrAddr) {
os << " + ";
this->PrintExpr(index, os);
os << ")";
os << "].";
// other case: get fields.
switch (kind) {
- case intrinsic::kArrData:
+ case builtin::kArrData:
os << "data";
break;
- case intrinsic::kArrShape:
+ case builtin::kArrShape:
os << "shape";
break;
- case intrinsic::kArrStrides:
+ case builtin::kArrStrides:
os << "strides";
break;
- case intrinsic::kArrNDim:
+ case builtin::kArrNDim:
os << "ndim";
break;
- case intrinsic::kArrTypeCode:
+ case builtin::kArrTypeCode:
os << "dtype.code";
break;
- case intrinsic::kArrTypeBits:
+ case builtin::kArrTypeBits:
os << "dtype.bits";
break;
- case intrinsic::kArrByteOffset:
+ case builtin::kArrByteOffset:
os << "byte_offset";
break;
- case intrinsic::kArrTypeLanes:
+ case builtin::kArrTypeLanes:
os << "dtype.lanes";
break;
- case intrinsic::kArrDeviceId:
+ case builtin::kArrDeviceId:
os << "ctx.device_id";
break;
- case intrinsic::kArrDeviceType:
+ case builtin::kArrDeviceType:
os << "ctx.device_type";
break;
default:
os << ')';
return os.str();
} else {
- CHECK_LT(kind, intrinsic::kTVMValueKindBound_);
+ CHECK_LT(kind, builtin::kTVMValueKindBound_);
std::ostringstream os;
os << "(((TVMValue*)";
this->PrintExpr(buffer, os);
PrintExpr(op->a, os);
}
-void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
- if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
- os << op->name << "(";
- for (size_t i = 0; i < op->args.size(); i++) {
- this->PrintExpr(op->args[i], os);
- if (i < op->args.size() - 1) {
- os << ", ";
- }
+void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
+ bool skip_first_arg, std::ostream& os) { // NOLINT(*)
+ os << global_symbol << "(";
+ for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
+ this->PrintExpr(args[i], os);
+ if (i < args.size() - 1) {
+ os << ", ";
}
- os << ")";
- } else if (op->is_intrinsic(CallNode::bitwise_and)) {
- PrintBinaryIntrinsic(op, " & ", os, this);
- } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
- CHECK_EQ(op->args.size(), 2U);
- uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
- uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
- uint64_t val = (high << 32U) | low;
- PrintUIntConst(op->dtype, val, os, this);
- } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
- PrintBinaryIntrinsic(op, " ^ ", os, this);
- } else if (op->is_intrinsic(CallNode::bitwise_or)) {
- PrintBinaryIntrinsic(op, " | ", os, this);
- } else if (op->is_intrinsic(CallNode::bitwise_not)) {
- CHECK_EQ(op->args.size(), 1U);
- os << "(~";
- this->PrintExpr(op->args[0], os);
- os << ')';
- } else if (op->is_intrinsic(CallNode::shift_left)) {
- PrintBinaryIntrinsic(op, " << ", os, this);
- } else if (op->is_intrinsic(CallNode::shift_right)) {
- PrintBinaryIntrinsic(op, " >> ", os, this);
- } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
- os << "(";
- PrintExpr(op->args[0], os);
- os << " ? ";
- PrintExpr(op->args[1], os);
- os << " : ";
- PrintExpr(op->args[2], os);
- os << ")";
- } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const LoadNode* l = op->args[0].as<LoadNode>();
- CHECK(op->args.size() == 1 && l);
- os << "((";
- this->PrintType(l->dtype.element_of(), os);
- os << " *)" << this->GetVarID(l->buffer_var.get()) << " + ";
- this->PrintExpr(l->index, os);
- os << ')';
- } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
- CHECK_EQ(op->args.size(), 3U);
- os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value);
- } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
- CHECK_EQ(op->args.size(), 1U);
- os << "(";
- this->PrintExpr(op->args[0], os);
- os << " == NULL)";
- } else if (op->is_intrinsic(CallNode::reinterpret)) {
- // generate (*( TYPE *)(&(ARG)))
- os << "(*(";
- this->PrintType(op->dtype, os);
- os << " *)(&(";
- this->PrintExpr(op->args[0], os);
- os << ")))";
- } else if (op->is_intrinsic(CallNode::isnan)) {
- os << "(";
- this->PrintExpr(op->args[0], os);
- os << " != ";
- this->PrintExpr(op->args[0], os);
- os << ")";
- } else {
- if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
- LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype;
+ }
+ os << ")";
+}
+
+void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ auto call_op = GetRef<Op>(ptr_op);
+
+ if (op->op.same_as(builtin_call_extern_)) {
+ CHECK_GE(op->args.size(), 1U);
+ auto func = Downcast<StringImm>(op->args[0]);
+ this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
+ } else if (op_attr_global_symbol_.count(call_op)) {
+ // call extern if the op itself have a global symbol.
+ this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
+ op->args, false, os);
+ } else if (op->op.same_as(builtin::bitwise_and())) {
+ PrintBinaryIntrinsic(op, " & ", os, this);
+ } else if (op->op.same_as(builtin::large_uint_imm())) {
+ CHECK_EQ(op->args.size(), 2U);
+ uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
+ uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
+ uint64_t val = (high << 32U) | low;
+ PrintUIntConst(op->dtype, val, os, this);
+ } else if (op->op.same_as(builtin::bitwise_xor())) {
+ PrintBinaryIntrinsic(op, " ^ ", os, this);
+ } else if (op->op.same_as(builtin::bitwise_or())) {
+ PrintBinaryIntrinsic(op, " | ", os, this);
+ } else if (op->op.same_as(builtin::bitwise_not())) {
+ CHECK_EQ(op->args.size(), 1U);
+ os << "(~";
+ this->PrintExpr(op->args[0], os);
+ os << ')';
+ } else if (op->op.same_as(builtin::shift_left())) {
+ PrintBinaryIntrinsic(op, " << ", os, this);
+ } else if (op->op.same_as(builtin::shift_right())) {
+ PrintBinaryIntrinsic(op, " >> ", os, this);
+ } else if (op->op.same_as(builtin::if_then_else())) {
+ os << "(";
+ PrintExpr(op->args[0], os);
+ os << " ? ";
+ PrintExpr(op->args[1], os);
+ os << " : ";
+ PrintExpr(op->args[2], os);
+ os << ")";
+ } else if (op->op.same_as(builtin::address_of())) {
+ const LoadNode* l = op->args[0].as<LoadNode>();
+ CHECK(op->args.size() == 1 && l);
+ os << "((";
+ this->PrintType(l->dtype.element_of(), os);
+ os << " *)" << this->GetVarID(l->buffer_var.get()) << " + ";
+ this->PrintExpr(l->index, os);
+ os << ')';
+ } else if (op->op.same_as(builtin::tvm_struct_get())) {
+ CHECK_EQ(op->args.size(), 3U);
+ os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value);
+ } else if (op->op.same_as(builtin::isnullptr())) {
+ CHECK_EQ(op->args.size(), 1U);
+ os << "(";
+ this->PrintExpr(op->args[0], os);
+ os << " == NULL)";
+ } else if (op->op.same_as(builtin::reinterpret())) {
+ // generate (*( TYPE *)(&(ARG)))
+ os << "(*(";
+ this->PrintType(op->dtype, os);
+ os << " *)(&(";
+ this->PrintExpr(op->args[0], os);
+ os << ")))";
+ } else if (op->op.same_as(builtin::isnan())) {
+ os << "(";
+ this->PrintExpr(op->args[0], os);
+ os << " != ";
+ this->PrintExpr(op->args[0], os);
+ os << ")";
} else {
- LOG(FATAL) << "Unresolved call type " << op->call_type;
+ LOG(FATAL) << "Unresolved call " << op->op;
}
+ } else {
+ CHECK(op->op.as<GlobalVarNode>());
+ LOG(FATAL) << "Do not yet support cross function call";
}
}
if (is_const(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
if (call) {
- if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
+ if (call->op.same_as(builtin::tvm_storage_sync())) {
this->PrintStorageSync(call);
return;
- } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) {
+ } else if (call->op.same_as(builtin::tvm_struct_set())) {
CHECK_EQ(call->args.size(), 4);
std::string value = PrintExpr(call->args[3]);
std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1],
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_H_
+#include <tvm/ir/op.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
virtual bool IsScopePartOfType() const { return true; }
/*!
+ * \brief Print external function call.
+ * \param ret_type The return type.
+ * \param global_symbol The symbolc of the target function.
+ * \param args The arguments to the function.
+ * \param skip_first_arg Whether to skip the first arguments.
+ * \param os The output stream.
+ */
+ virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
+ bool skip_first_arg, std::ostream& os); // NOLINT(*)
+ /*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;
+ /*! \brief Record of ops that have pre-defined global symbol. */
+ OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
+ // cache commonly used ops
+ const Op& builtin_call_extern_ = builtin::call_extern();
private:
/*! \brief whether to print in SSA form */
}
void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
- if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
+ if (op->op.same_as(builtin::tvm_stack_alloca())) {
std::string stack_name = GetUniqueName("stack");
const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImmNode* num = op->args[1].as<IntImmNode>();
this->PrintIndent();
this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
os << stack_name;
- } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
+ } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
const StringImmNode* s = op->args[0].as<StringImmNode>();
CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
int64_t begin = op->args[3].as<IntImmNode>()->value;
}
this->PrintGetFuncFromBackend(func_name, packed_func_name);
this->PrintFuncCall(packed_func_name, num_args);
- } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
+ } else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
} else {
os << sret;
}
+void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
+ bool skip_first_arg, std::ostream& os) { // NOLINT(*)
+ DataType ret_dtype = GetRuntimeDataType(ret_type);
+ if (ret_dtype.is_vector()) {
+ //
+ // Emit an unsupported vector call
+ //
+ // v = intrin_f((float4*)A[0], (float4*)B[0])
+ //
+ // as
+ //
+ // float4 __ret;
+ // {
+ // float4 __arg0 = ((float4*)A)[0];
+ // float4 __arg1 = ((float4*)B)[0];
+ // __ret.x = intrin_f(__arg0.x, __arg1.x);
+ // __ret.y = intrin_f(__arg0.y, __arg1.y);
+ // __ret.z = intrin_f(__arg0.z, __arg1.z);
+ // __ret.w = intrin_f(__arg0.w, __arg1.w);
+ // }
+ // v = __ret;
+ //
+ // Declare the result vector.
+ std::string sret = GetUniqueName("_");
+ this->PrintIndent();
+ this->PrintType(ret_dtype, stream);
+ stream << ' ' << sret << ";\n";
+ {
+ // Load arguments.
+ std::vector<std::string> sargs;
+ size_t arg_begin = static_cast<size_t>(skip_first_arg);
+ for (size_t i = arg_begin; i < args.size(); ++i) {
+ std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
+ sargs.push_back(std::move(val));
+ }
+
+ // Emit a scalar call for each lane.
+ for (int i = 0; i < ret_dtype.lanes(); ++i) {
+ std::ostringstream scall;
+ scall << global_symbol << "(";
+ for (size_t j = 0; j < sargs.size(); ++j) {
+ if (j > 0) scall << ", ";
+ PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
+ }
+ scall << ")";
+ PrintVecElemStore(sret, ret_dtype, i, scall.str());
+ }
+ }
+ os << sret;
+ } else {
+ CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
+ }
+}
+
void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
- // This is only for backward compatibility with __shfl_{up/down}.
- // A macro will be used to replace *_sync calls to legacy ones.
- if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") ||
- op->is_intrinsic("__shfl_down_sync")) {
- enable_warp_shuffle_ = true;
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ Op call_op = GetRef<Op>(ptr_op);
+ // This is only for backward compatibility with __shfl_{up/down}.
+ // A macro will be used to replace *_sync calls to legacy ones.
+ if (op_need_warp_shuffle_.get(call_op, false)) {
+ enable_warp_shuffle_ = true;
+ }
}
- if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
+ if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment(";
os << "], ";
this->PrintExpr(op->args[5], os);
os << ")";
- } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) {
+ } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::load_matrix_sync(";
os << ", ";
this->PrintExpr(op->args[6], os);
os << ")";
- } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
+ } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::store_matrix_sync(";
LOG(FATAL) << "Invalid parameters";
}
os << ")";
- } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
+ } else if (op->op.same_as(builtin::tvm_mma_sync())) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::mma_sync(";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
- } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
+ } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::bmma_sync(";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
- } else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) {
- //
- // Emit an unsupported vector call
- //
- // v = intrin_f((float4*)A[0], (float4*)B[0])
- //
- // as
- //
- // float4 __ret;
- // {
- // float4 __arg0 = ((float4*)A)[0];
- // float4 __arg1 = ((float4*)B)[0];
- // __ret.x = intrin_f(__arg0.x, __arg1.x);
- // __ret.y = intrin_f(__arg0.y, __arg1.y);
- // __ret.z = intrin_f(__arg0.z, __arg1.z);
- // __ret.w = intrin_f(__arg0.w, __arg1.w);
- // }
- // v = __ret;
- //
- // Declare the result vector.
- std::string sret = GetUniqueName("_");
- this->PrintIndent();
- this->PrintType(op->dtype, stream);
- stream << ' ' << sret << ";\n";
- {
- // Load arguments.
- std::vector<std::string> sargs;
- for (size_t i = 0; i < op->args.size(); ++i) {
- std::string val = SSAGetID(PrintExpr(op->args[i]), op->args[i].dtype());
- sargs.push_back(std::move(val));
- }
-
- // Emit a scalar call for each lane.
- for (int i = 0; i < op->dtype.lanes(); ++i) {
- std::ostringstream scall;
- scall << op->name << "(";
- for (size_t j = 0; j < op->args.size(); ++j) {
- if (j > 0) scall << ", ";
- PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall);
- }
- scall << ")";
- PrintVecElemStore(sret, op->dtype, i, scall.str());
- }
- }
- os << sret;
} else {
CodeGenC::VisitExpr_(op, os);
}
void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
if (is_const(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
- if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) {
+ if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
PrintIndent();
stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
PrintIndent();
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
+ protected:
+ void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
+ bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
+
private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final;
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
+ // Op attribute map
+ OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
}
void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
- if (op->is_intrinsic(CallNode::reinterpret)) {
+ if (op->op.same_as(builtin::reinterpret())) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
this->PrintType(op->dtype, os);
* \file intrin_rule_cuda.cc
* \brief CUDA intrinsic rules.
*/
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op_attr_types.h>
+
#include "../intrin_rule.h"
namespace tvm {
};
struct CUDAWarpIntrinsic {
- const char* operator()(DataType t, const std::string& name) const {
- if (name == intrinsic::tvm_warp_shuffle) {
- return "__shfl_sync";
- }
- if (name == intrinsic::tvm_warp_shuffle_up) {
- return "__shfl_up_sync";
- }
- if (name == intrinsic::tvm_warp_shuffle_down) {
- return "__shfl_down_sync";
- }
- if (name == intrinsic::tvm_warp_activemask) {
- return "__activemask";
+ const Op operator()(DataType t, const Op& orig_op) const {
+ if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
+ return Op::Get("tir.cuda.__shfl_sync");
+ } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
+ return Op::Get("tir.cuda.__shfl_up_sync");
+ } else {
+ CHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
+ return Op::Get("tir.cuda.__shfl_down_sync");
}
- return "";
}
};
+static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) {
+ Call call = args[0];
+ *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, CallNode::PureExtern);
+}
+
template <typename T>
static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
- const char* name = T()(call->dtype, call->name);
- *rv = Call(call->dtype, name, cuda_args, CallNode::PureExtern);
+
+ *rv =
+ Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args, CallNode::PureExtern);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern<CUDAMath>);
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask")
- .set_body(DispatchExtern<CUDAWarpIntrinsic>);
+ .set_body(DispatchCUDAWarpActiveMask);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern<CUDAMath>);
+// Register low-level builtin ops.
+// TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins.
+TVM_REGISTER_OP("tir.cuda.__shfl_sync")
+ .set_num_inputs(4)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_sync")
+ .set_attr<bool>("cuda.need_warp_shuffle", true);
+
+TVM_REGISTER_OP("tir.cuda.__shfl_up_sync")
+ .set_num_inputs(4)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_up_sync")
+ .set_attr<bool>("cuda.need_warp_shuffle", true);
+
+TVM_REGISTER_OP("tir.cuda.__shfl_down_sync")
+ .set_num_inputs(4)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_down_sync")
+ .set_attr<bool>("cuda.need_warp_shuffle", true);
+
+TVM_REGISTER_OP("tir.cuda.__activemask")
+ .set_num_inputs(0)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "__activemask")
+ .set_attr<bool>("cuda.need_warp_shuffle", true);
+
} // namespace intrin
} // namespace codegen
} // namespace tvm
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
<< "Intel warp shuffle dose not support width != warp_size";
- Array<PrimExpr> opencl_args{{call->args[1], call->args[2]}};
- *rv = Call(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern);
+ Array<PrimExpr> opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}};
+ *rv = Call(call->dtype, builtin::call_extern(), opencl_args, CallNode::PureExtern);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle);
#include "codegen_spirv.h"
#include <tvm/runtime/container.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
}
spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
- if (op->is_intrinsic("spirv_glsl450")) {
+ if (op->op.same_as(builtin::call_spirv_glsl450())) {
CHECK_GE(op->args.size(), 2U);
uint32_t inst_id = static_cast<uint32_t>(op->args[0].as<IntImmNode>()->value);
std::vector<spirv::Value> values;
values.push_back(MakeValue(op->args[i]));
}
return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values);
- } else if (op->is_intrinsic(CallNode::bitwise_and)) {
+ } else if (op->op.same_as(builtin::bitwise_and())) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
- } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
+ } else if (op->op.same_as(builtin::bitwise_xor())) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
- } else if (op->is_intrinsic(CallNode::bitwise_or)) {
+ } else if (op->op.same_as(builtin::bitwise_or())) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
- } else if (op->is_intrinsic(CallNode::bitwise_not)) {
+ } else if (op->op.same_as(builtin::bitwise_not())) {
CHECK_EQ(op->args.size(), 1U);
spirv::Value a = MakeValue(op->args[0]);
return builder_->MakeValue(spv::OpNot, a.stype, a);
- } else if (op->is_intrinsic(CallNode::shift_left)) {
+ } else if (op->op.same_as(builtin::shift_left())) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
- } else if (op->is_intrinsic(CallNode::shift_right)) {
+ } else if (op->op.same_as(builtin::shift_right())) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
} else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
}
- } else if (op->is_intrinsic(CallNode::reinterpret)) {
+ } else if (op->op.same_as(builtin::reinterpret())) {
return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
- } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+ } else if (op->op.same_as(builtin::large_uint_imm())) {
CHECK_EQ(op->args.size(), 2U);
uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
return builder_->UIntImm(builder_->GetSType(op->dtype), val);
- } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
+ } else if (op->op.same_as(builtin::tvm_storage_sync())) {
return this->CreateStorageSync(op);
- } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
+ } else if (op->op.same_as(builtin::if_then_else())) {
CHECK_EQ(op->args.size(), 3U);
spirv::Value cond = MakeValue(op->args[0]);
spirv::Label then_label = builder_->NewLabel();
phi.SetIncoming(0, then_value, then_value_label);
phi.SetIncoming(1, else_value, else_value_label);
return phi;
- } else if (op->is_intrinsic("popcount")) {
+ } else if (op->op.same_as(builtin::popcount())) {
return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else {
if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
- LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype;
+ LOG(FATAL) << "Unresolved intrinsic " << op->op << " with return type " << op->dtype;
} else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) {
- LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype;
+ LOG(FATAL) << "Unresolved extern " << op->op << " with return type " << op->dtype;
} else {
LOG(FATAL) << "Unresolved call type " << op->call_type;
}
*/
#include <GLSL.std.450.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
namespace tvm {
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
- *rv = tir::Call(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic);
+ *rv = tir::Call(call->dtype, tir::builtin::call_spirv_glsl450(), cargs,
+ tir::CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
// map struct field kind to runtime variants
// We keep two separate enums to ensure runtime/compiler isolation.
StackVM::StructFieldKind MapFieldKind(int64_t kind) {
- auto val = static_cast<intrinsic::TVMStructFieldKind>(kind);
+ auto val = static_cast<builtin::TVMStructFieldKind>(kind);
switch (val) {
- case intrinsic::kArrData:
+ case builtin::kArrData:
return StackVM::kArrData;
- case intrinsic::kArrShape:
+ case builtin::kArrShape:
return StackVM::kArrShape;
- case intrinsic::kArrAddr:
+ case builtin::kArrAddr:
return StackVM::kArrAddr;
- case intrinsic::kArrStrides:
+ case builtin::kArrStrides:
return StackVM::kArrStrides;
- case intrinsic::kArrNDim:
+ case builtin::kArrNDim:
return StackVM::kArrNDim;
- case intrinsic::kArrTypeCode:
+ case builtin::kArrTypeCode:
return StackVM::kArrTypeCode;
- case intrinsic::kArrTypeBits:
+ case builtin::kArrTypeBits:
return StackVM::kArrTypeBits;
- case intrinsic::kArrTypeLanes:
+ case builtin::kArrTypeLanes:
return StackVM::kArrTypeLanes;
- case intrinsic::kArrByteOffset:
+ case builtin::kArrByteOffset:
return StackVM::kArrByteOffset;
- case intrinsic::kArrDeviceId:
+ case builtin::kArrDeviceId:
return StackVM::kArrDeviceId;
- case intrinsic::kArrDeviceType:
+ case builtin::kArrDeviceType:
return StackVM::kArrDeviceType;
- case intrinsic::kTVMValueContent:
+ case builtin::kTVMValueContent:
return StackVM::kTVMValueContent;
default:
LOG(FATAL) << "Do not know how to map field " << kind;
}
void CodeGenStackVM::VisitExpr_(const CallNode* op) {
- if (op->is_intrinsic(intrinsic::tvm_address_of)) {
+ if (op->op.same_as(builtin::address_of())) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
- } else if (op->is_intrinsic(CallNode::reinterpret)) {
+ } else if (op->op.same_as(builtin::reinterpret())) {
this->Push(op->args[0]);
- } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
+ } else if (op->op.same_as(builtin::tvm_struct_get())) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImmNode>()->value;
this->Push(op->args[0]);
vm_.code.push_back(code);
code.v_int = MapFieldKind(kind);
vm_.code.push_back(code);
- } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
+ } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
CHECK_GE(op->args.size(), 5U);
const StringImmNode* s = op->args[0].as<StringImmNode>();
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
vm_.code.push_back(code);
code.v_int = end;
vm_.code.push_back(code);
- } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
+ } else if (op->op.same_as(builtin::tvm_stack_alloca())) {
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImmNode* num = op->args[1].as<IntImmNode>();
// add stack size to be safe.
vm_.stack_size += size;
this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
- } else if (op->name == "TVMBackendAllocWorkspace") {
+ } else if (op->op.same_as(backend_alloc_workspace_op_)) {
CHECK_EQ(op->args.size(), 5U);
this->Push(op->args[0]);
this->Push(op->args[1]);
this->Push(op->args[3]);
this->Push(op->args[4]);
this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
- } else if (op->name == "TVMBackendFreeWorkspace") {
+ } else if (op->op.same_as(backend_free_workspace_op_)) {
CHECK_EQ(op->args.size(), 3U);
this->Push(op->args[0]);
this->Push(op->args[1]);
this->Push(op->args[2]);
this->PushOp(StackVM::TVM_DEVICE_FREE);
- } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
+ } else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PushOp(StackVM::TVM_THROW_LAST_ERROR);
- } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
+ } else if (op->op.same_as(builtin::isnullptr())) {
CHECK_EQ(op->args.size(), 1U);
this->Push(op->args[0]);
this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_HANDLE);
} else {
- LOG(FATAL) << "unknown function call " << op->name;
+ LOG(FATAL) << "unknown function call " << op->op;
}
}
void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) {
if (is_const(ev->value)) return;
const CallNode* op = ev->value.as<CallNode>();
- if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
+ if (op && op->op.same_as(builtin::tvm_struct_set())) {
CHECK_EQ(op->args.size(), 4U);
this->Push(op->args[0]);
this->Push(op->args[3]);
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each global function */
std::unordered_map<std::string, int> extern_fun_idmap_;
+
+ Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace");
+ Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace");
};
} // namespace codegen
PrimExpr VisitExpr_(const CallNode* op) {
PrimExpr expr = GetRef<PrimExpr>(op);
if (op->call_type == CallNode::CallType::PureIntrinsic) {
- static std::unordered_set<std::string> piecewise_const = {"floor", "ceil", "trunc", "round"};
- if (op->name == "exp") {
+ if (op->op.same_as(op_exp_)) {
return Mul(Mutate(op->args[0]), expr);
- } else if (op->name == "log") {
+ } else if (op->op.same_as(op_log_)) {
return Div(Mutate(op->args[0]), op->args[0]);
- } else if (op->name == "sigmoid") {
+ } else if (op->op.same_as(op_sigmoid_)) {
return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr)));
- } else if (op->name == "sqrt") {
+ } else if (op->op.same_as(op_sqrt_)) {
return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0)));
- } else if (op->name == "tanh") {
+ } else if (op->op.same_as(op_tanh_)) {
return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr)));
- } else if (op->name == "pow") {
+ } else if (op->op.same_as(op_pow_)) {
auto x = op->args[0], y = op->args[1];
return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
- } else if (op->name == "fabs") {
+ } else if (op->op.same_as(op_fabs_)) {
auto type = op->args[0].dtype();
return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)),
FloatImm(type, 1.0), FloatImm(type, -1.0)));
- } else if (op->name == intrinsic::tvm_if_then_else) {
+ } else if (op->op.same_as(op_if_then_else_)) {
Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
- return Call(op->dtype, op->name, new_args, op->call_type);
- } else if (piecewise_const.count(op->name)) {
+ return Call(op->dtype, op->op, new_args, op->call_type);
+ } else if (piecewise_const.count(op->op)) {
return FloatImm(expr.dtype(), 0.0);
} else {
- throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name);
+ LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op;
}
}
NOT_IMPLEMENTED;
Array<PrimExpr> indices_;
Var input_var_;
arith::Analyzer analyzer_;
+
+ const Op& op_exp_ = Op::Get("tir.exp");
+ const Op& op_log_ = Op::Get("tir.log");
+ const Op& op_sigmoid_ = Op::Get("tir.sigmoid");
+ const Op& op_sqrt_ = Op::Get("tir.sqrt");
+ const Op& op_tanh_ = Op::Get("tir.tanh");
+ const Op& op_pow_ = Op::Get("tir.pow");
+ const Op& op_fabs_ = Op::Get("tir.fabs");
+ const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
+ std::unordered_set<RelayExpr, ObjectPtrHash, ObjectPtrEqual> piecewise_const = {
+ Op::Get("tir.floor"), Op::Get("tir.ceil"), Op::Get("tir.trunc"), Op::Get("tir.round")};
};
PrimExpr Derivative(const PrimExpr& expr, const Var& var) {
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
attr->dim_align_offset};
realize = tir::AttrStmt(
t, tir::attr::buffer_dim_align,
- Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic),
realize);
}
}
* \brief Logics related to cross thread reduction, used by ComputeOpNode.
* \file cross_thread_reduction.cc
*/
+#include <tvm/tir/builtin.h>
+
#include "compute_op.h"
#include "op_util.h"
// Apply the existing input predicate if any.
output_preds.push_back(input_pred);
- Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce,
+ Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(),
freduce_args, CallNode::Intrinsic));
reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope,
make_zero(DataType::Handle()), reduce_body);
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
+ Call(DataType::Handle(), builtin::tvm_tuple(), tuple, CallNode::Intrinsic), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
}
input_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
}
// output binding
output_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
}
// Check variable remap
}
input_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
}
// output binding
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
Array<ObjectRef> bind_spec{buffer, tensor};
output_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
+ Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop));
}
// Check variable remap
std::unordered_map<const VarNode*, PrimExpr> vmap;
#include <tvm/target/target_info.h>
#include <tvm/te/operation.h>
#include <tvm/tir/buffer.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
namespace te {
using namespace te;
-using intrinsic::tvm_address_of;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
}
}
- void VisitExpr_(const CallNode* op) final {
+ void VisitExpr_(const ProducerLoadNode* op) final {
StmtExprVisitor::VisitExpr_(op);
- args_.insert(std::make_pair(op->name, op->args));
+ args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices));
}
friend class ScheduleAnalyser;
} else if (op->attr_key == tir::attr::buffer_dim_align) {
te::Tensor tensor = Downcast<te::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
- CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
+ CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
auto& vinfo = dim_align_[tensor];
size_t dim = tuple->args[0].as<IntImmNode>()->value;
if (dim >= vinfo.size()) {
Buffer buffer_b(buffer_node_b);
if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
return Evaluate(
- Call(DataType::Handle(), intrinsic::tvm_bmma_sync,
+ Call(DataType::Handle(), builtin::tvm_bmma_sync(),
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
} else {
return Evaluate(
- Call(DataType::Handle(), intrinsic::tvm_mma_sync,
+ Call(DataType::Handle(), builtin::tvm_mma_sync(),
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
auto pload = dst.as<ProducerLoadNode>();
auto fill_fragment_call = [this, &op](const Buffer& buffer) {
- return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment,
+ return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, op->value},
CallNode::Intrinsic));
return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call);
}
- const CallNode* value = op->value.as<CallNode>();
+ const ProducerLoadNode* value = op->value.as<ProducerLoadNode>();
CHECK(value != nullptr) << "Can only load fragment from a buffer";
- auto it = strides_.find(value->name);
- CHECK(it != strides_.end()) << "Cannot find stride for " << value->name;
+ auto it = strides_.find(value->producer->GetNameHint());
+ CHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint();
auto strides = it->second;
CHECK_GE(strides.size(), 2);
PrimExpr stride = strides[strides.size() - 2];
PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
PrimExpr mutated_value = thread_idx_mutator(op->value);
- PrimExpr src = Call(value->dtype, "&", {mutated_value}, CallNode::Extern);
+ // TODO(tvm-team) The extern function name seems to be a hack.
+ PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value},
+ CallNode::Extern);
auto pload = dst.as<ProducerLoadNode>();
PrimExpr matrix_major;
}
auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
- return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync,
+ return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, src, stride, matrix_major},
CallNode::Intrinsic));
PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator(dst);
- dst = Call(DataType::Handle(), "&", {dst}, CallNode::Extern);
+ dst =
+ Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}, CallNode::Extern);
auto pload = op->value.as<ProducerLoadNode>();
auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
- return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync,
+ return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, dst, stride, StringImm("col_major")},
CallNode::Intrinsic));
args.push_back(pload->indices[i]);
args.push_back(shape[i]);
}
- auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic);
+ auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args, CallNode::Intrinsic);
Array<ObjectRef> node = {buffer, tensor};
return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer));
}
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
const auto& iter = defs_.find(V);
if (iter == defs_.end()) return false;
const CallNode* C = iter->second.as<const CallNode>();
- if (!C || C->name != intrinsic::tvm_struct_get) return false;
+ if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false;
V = C->args[0].as<VarNode>();
}
return false;
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
}
Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
make_const(DataType::Int(32), access_mask)};
- return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic);
+ return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args, tir::CallNode::Intrinsic);
}
Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
});
// Call
-Call::Call(DataType dtype, String name, Array<PrimExpr> args, CallType call_type) {
+Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, CallType call_type) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined());
}
ObjectPtr<CallNode> node = make_object<CallNode>();
node->dtype = dtype;
- node->name = std::move(name);
+ node->op = std::move(op);
node->args = std::move(args);
node->call_type = call_type;
data_ = std::move(node);
}
-const char* CallNode::vectorizable_intrinsics[] = {"floor",
- "ceil",
- "sign",
- "trunc",
- "fabs",
- "round",
- "exp",
- "tanh",
- "sqrt",
- "log",
- "sin",
- "cos",
- "pow",
- "tan",
- tir::CallNode::shift_left,
- tir::CallNode::shift_right,
- tir::CallNode::likely,
- tir::CallNode::popcount};
-
-bool CallNode::is_vectorizable() const {
- size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*);
- for (size_t i = 0; i < cnt; ++i) {
- if (name == CallNode::vectorizable_intrinsics[i]) {
- return true;
- }
- }
- return false;
-}
-
TVM_REGISTER_GLOBAL("tir.Call")
- .set_body_typed([](DataType type, String name, Array<ObjectRef> args, int call_type) {
+ .set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, int call_type) {
Array<PrimExpr> prim_expr_args;
for (const auto& it : args) {
CHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>());
prim_expr_args.push_back(Downcast<PrimExpr>(it));
}
}
- return Call(type, name, prim_expr_args, static_cast<CallNode::CallType>(call_type));
+ return Call(type, op, prim_expr_args, static_cast<CallNode::CallType>(call_type));
});
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const CallNode*>(node.get());
- p->stream << op->name << "(";
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ p->stream << ptr_op->name << "(";
+ } else {
+ auto* ptr_gvar = op->op.as<GlobalVarNode>();
+ CHECK(ptr_gvar != nullptr);
+ p->stream << "@" << ptr_gvar->name_hint << "(";
+ }
for (size_t i = 0; i < op->args.size(); ++i) {
p->Print(op->args[i]);
if (i < op->args.size() - 1) {
if (args.same_as(op->args)) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->name, args, op->call_type);
+ return Call(op->dtype, op->op, args, op->call_type);
}
}
p->PrintIndent();
p->stream << "}\n";
});
+
+PrimExpr TypeAnnotation(DataType dtype) {
+ static auto op = Op::Get("tir.type_annotation");
+ return tir::Call(dtype, op, {}, tir::CallNode::PureIntrinsic);
+}
+
+TVM_REGISTER_OP("tir.type_annotation");
+
} // namespace tir
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/op/builtin.cc
+ *
+ * builtin intrinsic operators.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/op_attr_types.h>
+
+namespace tvm {
+namespace tir {
+namespace builtin {
+
+#define TIR_DEFINE_BUILTIN_FUNC(OpName) \
+ const Op& OpName() { \
+ static const Op& op = Op::Get("tir." #OpName); \
+ return op; \
+ } \
+ TVM_REGISTER_OP("tir." #OpName)
+
+TIR_DEFINE_BUILTIN_FUNC(reinterpret).set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(likely).set_num_inputs(1).set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(bitwise_and)
+ .set_num_inputs(2)
+ .set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(bitwise_or)
+ .set_num_inputs(2)
+ .set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(bitwise_xor)
+ .set_num_inputs(2)
+ .set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(bitwise_not)
+ .set_num_inputs(1)
+ .set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(shift_left)
+ .set_num_inputs(2)
+ .set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(shift_right)
+ .set_num_inputs(2)
+ .set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(large_uint_imm).set_num_inputs(2);
+
+TIR_DEFINE_BUILTIN_FUNC(address_of).set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(if_then_else).set_num_inputs(3);
+
+TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(popcount).set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(fma).set_num_inputs(3).set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_DEFINE_BUILTIN_FUNC(call_extern);
+
+TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin);
+
+TIR_DEFINE_BUILTIN_FUNC(call_spirv_glsl450);
+
+TIR_DEFINE_BUILTIN_FUNC(prefetch);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr).set_num_inputs(5);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle).set_num_inputs(0);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_context_id).set_num_inputs(0);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_tuple);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get).set_num_inputs(3);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set).set_num_inputs(4);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error).set_num_inputs(0);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca).set_num_inputs(2);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array).set_num_inputs(6);
+
+// When num_inputs are not set, the function is assumed to be variable length.
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context).set_num_inputs(1);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered);
+
+// TODO(tvm-team) revisit storage sync once we have a good memory hierachy structure.
+TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment);
+
+TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync);
+
+TIR_DEFINE_BUILTIN_FUNC(vectorhigh);
+
+TIR_DEFINE_BUILTIN_FUNC(vectorlow);
+
+TIR_DEFINE_BUILTIN_FUNC(vectorcombine);
+
+} // namespace builtin
+} // namespace tir
+} // namespace tvm
*/
/*!
- * \file expr_operator.cc
+ * \file tir/op/op.cc
+ *
+ * Common operator definitions for ops in tir/op.h
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
+#include <tvm/tir/op_attr_types.h>
#include <cmath>
// Centralized header for constant folders.
using namespace tir;
+// macro to register an unary op
+#define TIR_REGISTER_PURE_UNARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(1)
+
+// macro to register an binary op
+#define TIR_REGISTER_PURE_BINARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(2)
+
runtime::DataType GetRuntimeDataType(const Type& type) {
if (auto* n = type.as<PrimTypeNode>()) {
return n->dtype;
return tir::Cast(t, value);
}
+// LargeUIntImm
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
- return tir::Call(t, tir::intrinsic::tvm_large_uint_imm,
+ return tir::Call(t, tir::builtin::large_uint_imm(),
{make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)},
tir::CallNode::PureIntrinsic);
}
}
}
+// reinterpret
PrimExpr reinterpret(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
- return tir::Call(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic);
+ return tir::Call(t, tir::builtin::reinterpret(), {value}, tir::CallNode::PureIntrinsic);
}
+// operator+
PrimExpr operator+(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Add>(a, b);
return tir::Max(a, b);
}
+// if_then_else
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
CHECK(cond.dtype() == DataType::Bool(1))
<< "if_then_else only accept the condition to be boolean type.";
return false_value;
}
}
- return tir::Call(true_value.dtype(), tir::intrinsic::tvm_if_then_else,
+
+ return tir::Call(true_value.dtype(), tir::builtin::if_then_else(),
{cond, true_value, false_value}, tir::CallNode::PureIntrinsic);
}
+// likely
PrimExpr likely(PrimExpr cond) {
if (is_const(cond)) return cond;
- return tir::Call(cond.dtype(), tir::CallNode::likely, {cond}, tir::CallNode::PureIntrinsic);
+ return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, tir::CallNode::PureIntrinsic);
}
+TVM_REGISTER_OP("tir.likely").set_num_inputs(1);
+
+// operator>
PrimExpr operator>(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::GT>(a, b);
return tir::Not(a);
}
+// shirt right
PrimExpr operator>>(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
if (pb->value == 0) return a;
}
});
- return tir::Call(a.dtype(), tir::CallNode::shift_right, {a, b}, tir::CallNode::PureIntrinsic);
+
+ return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, tir::CallNode::PureIntrinsic);
}
+// shift left
PrimExpr operator<<(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
if (pb->value == 0) return a;
}
});
- return tir::Call(a.dtype(), tir::CallNode::shift_left, {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, tir::CallNode::PureIntrinsic);
}
+// bitwise and
PrimExpr operator&(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
});
- return tir::Call(a.dtype(), tir::CallNode::bitwise_and, {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, tir::CallNode::PureIntrinsic);
}
+// bitwise_or
PrimExpr operator|(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
});
- return tir::Call(a.dtype(), tir::CallNode::bitwise_or, {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, tir::CallNode::PureIntrinsic);
}
+// bitwise_xor
PrimExpr operator^(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
});
- return tir::Call(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, tir::CallNode::PureIntrinsic);
}
+// bitwie_not
PrimExpr operator~(PrimExpr a) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
- return tir::Call(a.dtype(), tir::CallNode::bitwise_not, {a}, tir::CallNode::PureIntrinsic);
+ return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, tir::CallNode::PureIntrinsic);
}
+TVM_REGISTER_OP("tir.bitwise_not");
+
+TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; });
+
+// pow
PrimExpr pow(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "power only applies to float";
- return tir::Call(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.pow");
+ return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic);
}
+TVM_REGISTER_OP("tir.pow").set_num_inputs(2).set_attr<TVectorizable>("TVectorizable", true);
+
+// abs
PrimExpr abs(PrimExpr x) {
if (x.dtype().is_int()) {
using tir::IntImmNode;
if (fx) {
return FloatImm(x.dtype(), std::fabs(fx->value));
}
- return tir::Call(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.fabs");
+ return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
} else if (x.dtype().is_uint()) {
return x;
} else {
}
}
+TIR_REGISTER_PURE_UNARY_OP("tir.fabs").set_attr<TVectorizable>("TVectorizable", true);
+
+// isnan
PrimExpr isnan(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
if (fx) {
return make_const(t, std::isnan(fx->value));
}
+ static auto op = Op::Get("tir.isnan");
if (x.dtype().bits() == 16) {
- return tir::Call(t, tir::CallNode::isnan,
- {cast(DataType::Float(32, t.lanes()), std::move(x))},
+ return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))},
tir::CallNode::PureIntrinsic);
} else {
- return tir::Call(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic);
+ return tir::Call(t, op, {x}, tir::CallNode::PureIntrinsic);
}
} else {
LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op...";
}
}
+TIR_REGISTER_PURE_UNARY_OP("tir.isnan");
+
+// isinf
PrimExpr isinf(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
}
}
+// isfinite
PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }
PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
+// fmod
PrimExpr fmod(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "fmod only applies to float";
- return tir::Call(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.fmod");
+ return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic);
}
+TIR_REGISTER_PURE_UNARY_OP("tir.fmod");
+
+// floor
PrimExpr floor(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
- return tir::Call(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.floor");
+ return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
}
+TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr<TVectorizable>("TVectorizable", true);
+
+// ceil
PrimExpr ceil(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
- return tir::Call(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.ceil");
+ return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
}
+TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr<TVectorizable>("TVectorizable", true);
+
+// round
PrimExpr round(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
- return tir::Call(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.round");
+ return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
}
+TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr<TVectorizable>("TVectorizable", true);
+
+// nearbyint
PrimExpr nearbyint(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
- return tir::Call(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.nearbyint");
+ return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
}
+TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint");
+
+// trunc
PrimExpr trunc(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
if (fx) {
return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)));
}
- return tir::Call(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic);
+ static auto op = Op::Get("tir.trunc");
+ return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic);
}
+TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr<TVectorizable>("TVectorizable", true);
+
+// unary op registration.
+TIR_REGISTER_PURE_UNARY_OP("tir.exp").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.exp2").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.exp10").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.erf");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.rsqrt");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.log").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.log2").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.log1p");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.popcount").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.cosh").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.sin").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.sinh").set_attr<TVectorizable>("TVectorizable", true);
+
+TIR_REGISTER_PURE_UNARY_OP("tir.asin");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.acos");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.atan");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.acosh");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.asinh");
+
+TIR_REGISTER_PURE_UNARY_OP("tir.atanh");
+
+// binary intrinsics
+TIR_REGISTER_PURE_BINARY_OP("tir.atan2");
+
+TIR_REGISTER_PURE_BINARY_OP("tir.nextafter");
+
+TIR_REGISTER_PURE_BINARY_OP("tir.hypot");
+
+TIR_REGISTER_PURE_BINARY_OP("tir.copysign");
+
+TIR_REGISTER_PURE_BINARY_OP("tir.ldexp");
+
// expose basic functions to node namespace
TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kDLInt) {
.set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
return if_then_else(cond, true_value, false_value);
});
+
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/op/runtime.cc
+ * \brief TIR ops for runtime functions.
+ */
+#include <tvm/ir/op.h>
+#include <tvm/tir/op_attr_types.h>
+
+namespace tvm {
+namespace tir {
+
+TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace")
+ .set_num_inputs(5)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace");
+
+TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace")
+ .set_num_inputs(3)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace");
+
+} // namespace tir
+} // namespace tvm
#include "arg_binder.h"
#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
}
}
-inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) {
+inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);
// dimension checks
- PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
+ PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size();
DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
- PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
+ PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) ==
IntImm(DataType::UInt(8), dtype.code()) &&
- TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
+ TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) ==
IntImm(DataType::UInt(8), dtype.bits()) &&
- TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
+ TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
IntImm(DataType::UInt(16), dtype.lanes()));
if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
}
// data field
- if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
+ if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
Var v_shape(arg_name + ".shape", DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(
- LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
+ LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) {
break;
Var v_strides(arg_name + ".strides", DataType::Handle());
def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
- LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop));
+ LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
PrimExpr is_null =
- Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic);
+ Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}, CallNode::PureIntrinsic);
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
- TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
+ TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
if (Bind_(buffer->elem_offset,
cast(buffer->elem_offset.dtype(),
- (TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset) /
+ (TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset) /
make_const(DataType::UInt(64), data_bytes))),
arg_name + ".elem_offset", true)) {
if (buffer->offset_factor > 1) {
}
}
// device info.
- Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType),
+ Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType),
arg_name + ".device_type", true);
- Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId),
+ Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
arg_name + ".device_id", true);
}
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
auto uint32_v = Cast(uint32_dtype, op_val);
// to be endian invariant.
- return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic);
+ return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}, CallNode::PureIntrinsic);
} else if (op->dtype.is_bfloat16()) {
// if is cast_to_bf16, check if op->value is fp32
CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32);
auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes());
- auto uint32_v = Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic);
+ auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}, CallNode::PureIntrinsic);
auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes());
/* the following TIR is equivalent to the C++ code below:
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
}
PrimExpr VisitExpr_(const CallNode* op) final {
- if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
+ if (process_store_ && op->op.same_as(builtin::if_then_else())) {
unsafe_rewritten_ = true;
}
return StmtExprMutator::VisitExpr_(op);
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
class ContextCallCombiner final : public StmtExprMutator {
public:
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
+ if (op->op.same_as(builtin::tvm_thread_context())) {
CHECK_EQ(op->args.size(), 1U);
PrimExpr ctx = op->args[0];
auto it = ctx_map_.find(ctx);
return it->second;
} else {
CHECK(ctx.dtype().is_handle());
- std::string name;
- if (const CallNode* call = ctx.as<CallNode>()) {
- name = call->name + "_cache";
- } else {
- name = "ctx_cache_";
- }
- Var ctx_var(name, ctx.dtype());
+ Var ctx_var("ctx_cache_", ctx.dtype());
ctx_map_[ctx] = ctx_var;
return std::move(ctx_var);
}
* \file coproc_sync.cc
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode* buffer = op->args[1].as<VarNode>();
if (in_scope_) {
touched_[buffer].coproc = true;
}
std::vector<Stmt> GetSync(std::string sync_name) {
- return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))};
+ return {
+ Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}, CallNode::Intrinsic))};
}
const std::unordered_set<const VarNode*>& touched_;
explicit CoProcBarrierDetector(const std::unordered_set<const VarNode*>& touched,
const std::string& coproc_name)
: touched_(touched) {
- read_barrier_name_ = coproc_name + ".coproc_read_barrier";
- write_barrier_name_ = coproc_name + ".coproc_write_barrier";
+ read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier";
+ write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier";
}
void PlanReadBarrier(const Stmt& stmt) {
CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer;
PrimExpr min = r->min;
PrimExpr extent = r->extent;
- return Evaluate(Call(DataType::Int(32), func,
+ return Evaluate(Call(DataType::Int(32), Op::Get(func),
{wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
CallNode::Intrinsic));
}
public:
explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name)
: coproc_axis_(coproc_axis) {
- sync_push_name_ = coproc_name + ".coproc_dep_push";
- sync_pop_name_ = coproc_name + ".coproc_dep_pop";
+ sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push");
+ sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop");
}
void Plan(const Stmt& stmt) {
}
Stmt MakePush(int from, int to) {
- return Evaluate(Call(DataType::Int(32), sync_push_name_,
+ return Evaluate(Call(DataType::Int(32), sync_push_op_,
{make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
CallNode::Intrinsic));
}
Stmt MakePop(int from, int to) {
- return Evaluate(Call(DataType::Int(32), sync_pop_name_,
+ return Evaluate(Call(DataType::Int(32), sync_pop_op_,
{make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
CallNode::Intrinsic));
}
SyncState first_state_, last_state_, curr_state_;
// Variables
IterVar coproc_axis_;
- std::string sync_push_name_, sync_pop_name_;
+ Op sync_push_op_, sync_pop_op_;
};
class CoProcSyncInserter : public StmtMutator {
* \file inject_virtual_thread.cc
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
}
void VisitExpr_(const VarNode* op) final { HandleUseVar(op); }
void VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
const auto* rw_mask = op->args[4].as<IntImmNode>();
const VarNode* buffer_var = op->args[1].as<VarNode>();
CHECK(buffer_var);
}
// Expression.
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
PrimExpr extent = this->VisitExpr(op->args[3]);
PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
- return Call(op->dtype, op->name, {op->args[0], op->args[1], offset, extent, op->args[4]},
+ return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]},
op->call_type);
- } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
+ } else if (op->op.same_as(builtin::tvm_context_id())) {
return allow_share_ ? GetRef<PrimExpr>(op) : var_;
} else {
return StmtExprMutator::VisitExpr_(op);
#define TVM_TIR_TRANSFORMS_IR_UTIL_H_
#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
* \return the get expression.
*/
inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index,
- intrinsic::TVMStructFieldKind kind) {
+ builtin::TVMStructFieldKind kind) {
Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind))};
- return Call(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic);
+ return Call(dtype, builtin::tvm_struct_get(), args, CallNode::PureIntrinsic);
}
/*!
* \param offset the offset index.
*/
inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) {
- return Call(DataType::Handle(), intrinsic::tvm_address_of,
+ return Call(DataType::Handle(), builtin::address_of(),
{Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
const_true(dtype.lanes()))},
CallNode::PureIntrinsic);
offset = offset * make_const(offset.dtype(), dtype.lanes());
offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes());
}
- return Call(DataType::Handle(), intrinsic::tvm_address_of,
+ return Call(DataType::Handle(), builtin::address_of(),
{Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic);
}
* \param value The value to be set.
* \return the set stmt.
*/
-inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind,
- PrimExpr value) {
+inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) {
Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind)), value};
- return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args, CallNode::Intrinsic));
}
/*!
#include <tvm/arith/analyzer.h>
#include <tvm/arith/bound.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
}
void VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(CallNode::likely)) {
+ if (op->op.same_as(builtin::likely())) {
in_likely_ = true;
StmtExprVisitor::VisitExpr_(op);
in_likely_ = false;
- } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
+ } else if (op->op.same_as(builtin::tvm_thread_allreduce())) {
// no split if the body contains allreduce.
no_split_ = true;
return;
}
void VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(CallNode::likely)) {
+ if (op->op.same_as(builtin::likely())) {
PrimExpr cond = op->args[0];
if (ExprUseVars(cond, std::unordered_set<const VarNode*>({current_var_.get()}))) {
// For cond, find out the interval, if exists, in which we can prove that cond is
class RemoveLikelyTags : public StmtExprMutator {
public:
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(CallNode::likely)) {
+ if (op->op.same_as(builtin::likely())) {
CHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
} else {
#include <tvm/runtime/registry.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/buffer.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
}
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
return MakeAccessPtr(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const CallNode* op) final {
+ // NOTE: call_type will eventually be deprecated and the information
+ // will be folded into Op's attr
if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) {
- PrimExpr r = ApplyPattern(op->name, GetRef<PrimExpr>(op));
- if (r.defined()) return r;
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ // Still use legacy string based rewriting
+ // TODO(tvm-team): migrate the pattern application from global function look up
+ // to an OpAttrMap<PackedFunc>
+ std::string name = ptr_op->name;
+ PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
+ if (r.defined()) return r;
+ }
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
- PrimExpr r = (*fma_)(Call(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
+ PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}, CallNode::PureIntrinsic));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) {
+ PrimExpr ApplyPattern(std::string name, const PrimExpr& e) {
+ if (name.compare(0, 4, "tir.") == 0) {
+ name = name.substr(4);
+ }
+
for (size_t i = 0; i < patterns_.size(); ++i) {
std::string& p = patterns_[i];
size_t psize = p.length();
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<EvaluateNode>();
const CallNode* call = op->value.as<CallNode>();
- if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
+ if (call && call->op.same_as(builtin::tvm_thread_allreduce())) {
return MakeAllreduce(call);
} else {
return stmt;
{
PrimExpr pred = const_true(1);
PrimExpr mask =
- Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
+ Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic);
seq.emplace_back(Store(mask_var, mask, index, pred));
// Push allocation with an empty body. Later this will be fixed
// when the entire body is ready.
// The former may cause dead lock as there is a divergent
// branch with a warp sync call inside.
//
- const char* shfl_func = intrinsic::tvm_warp_shuffle_down;
- PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset);
+ PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_var, val, offset);
const AllocateNode* repl = local_vars[i].as<AllocateNode>();
Stmt s = Store(repl->buffer_var, other, index, pred);
seq.push_back(s);
for (size_t i = 0; i < size; ++i) {
Var var = shared_bufs[i];
PrimExpr pred = const_true(types[i].lanes());
- const char* shfl_func = intrinsic::tvm_warp_shuffle;
PrimExpr val = Load(types[i], var, index, pred);
- PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0);
+ PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 0);
seq.push_back(Store(var, splat, index, pred));
}
}
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
- return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)},
+ return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)},
CallNode::Intrinsic));
}
- // Emit warp shuffle intrinsic calls.
- PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) {
+ // Emit warp shuffle calls.
+ PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int delta_or_lane) {
PrimExpr pred = const_true(1);
PrimExpr index(0);
PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred);
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width};
- return Call(val.dtype(), name, args, CallNode::Intrinsic);
+ return Call(val.dtype(), op, args, CallNode::Intrinsic);
}
// Check if this is a reduction on threadIdx.x and its extent matches
* \file tir/transforms/lower_tvm_buildin.cc
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
inline PrimExpr StackAlloca(std::string type, size_t num) {
Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
- return Call(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic);
+ return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args, CallNode::Intrinsic);
}
// Calculate the statistics of packed function.
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
Stmt throw_last_error =
- Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic));
+ Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}, CallNode::Intrinsic));
- Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null,
- {op->buffer_var}, CallNode::PureIntrinsic),
+ Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var},
+ CallNode::PureIntrinsic),
throw_last_error),
op->body});
-
Stmt alloca = LetStmt(
op->buffer_var,
- Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace",
+ Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())},
CallNode::Extern),
body);
- PrimExpr free_op = Call(DataType::Int(32), "TVMBackendFreeWorkspace",
+ PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_), op->buffer_var},
CallNode::Extern);
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
+ if (op->op.same_as(builtin::tvm_call_packed())) {
return MakeCallPacked(op);
- } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
+ } else if (op->op.same_as(builtin::tvm_call_trace_packed())) {
return MakeCallTracePacked(op);
- } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
+ } else if (op->op.same_as(builtin::tvm_stack_make_shape())) {
return MakeShape(op);
- } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
+ } else if (op->op.same_as(builtin::tvm_stack_make_array())) {
return MakeArray(op);
- } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
+ } else if (op->op.same_as(builtin::tvm_context_id())) {
return make_zero(op->dtype);
} else {
return StmtExprMutator::VisitExpr_(op);
run_array_stack_ += 1;
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrData, op->args[0]));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrShape, op->args[1]));
PrimExpr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(DataType::Handle());
}
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrStrides, strides));
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrNDim, op->args[3]));
DataType dtype = op->args[4].dtype();
prep_seq_.emplace_back(
- TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
+ TVMStructSet(stack_array_, idx, builtin::kArrTypeCode,
make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeBits,
make_const(DataType::UInt(8), dtype.bits())));
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeLanes,
make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrByteOffset,
cast(DataType::UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceId,
cast(DataType::Int(32), device_id_)));
- prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
+ prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceType,
cast(DataType::Int(32), device_type_)));
- return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
+ return TVMStructGet(DataType::Handle(), stack_array_, idx, builtin::kArrAddr);
}
// call packed.
PrimExpr MakeCallPacked(const CallNode* op) {
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast<int>(arg_stack_begin + i - 1),
- intrinsic::kTVMValueContent, arg));
+ builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
Array<PrimExpr> packed_args = {op->args[0], stack_value_, stack_tcode_,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};
- return Call(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args,
+ return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args,
CallNode::Intrinsic);
}
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast<int>(arg_stack_begin + i - 1),
- intrinsic::kTVMValueContent, arg));
+ builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq_.emplace_back(
ConstInt32(arg_stack_begin + op->args.size() - 1),
// Pass traced value.
op->args[args_size - 1]};
- return Call(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args,
+ return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args,
CallNode::Intrinsic);
}
bool IsArrayHandle(const PrimExpr& arg) {
// specially set array handle.
if (const CallNode* buf = arg.as<CallNode>()) {
- if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
- buf->args[2].as<IntImmNode>()->value == intrinsic::kArrAddr) {
+ if (buf->op.same_as(builtin::tvm_struct_get()) &&
+ buf->args[2].as<IntImmNode>()->value == builtin::kArrAddr) {
return true;
}
}
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
<< " local_index=" << local_index;
PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);
PrimExpr mask =
- Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
- return Call(load_value.dtype(), intrinsic::tvm_warp_shuffle,
+ Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic);
+ return Call(load_value.dtype(), builtin::tvm_warp_shuffle(),
{mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic);
} else {
return StmtExprMutator::VisitExpr_(op);
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
- IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
+ IntImm(DataType::Int(32), builtin::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
- PrimExpr res = Call(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic);
+ PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args, CallNode::PureIntrinsic);
// cast to the target version.
if (api_type != t) {
res = Cast(t, res);
if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
Stmt set_device =
- Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed,
+ Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device), device_type, device_id},
CallNode::Intrinsic));
body = SeqStmt({set_device, body});
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
// indicator of LoadNode::index and StoreNode::index
bool is_index_{false};
+ // cached ops
+ const Op& builtin_pow_ = Op::Get("tir.pow");
};
#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
op = e.as<CallNode>();
CHECK(op != nullptr) << "Expected type to be CallNode"
<< ", but get " << e->GetTypeKey();
- if (op->call_type == CallNode::PureIntrinsic) {
- if (op->name == intrinsic::tvm_if_then_else) {
- return if_then_else(op->args[0], op->args[1], op->args[2]);
- } else if (op->name == CallNode::shift_right) {
- return op->args[0] >> op->args[1];
- } else if (op->name == CallNode::shift_left) {
- return op->args[0] << op->args[1];
- } else if (op->name == CallNode::bitwise_and) {
- return op->args[0] & op->args[1];
- } else if (op->name == CallNode::bitwise_or) {
- return op->args[0] | op->args[1];
- } else if (op->name == CallNode::bitwise_xor) {
- return op->args[0] ^ op->args[1];
- } else if (op->name == "pow") {
- return pow(op->args[0], op->args[1]);
- }
+
+ if (op->op.same_as(builtin::if_then_else())) {
+ return if_then_else(op->args[0], op->args[1], op->args[2]);
+ } else if (op->op.same_as(builtin::shift_right())) {
+ return op->args[0] >> op->args[1];
+ } else if (op->op.same_as(builtin::shift_left())) {
+ return op->args[0] << op->args[1];
+ } else if (op->op.same_as(builtin::bitwise_and())) {
+ return op->args[0] & op->args[1];
+ } else if (op->op.same_as(builtin::bitwise_or())) {
+ return op->args[0] | op->args[1];
+ } else if (op->op.same_as(builtin::bitwise_xor())) {
+ return op->args[0] ^ op->args[1];
+ } else if (op->op.same_as(builtin_pow_)) {
+ return pow(op->args[0], op->args[1]);
}
+
return e;
}
* \brief Rewrite uinsafe select expression.
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
// Because we will issue guard to make sure it is.
bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); }
bool VisitExpr_(const CallNode* op) {
- if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
+ if (op->op.same_as(builtin::if_then_else())) {
return VisitExpr(op->args[0]);
- } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
+ } else if (op->op.same_as(builtin::address_of())) {
const LoadNode* l = op->args[0].as<LoadNode>();
return this->VisitExpr(l->index);
} else if (op->is_pure()) {
bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) &&
cond_is_scalar_bool) {
- return Call(op->dtype, intrinsic::tvm_if_then_else,
+ return Call(op->dtype, builtin::if_then_else(),
{op->condition, op->true_value, op->false_value}, CallNode::Intrinsic);
} else {
return expr;
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
call_args.push_back(ext);
}
return Evaluate(
- Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic));
+ Call(DataType::Int(32), builtin::tvm_call_packed(), call_args, CallNode::Intrinsic));
}
// target ir module
}
void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
- if (op->is_intrinsic(intrinsic::tvm_address_of)) {
+ if (op->op.same_as(builtin::address_of())) {
const LoadNode* l = op->args[0].as<LoadNode>();
StmtExprVisitor::VisitExpr_(l);
- } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ } else if (op->op.same_as(builtin::tvm_access_ptr())) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
}
}
StmtExprVisitor::VisitExpr_(op);
- } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
+ } else if (op->op.same_as(builtin::tvm_storage_sync())) {
CHECK(allow_append_);
const std::string& s = op->args[0].as<StringImmNode>()->value;
if (s != "warp") {
#include <tvm/target/target_info.h>
#include <tvm/te/operation.h>
#include <tvm/tir/buffer.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
namespace tvm {
namespace tir {
-using intrinsic::tvm_address_of;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
} else if (op->attr_key == attr::buffer_dim_align) {
auto buffer = Downcast<tir::Buffer>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
- CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
+ CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
auto& vinfo = dim_align_[buffer];
int dim = tuple->args[0].as<IntImmNode>()->value;
if (static_cast<size_t>(dim) >= vinfo.size()) {
} else {
PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
PrimExpr address =
- Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
+ Call(DataType::Handle(), builtin::address_of(), {load}, CallNode::PureIntrinsic);
PrimExpr prefetch =
- Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
+ Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}, CallNode::Intrinsic);
stmt = Evaluate(prefetch);
PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
const BufferNode* target = arr[1].as<BufferNode>();
const CallNode* tuple = op->value.as<CallNode>();
CHECK(buffer && target);
- CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
+ CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
auto key = GetRef<Buffer>(target);
auto it = buf_map_.find(key);
#include <tvm/runtime/registry.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
}
}
void VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_address_of)) {
+ if (op->op.same_as(builtin::address_of())) {
const LoadNode* l = op->args[0].as<LoadNode>();
this->VisitExpr(l->index);
} else {
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
if (se->bits_offset != 0) {
offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
}
- return Call(op->dtype, op->name, {op->args[0], se->alloc_var, offset, extent, op->args[4]},
+ return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]},
op->call_type);
} else {
return StmtExprMutator::VisitExpr_(op);
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
UpdateTypeMap(buffer, dtype);
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
- if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
- op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
+ if (op->op.same_as(builtin::tvm_load_matrix_sync()) ||
+ op->op.same_as(builtin::tvm_store_matrix_sync())) {
// Get shape and layout information from load and store intrinsic
CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var = op->args[0].as<VarNode>();
}
fragments[buffer_var] = info;
}
- } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
+ } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
// Get shape information from fill intrinsic
CHECK_EQ(op->args.size(), 6U);
const VarNode* buffer_var = op->args[0].as<VarNode>();
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
- if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
+ if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) {
CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var_d = op->args[0].as<VarNode>();
const VarNode* buffer_var_a = op->args[2].as<VarNode>();
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier();
} else {
- barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
+ barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}, CallNode::Intrinsic));
}
// Mutate after query, to avoid stmt change.
}
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+ if (op->op.same_as(builtin::tvm_access_ptr())) {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
CHECK(op != nullptr);
Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)};
Stmt prep =
- Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
+ Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs, CallNode::Intrinsic));
Stmt body = op->body;
for (const auto& kv : rw_stats_) {
const auto& e = kv.second;
}
rw_stats_.clear();
Stmt kinit = Evaluate(
- Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
+ Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}, CallNode::Intrinsic));
body = SeqStmt({kinit, body});
body = AttrStmt(op->node, op->attr_key, op->value, body);
return SeqStmt({prep, body});
} else {
CHECK_EQ(num_work_dim_, thread_extents_.size());
}
- return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
+ return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string()), is_lead_, num_blocks_},
CallNode::Intrinsic));
}
// Loop vectorizer as in Halide pipeline.
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
+#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
t = BroadcastTo(t, lanes);
f = BroadcastTo(f, lanes);
- return Call(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type);
+ return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->call_type);
}
}
// Call
PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->name == intrinsic::tvm_if_then_else) {
+ if (op->op.same_as(builtin::if_then_else())) {
return MutateIfThenElseExpr_(op);
}
- if (!op->is_vectorizable()) {
+ auto* op_ptr = op->op.as<OpNode>();
+ bool vectorizable = op_ptr && op_vectorizable_.get(GetRef<Op>(op_ptr), false);
+
+ if (!vectorizable) {
// Cannot vectorize this op
Array<PrimExpr> new_args;
for (auto arg : op->args) {
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->name, new_args, op->call_type);
+ return Call(op->dtype, op->op, new_args, op->call_type);
}
} else {
int lane = 0;
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype.with_lanes(lane), op->name, new_args, op->call_type);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args, op->call_type);
}
}
}
bool need_scalarize_{false};
// The lets
std::unordered_map<const VarNode*, PrimExpr> lets_;
+ // vectorizable property
+ OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable");
+
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/node/functor.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
}
{
- auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
+ auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1},
+ CallNode::Extern));
auto res = v(std::move(body));
- CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[0].same_as(x));
+ CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[1].same_as(x));
}
{
Stmt body = fmakealloc();
def test_duplicate_global_var():
parse_text(
"""
- def @id[A](%x: A) -> A { x }
- def @id[A](%x: A) -> A { x }
+ def @id[A](%%x: A) -> A { x }
+ def @id[A](%%x: A) -> A { x }
"""
)
assert tvm.ir.structural_equal(lhs, rhs)
# Test that components with side effects are not removed
- side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic)
+ dummy = tvm.ir.GlobalVar("dummy")
+ side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs, tvm.tir.Call.Intrinsic)
ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0],
sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0],
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name='A', dtype="int32")
- B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "reinterpret", A(*i)), name='B')
+ B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "tir.reinterpret", A(*i)), name='B')
s = te.create_schedule(B.op)
def check_c():
n = tvm.runtime.convert(4)
A = ib.pointer("float32", name="A")
args = [
- tvm.tir.call_pure_intrin("handle", "tvm_address_of", A[0]),
+ tvm.tir.call_pure_intrin("handle", "tir.address_of", A[0]),
0, 3, 1
]
ib.emit(tvm.tir.Evaluate(
tvm.tir.Call(
- "int32", "prefetch", args, tvm.tir.Call.Intrinsic)))
+ "int32", "tir.prefetch", args, tvm.tir.Call.Intrinsic)))
body = ib.get()
mod = tvm.IRModule.from_expr(
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
def np_float2np_bf16(arr):
- ''' Convert a numpy array of float to a numpy array
+ ''' Convert a numpy array of float to a numpy array
of bf16 in uint16'''
orig = arr.view('<u4')
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
return np.right_shift(orig + bias, 16).astype('uint16')
def np_float2tvm_bf16(arr):
- ''' Convert a numpy array of float to a TVM array
+ ''' Convert a numpy array of float to a TVM array
of bf16'''
nparr = np_float2np_bf16(arr)
return tvm.nd.empty(nparr.shape, 'uint16').copyfrom(nparr)
def np_bf162np_float(arr):
- ''' Convert a numpy array of bf16 (uint16) to a numpy array
+ ''' Convert a numpy array of bf16 (uint16) to a numpy array
of float'''
u32 = np.left_shift(arr.astype('uint32'), 16)
return u32.view('<f4')
tvm.testing.assert_allclose(np_bf162np_float(c_.asnumpy()), res)
dotest(True)
dotest(False)
-
+
if __name__ == "__main__":
test_multiple_func()
test_llvm_large_uintimm()
Ab = tvm.tir.decl_buffer((n, ), dtype)
i = te.size_var('i')
ib = tvm.tir.ir_builder.create()
- handle = tvm.tir.call_intrin("handle", "tvm_static_handle")
+ handle = tvm.tir.call_intrin("handle", "tir.tvm_static_handle")
ib.emit(
tvm.tir.call_packed("test_static_callback", handle, Ab))
BA = ins[0]
BC = outs[0]
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, n, m, l, BC.elem_offset // (row * col),
BA.access_ptr('r'), col, 'row_major'))
return ib.get()
def init():
ib = tvm.tir.ir_builder.create()
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0))
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0))
return ib.get()
def update():
ib = tvm.tir.ir_builder.create()
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
BC.data, BC.elem_offset // (n * m),
BA.data, BA.elem_offset // (n * l),
BB.data, BB.elem_offset // (l * m),
BA = ins[0]
BC = outs[0]
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
BA.data, n, m, l, BA.elem_offset // (n * m),
BC.access_ptr('w'), m, 'row_major'))
return ib.get()
assert x.vectors[0] == a
assert x.indices[0].value == 0
- x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern)
+ x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a], tvm.tir.Call.Extern)
assert isinstance(x, tvm.tir.Call)
assert x.dtype == "float32"
- assert x.name == "xyz"
- assert x.args[0] == a
+ assert x.op.name == "tir.call_extern"
+ assert x.args[1] == a
assert x.call_type == tvm.tir.Call.Extern
v = te.var("aa")
def test_bitwise():
x = te.var('x')
y = te.var('y')
- assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")'
- assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin")'
- assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin")'
+ assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+ assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+ assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+ assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+ assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")'
+ assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")'
+ assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")'
+ assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")'
+ assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32, type="pure_intrin")'
+ assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32, type="pure_intrin")'
assert str(10 % x) == 'floormod(10, x: int32)'
- assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin")'
+
+ assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32, type="pure_intrin")'
assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
def test_isnan():
x = te.var('x', 'float32')
- assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin")'
+ assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool, type="pure_intrin")'
assert str(tvm.tir.isnan(x).dtype) == 'bool'
y = te.var('y', 'float16')
- assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")'
+ assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")'
z = te.var('z', 'int32')
assert str(tvm.tir.isnan(z)) == 'False'
k = te.var('k', 'int8x2')
ib.emit(tvm.tir.call_extern("int32", "TestB", x))
ib.emit(tvm.tir.call_extern("int32", "TestC", x))
body = ib.get()
+ builtin_call_extern = tvm.ir.Op.get("tir.call_extern")
def preorder(op):
- if op.name == "TestC":
+ if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestC":
return tvm.tir.const(0, "int32")
return None
def postorder(op):
assert isinstance(op, tvm.tir.Call)
- if op.name == "TestA":
- return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1)
+ if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestA":
+ return tvm.tir.call_extern("int32", "TestB", op.args[1] + 1)
return op
body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"])
stmt_list = tvm.tir.stmt_list(body.body.body)
- assert stmt_list[0].value.args[0].name == "TestB"
+ assert stmt_list[0].value.args[1].args[0].value == "TestB"
assert stmt_list[1].value.value == 0
if __name__ == "__main__":
def to32(v):
uint32_v = topi.cast(v, "uint32")
uint32_v = tvm.tir.call_pure_intrin(
- "uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32"))
- return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v)
+ "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32"))
+ return tvm.tir.call_pure_intrin("float32", "tir.reinterpret", uint32_v)
def to16(v):
- uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v)
+ uint32_v = tvm.tir.call_pure_intrin("uint32", "tir.reinterpret", v)
rounding_bias = tvm.tir.call_pure_intrin(
- "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32"))
+ "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
rounding_bias = tvm.tir.call_pure_intrin(
- "uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32"))
+ "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32"))
rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
uint32_v = uint32_v + rounding_bias
uint32_v = tvm.tir.call_pure_intrin(
- "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32"))
+ "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32"))
return topi.cast(uint32_v, 'uint16')
def check(fcompute_before, fcompute_after):
def device_context(dev_id):
ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
return tvm.tir.Call(
- "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic)
+ "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic)
ib = tvm.tir.ir_builder.create()
n = te.var("n")
import tvm
from tvm import te
+# register the ops
+tvm.ir.register_op_attr("tir.cop.coproc_sync", "TGlobalSymbol", "coproc_sync")
+tvm.ir.register_op_attr("tir.cop.coproc_read_barrier", "TGlobalSymbol", "coproc_readb")
+tvm.ir.register_op_attr("tir.cop.coproc_write_barrier", "TGlobalSymbol", "coproc_writeb")
+tvm.ir.register_op_attr("tir.cop.coproc_dep_push", "TGlobalSymbol", "coproc_dep_push")
+tvm.ir.register_op_attr("tir.cop.coproc_dep_pop", "TGlobalSymbol", "coproc_dep_pop")
+
+
def test_coproc_sync():
@tvm.register_func("tvm.info.mem.global.cache")
def meminfo_cache():
max_simd_bits=32,
max_num_bits=128,
head_address=tvm.tir.call_extern("handle", "global_cache"))
+
ib = tvm.tir.ir_builder.create()
n = te.size_var("n")
cp = te.thread_axis((0, 1), "cop")
body = stmt.body.body.body
blist = tvm.tir.stmt_list(body)
- assert(blist[1].value.name == "cop.coproc_read_barrier")
+
+ assert(blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")))
assert(blist[1].value.args[3].value == 80)
- assert(blist[-2].value.name == "cop.coproc_sync")
- assert(blist[-1].value.name == "cop.coproc_write_barrier")
+ assert(blist[-2].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_sync")))
+ assert(blist[-1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_write_barrier")))
assert(blist[-1].value.args[3].value == 10)
slist = tvm.tir.stmt_list(slist[-1])
pop_st = slist[0].body[0]
- assert(push_st.value.name == "cop.coproc_dep_push")
+ assert(push_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_push")))
assert(__check_list(push_st.value.args, [2,3]))
- assert(pop_st.value.name == "cop.coproc_dep_pop")
+ assert(pop_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_pop")))
assert(__check_list(pop_st.value.args, [2,3]))
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
def count_sync(op):
- if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
+ if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")):
count[0] += 1
tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
assert count[0] == 4
bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject())
ib.emit(tvm.tir.call_extern("int32", "Run",
bbuffer.access_ptr("r"),
- tvm.tir.call_pure_intrin("int32", "tvm_context_id")))
+ tvm.tir.call_pure_intrin("int32", "tir.tvm_context_id")))
C[i * nthread + tx] = B[i] + 1
return ib.get()
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a)))
aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
- assert yy.name == "tvm_if_then_else"
- assert zz.name == "tvm_if_then_else"
+ builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else")
+
+ assert yy.op.same_as(builtin_if_then_else)
+ assert yy.op.same_as(builtin_if_then_else)
assert isinstance(aa, tvm.tir.Select)
count = [0]
def count_sync(op):
- if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
+ if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")):
count[0] += 1
tvm.tir.stmt_functor.post_order_visit(f.body, count_sync)
assert count[0] == 4
cuda_target = tvm.target.create("cuda")
f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"]
body_list = tvm.tir.stmt_list(f.body.body.body.body)
- assert(body_list[1].value.name == "tvm_storage_sync")
+ assert(body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")))
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, 4, for_type="vectorize") as i:
- A[i] = tvm.tir.call_intrin("float32", "tvm_if_then_else",
+ A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else",
i > 0,
A[i] + 1, A[i])
stmt = ib.get()
A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as k:
with ib.for_range(0, 4, for_type="vectorize") as i:
- A[k * 4 + i] = tvm.tir.call_intrin("float32", "tvm_if_then_else",
+ A[k * 4 + i] = tvm.tir.call_intrin("float32", "tir.if_then_else",
k > 0,
A[k * 4 + i], 0)
stmt = ib.get()
#define TOPI_DETAIL_EXTERN_H_
#include <tvm/te/operation.h>
+#include <tvm/tir/builtin.h>
#include <string>
#include <vector>
*/
inline PrimExpr pack_buffer(Buffer buf) {
CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
- auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape,
+ auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(),
buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
PrimExpr strides;
if (buf->strides.size() > 0) {
- strides = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape,
+ strides = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(),
buf->shape, tvm::tir::CallNode::CallType::Intrinsic);
} else {
strides = 0;
make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
make_const(buf->dtype, 0),
buf->elem_offset};
- return tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, pack_args,
+ return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args,
tvm::tir::CallNode::CallType::Intrinsic);
}
* \return An expression representing the invocation
*/
inline PrimExpr call_packed(Array<PrimExpr> args) {
- return tvm::tir::Call(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args,
+ return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args,
tvm::tir::CallNode::CallType::Intrinsic);
}
#define TOPI_ELEMWISE_H_
#include <topi/tags.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <algorithm>
return compute(
x->shape,
[&](const Array<Var>& i) {
- return tvm::tir::Call(type, "reinterpret", {x(i)}, tvm::tir::CallNode::PureIntrinsic);
+ return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)},
+ tvm::tir::CallNode::PureIntrinsic);
},
name, tag);
}
cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_)
else:
cnts = tvm.tir.popcount(w_ & x_)
- upper_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorhigh', cnts)
- lower_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorlow', cnts)
+ upper_half = tvm.tir.call_pure_intrin(
+ half_dtype, 'tir.vectorhigh', cnts)
+ lower_half = tvm.tir.call_pure_intrin(
+ half_dtype, 'tir.vectorlow', cnts)
cnts8[i] = upper_half + lower_half
for i in range(m//2):
cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
args_2, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.tir.call_pure_intrin(
- full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
+ full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
out = tvm.tir.call_llvm_intrin(
return_dtype, vpadalu,
cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
args_2, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.tir.call_pure_intrin(
- full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
+ full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
out = tvm.tir.call_llvm_intrin(
return_dtype, vpadalu,
dtype_c = '%s32x%d' % (dtype, int32_lanes)
a_int8 = ins[0].vload([0], dtype_a)
- re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8)
+ re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'tir.reinterpret', a_int8)
# broadcast a
vec_ai32 = re_int32.astype(dtype_c)
- vec_a = tvm.tir.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32)
+ vec_a = tvm.tir.call_pure_intrin(dtype_b, 'tir.reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], dtype_b)
vec_c = outs[0].vload([0], dtype_c)
tvm.target.intrin.register_intrin_rule(
"cuda", "atomic_add", cuda_atomic_add_rule, override=True)
+tvm.ir.register_op_attr("tir.atomic_add", "TVectorizable", False)
def atomic_add(x, y):
- return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y)
+ return tvm.tir.call_pure_intrin(y.dtype, "tir.atomic_add", x, y)
def get_valid_counts_ir(data, valid_count, out, out_indices,
with ib.if_scope(
tvm.tir.all(data[tid * elem_length + score_index] > score_threshold,
tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
- atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tvm_address_of",
+ atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tir.address_of",
valid_count[i]), one_count)
with ib.for_range(0, elem_length) as k:
out[tid * elem_length + k] = data[tid * elem_length + k]
temp_index[0] = index_out[offset]
index_out[offset] = index_out[offset + 1]
index_out[offset + 1] = temp_index[0]
- ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
+ ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
tvm.runtime.convert(['shared']),
tvm.tir.Call.Intrinsic))
return ib.get()
iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
with ib.if_scope(iou > nms_threshold):
p_out[base_idx + i] = True
- ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
+ ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
tvm.runtime.convert(['shared']),
tvm.tir.Call.Intrinsic))
return ib.get()
if indices_out is not None:
indices_out[base_idx + tid * axis_mul_after] = \
tvm.tir.generic.cast(tid, indices_out.dtype)
- ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
+ ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
tvm.runtime.convert(['shared']),
tvm.tir.Call.Intrinsic))
idxd = tvm.tir.indexdiv
temp_index[0] = indices_out[offset]
indices_out[offset] = indices_out[offset + axis_mul_after]
indices_out[offset + axis_mul_after] = temp_index[0]
- ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
+ ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
tvm.runtime.convert(['shared']),
tvm.tir.Call.Intrinsic))
temp_index[0] = output[offset]
output[offset] = output[offset + axis_mul_after]
output[offset + axis_mul_after] = temp_index[0]
- ib.emit(tvm.tir.Call(None, 'tvm_storage_sync',
+ ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync',
tvm.runtime.convert(['shared']),
tvm.tir.Call.Intrinsic))
BC = outs[0]
row = wmma_m * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
BC = outs[0]
row = wmma_n * wmma_k
warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, wmma_m, wmma_n, wmma_k, warp_index,
BA.access_ptr('r'), strides_from[0], layout))
return ib.get()
BC = outs[0]
row = wmma_m * wmma_n
warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
BA.data, wmma_m, wmma_n, wmma_k, warp_index,
BC.access_ptr('w'), strides_dst[0], 'row_major'))
return ib.get()
def init():
ib = tvm.tir.ir_builder.create()
ib.emit(
- tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k,
+ tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment',
+ BC.data, wmma_m, wmma_n, wmma_k,
warp_index_C, 0.0))
return ib.get()
def update():
ib = tvm.tir.ir_builder.create()
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
BC.data, warp_index_C,
BA.data, warp_index_A,
BB.data, warp_index_B,
return ib.get()
a_int8 = ins[0].vload([0], "uint8x4")
- re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8)
+ re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8)
vec_ai32 = re_int32.astype('int32x16')
- vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
+ vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32)
vec_b = ins[1].vload([0, 0], "int8x64")
vec_one = tvm.tir.const(1, "int16x32")
pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
return ib.get()
a_int8 = ins[0].vload([0], "uint8x2")
- re_int16 = tvm.tir.call_pure_intrin('int16', 'reinterpret', a_int8)
+ re_int16 = tvm.tir.call_pure_intrin('int16', 'tir.reinterpret', a_int8)
vec_ai16 = re_int16.astype('int16x32')
- vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai16)
+ vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai16)
for i in range(4):
vec_b = ins[1].vload([i*32, 0], "int8x64")
return ib.get()
a_int8 = ins[0].vload([0], "uint8x4")
- re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8)
+ re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8)
vec_ai32 = re_int32.astype('int32x16')
vec_b = ins[1].vload([0, 0], "int8x64")
llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name)
if llvm_id != 0: # VNNI is available for current LLVM version
- vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'reinterpret', vec_b)
+ vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'tir.reinterpret', vec_b)
vec_zero = tvm.tir.const(0, "int32x16")
quad_reduction = tvm.tir.call_llvm_intrin('int32x16',
'llvm.x86.avx512.vpdpbusd.512',
vec_zero,
vec_ai32, vec_bi32)
else: # Fall back to the normal AVX512
- vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32)
+ vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32)
vec_one = tvm.tir.const(1, "int16x32")
pair_reduction = tvm.tir.call_llvm_intrin('int16x32',
'llvm.x86.avx512.pmaddubs.w.512',
def test_apply(func, name):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
- assert B.op.body[0].name == name
+ assert B.op.body[0].op.name == "tir." + name
test_apply(topi.exp, "exp")
test_apply(topi.erf, "erf")
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
- assert B.op.body[0].name == name
+ assert B.op.body[0].op.name == "tir." + name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary
if check_round:
- a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5
+ a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-4
b_np = f_numpy(a_np)
def check_device(device):
B = topi.isnan(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
- assert B.op.body[0].name == "isnan"
+ assert B.op.body[0].op.name == "tir.isnan"
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan
# avoid round check too close to boundary
def my_cuda_math_rule(op):
"""Customized CUDA intrinsic lowering rule"""
assert isinstance(op, tvm.tir.Call)
+ name = op.op.name
+ assert name.startswith("tir.")
+ dispatch_name = name[4:]
if op.dtype == "float32":
# call float function
- return tvm.tir.call_pure_extern("float32", "%sf" % op.name, op.args[0])
+ return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0])
elif op.dtype == "float64":
# call double function
- return tvm.tir.call_pure_extern("float32", op.name, op.args[0])
+ return tvm.tir.call_pure_extern("float32", dispatch_name, op.args[0])
else:
# cannot do translation, return self.
return op
def mylog(x):
"""customized log intrinsic function"""
- return tvm.tir.call_pure_intrin(x.dtype, "mylog", x)
+ return tvm.tir.call_pure_intrin(x.dtype, "tir.mylog", x)
def my_cuda_mylog_rule(op):
else:
return op
-
+# new op registration is triggered by registering an attribute of the op
+tvm.ir.register_op_attr("tir.mylog", "TVectorizable", True)
tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
n = te.var("n")
BA = ins[0]
BC = outs[0]
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync',
BC.data, n, n, n, BC.elem_offset // 256,
BA.access_ptr('r'), n, 'row_major'))
return ib.get()
def init():
ib = tvm.tir.ir_builder.create()
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
return ib.get()
def update():
ib = tvm.tir.ir_builder.create()
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync',
BC.data, BC.elem_offset // 256,
BA.data, BA.elem_offset // 256,
BB.data, BB.elem_offset // 256,
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
- ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync',
+ ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync',
BA.data, n, n, n, BA.elem_offset // 256,
BC.access_ptr('w'), n, 'row_major'))
return ib.get()
def __init__(self, env):
self.vta_axis = te.thread_axis("vta")
self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp")
- ctx = tvm.tir.call_extern("handle", "VTATLSCommandHandle")
+ ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle")
self.command_handle = tvm.tir.Call(
- "handle", "tvm_thread_context", [ctx],
+ "handle", "tir.tvm_thread_context", [ctx],
tvm.tir.Call.Intrinsic)
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
tvm.runtime.const(1<<31, dtype="uint32"))
+
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.tir.call_extern(
get_env().dev.command_handle,
op.args[0], op.args[1])
+# register a dummy into to trigger registration of the ops
+# change the info to lowering rule later.
+tvm.ir.register_op_attr("tir.vta.coproc_sync", "TVectorizable", False)
+tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TVectorizable", False)
+tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TVectorizable", False)
+
+tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush")
+tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle")
+
def _init_env():
"""Initialize the default global env"""
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index in (0, 2):
- irb.emit(tvm.tir.call_extern(
- "int32", "VTAUopPush",
+ irb.emit(tvm.tir.call_intrin(
+ "int32", "tir.vta.uop_push",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0, 0))
else:
- irb.emit(tvm.tir.call_extern(
- "int32", "VTAUopPush",
+ irb.emit(tvm.tir.call_intrin(
+ "int32", "tir.vta.uop_push",
0, 1,
dout.access_ptr("rw", "int32"),
0,
loop_var = stmt.loop_var
gemm_offsets = [None, None, None]
fail = [False]
+ builtin_uop_push = tvm.ir.Op.get("tir.vta.uop_push")
def _post_order(op):
assert isinstance(op, tvm.tir.Call)
base_args = 2
- if op.name == "VTAUopPush":
+ if op.op.same_as(builtin_uop_push):
args = []
args += op.args[:base_args]
for i in range(3):
gemm_offsets[i] = m[0]
args.append(m[1])
args += op.args[base_args+3:]
- return tvm.tir.call_extern("int32", "VTAUopPush", *args)
- if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
+ return tvm.tir.call_intrin("int32", builtin_uop_push, *args)
+ if op.op.name not in ("tir.vta.command_handle", "tir.tvm_thread_context"):
raise RuntimeError("unexpected op %s" % op)
return op
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
- irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
+ irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push",
0, 1,
dout.access_ptr("rw", "int32"),
0, 0,
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_buffer], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
return inner
else:
conv_call, data_call, kernel_call = calls[-3:]
irb.scope_attr(
dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
- irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
+ irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
1, 0, 1, 0, env.BLOCK_OUT)
inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
args = kernel_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
args = data_call.indices
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN)
inner = tvm.tir.AttrStmt(
[dinp, pad_data_tensor], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner)
return inner
return None
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Call):
- if loop_body.value.name == 'shift_left':
+ if loop_body.value.op.name == 'tir.shift_left':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = analyzer.simplify(-loop_body.value.args[1])
- elif loop_body.value.name == 'shift_right':
+ elif loop_body.value.op.name == 'tir.shift_right':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
"int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
- irb.emit(tvm.tir.call_extern(
- "int32", "VTAUopPush",
+ irb.emit(tvm.tir.call_intrin(
+ "int32", "tir.vta.uop_push",
1, 0,
dst_coeff[len(dst_coeff)-1],
src_coeff[len(src_coeff)-1],