From 82d157f0b83ae17fde7bbfca14110aa2f2b80b61 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 22 Jun 2020 17:47:01 -0700 Subject: [PATCH] [TIR][REFACTOR][API-CHANGE] Change Call.name to Call.op(RelayExpr) (#5863) * [TIR][REFACTOR][API-CHANGE] Change Call.name(string) to Call.op(tvm::Op/RelayExpr) This PR brings a major refactor to the tir::Call structure. The current Call structure uses a string field(name) to identify the function/intrinsic being called. This approach is limited as we start to expand TIR to be more structured. In particular, we are interested in the following aspects: - Type a function and perform better compile time type checking so that we can find errors early. - Register additional properties about an operator, such as: - Whether an intrinsic can be vectorized - What is the adjoint function of the intrinsic(for tensor expression AD) - Whether the operator has side effect. - Perform specific codegen about an intrinsic if necessary. - Call into another function in the same module. The refactor changes the Call.name field to Call.op. The Call.op field has a RelayExpr type, and we can pass: - A tvm::Op which represents the corresponding intrinsic. - A tvm::GlobalVar for calling into another function in the IRModule. All the current intrinsics are migrated by registering an tvm::Op. Because the unified IR shares a single Op registry. We use the "tir" namespace for tir related intrinsics, for example bitwise and is now registered under `tir.bitwise_and`. To simplify upgrade, we introduce a `tir.call_extern` intrinsic that allows us to call into arbitary external function without type checking. However, we should move towards more type checked variants in the system. Under the new op design. We should no longer try to pattern match all the specific intrincis. Instead, we should rely on attr of each Op to do transformation. For example, the vectorization pass depends on the TVectorizable property of the op, which can be registered independently. In this way, we can still grow the number of intrinsics when necessary without having to change all the passes. The same rule applies for tensor expression AD. Currently we are performing AD by pattern match on operators like exp, sin, cos. We should instead change to the ajoint registeration mechanism like those in relay. Followup refactors need to be performed, including: - Fold the Call.call_type into operator's attribute. - Enrich the operator registry information - Refactor passes(e.g. AD, intrin lowering) to use the attribute based transformation * Fix nms * Fix remaining testcase * Address review comment --- include/tvm/relay/expr.h | 2 +- include/tvm/tir/builtin.h | 540 +++++++++++++++++++++ include/tvm/tir/expr.h | 399 +-------------- include/tvm/tir/function.h | 4 +- include/tvm/tir/op.h | 8 +- include/tvm/tir/op_attr_types.h | 48 ++ include/tvm/tir/stmt.h | 4 +- python/tvm/contrib/nvcc.py | 3 +- python/tvm/target/datatype.py | 11 +- python/tvm/target/intrin.py | 10 +- python/tvm/te/hybrid/calls.py | 2 +- python/tvm/tir/expr.py | 22 +- python/tvm/tir/ir_builder.py | 2 +- python/tvm/tir/op.py | 81 ++-- src/arith/const_int_bound.cc | 6 +- src/arith/ir_mutator_with_analyzer.cc | 9 +- src/arith/modular_set.cc | 3 +- src/arith/pattern_match.h | 45 +- src/arith/rewrite_simplify.cc | 10 +- src/contrib/hybrid/codegen_hybrid.cc | 31 +- src/ir/op.cc | 2 +- src/printer/tir_text_printer.cc | 15 +- src/relay/transforms/pass_util.h | 4 +- src/target/intrin_rule.h | 16 +- src/target/llvm/codegen_arm.cc | 15 +- src/target/llvm/codegen_cpu.cc | 66 +-- src/target/llvm/codegen_cpu.h | 5 +- src/target/llvm/codegen_llvm.cc | 76 +-- src/target/llvm/codegen_llvm.h | 10 +- src/target/llvm/codegen_nvptx.cc | 8 +- src/target/llvm/codegen_x86_64.cc | 4 +- src/target/llvm/intrin_rule_llvm.cc | 20 +- src/target/llvm/intrin_rule_llvm.h | 6 +- src/target/llvm/intrin_rule_nvptx.cc | 17 +- src/target/llvm/intrin_rule_rocm.cc | 34 +- src/target/source/codegen_c.cc | 186 +++---- src/target/source/codegen_c.h | 17 + src/target/source/codegen_c_host.cc | 6 +- src/target/source/codegen_cuda.cc | 123 ++--- src/target/source/codegen_cuda.h | 7 + src/target/source/codegen_metal.cc | 2 +- src/target/source/intrin_rule_cuda.cc | 58 ++- src/target/source/intrin_rule_opencl.cc | 4 +- src/target/spirv/codegen_spirv.cc | 29 +- src/target/spirv/intrin_rule_spirv.cc | 4 +- src/target/stackvm/codegen_stackvm.cc | 49 +- src/target/stackvm/codegen_stackvm.h | 4 + src/te/autodiff/jacobian.cc | 34 +- src/te/operation/compute_op.cc | 3 +- src/te/operation/cross_thread_reduction.cc | 4 +- src/te/operation/extern_op.cc | 2 +- src/te/operation/tensor_compute_op.cc | 5 +- src/te/operation/tensorize.cc | 4 +- .../schedule_postproc_rewrite_for_tensor_core.cc | 33 +- src/tir/analysis/verify_memory.cc | 3 +- src/tir/ir/buffer.cc | 3 +- src/tir/ir/expr.cc | 45 +- src/tir/ir/expr_functor.cc | 2 +- src/tir/ir/stmt.cc | 8 + src/tir/op/builtin.cc | 155 ++++++ src/tir/{ir => op}/op.cc | 172 ++++++- src/tir/op/runtime.cc | 39 ++ src/tir/transforms/arg_binder.cc | 27 +- src/tir/transforms/bf16_legalize.cc | 5 +- src/tir/transforms/bound_checker.cc | 3 +- src/tir/transforms/combine_context_call.cc | 11 +- src/tir/transforms/coproc_sync.cc | 22 +- src/tir/transforms/inject_virtual_thread.cc | 9 +- src/tir/transforms/ir_util.h | 14 +- src/tir/transforms/loop_partition.cc | 9 +- .../transforms/lower_device_storage_access_info.cc | 3 +- src/tir/transforms/lower_intrin.cc | 20 +- src/tir/transforms/lower_thread_allreduce.cc | 19 +- src/tir/transforms/lower_tvm_builtin.cc | 58 +-- src/tir/transforms/lower_warp_memory.cc | 5 +- src/tir/transforms/make_packed_api.cc | 7 +- src/tir/transforms/narrow_datatype.cc | 35 +- src/tir/transforms/rewrite_unsafe_select.cc | 7 +- src/tir/transforms/split_host_device.cc | 3 +- src/tir/transforms/storage_access.cc | 6 +- src/tir/transforms/storage_flatten.cc | 10 +- src/tir/transforms/storage_rewrite.cc | 9 +- src/tir/transforms/tensorcore_infer_fragment.cc | 8 +- src/tir/transforms/thread_storage_sync.cc | 11 +- src/tir/transforms/vectorize_loop.cc | 18 +- tests/cpp/ir_functor_test.cc | 6 +- tests/python/relay/test_ir_parser.py | 4 +- .../unittest/test_arith_canonical_simplify.py | 3 +- .../python/unittest/test_target_codegen_c_host.py | 2 +- tests/python/unittest/test_target_codegen_llvm.py | 12 +- .../unittest/test_target_codegen_static_init.py | 2 +- .../unittest/test_te_schedule_tensor_core.py | 8 +- tests/python/unittest/test_tir_constructor.py | 6 +- tests/python/unittest/test_tir_nodes.py | 27 +- .../unittest/test_tir_stmt_functor_ir_transform.py | 9 +- .../unittest/test_tir_transform_bf16_legalize.py | 12 +- .../test_tir_transform_combine_context_call.py | 2 +- .../unittest/test_tir_transform_coproc_sync.py | 20 +- .../test_tir_transform_inject_double_buffer.py | 2 +- .../test_tir_transform_inject_virtual_thread.py | 2 +- .../test_tir_transform_rewrite_unsafe_select.py | 6 +- .../unittest/test_tir_transform_storage_flatten.py | 2 +- .../unittest/test_tir_transform_thread_sync.py | 2 +- .../unittest/test_tir_transform_vectorize.py | 4 +- topi/include/topi/detail/extern.h | 9 +- topi/include/topi/elemwise.h | 4 +- topi/python/topi/arm_cpu/bitserial_conv2d.py | 10 +- topi/python/topi/arm_cpu/tensor_intrin.py | 4 +- topi/python/topi/cuda/nms.py | 5 +- topi/python/topi/cuda/rcnn/proposal.py | 4 +- topi/python/topi/cuda/sort.py | 6 +- topi/python/topi/cuda/tensor_intrin.py | 11 +- topi/python/topi/x86/tensor_intrin.py | 14 +- topi/tests/python/test_topi_basic.py | 2 +- topi/tests/python/test_topi_math.py | 6 +- tutorials/language/intrin_math.py | 12 +- tutorials/optimize/opt_conv_tensorcore.py | 8 +- vta/python/vta/environment.py | 14 +- vta/python/vta/intrin.py | 8 +- vta/python/vta/transform.py | 27 +- 120 files changed, 1965 insertions(+), 1168 deletions(-) create mode 100644 include/tvm/tir/builtin.h create mode 100644 include/tvm/tir/op_attr_types.h create mode 100644 src/tir/op/builtin.cc rename src/tir/{ir => op}/op.cc (82%) create mode 100644 src/tir/op/runtime.cc diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 779bcc3..3c156df 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -226,7 +226,7 @@ class CallNode : public ExprNode { /*! * \brief The operator(function) being invoked * - * - It can be relay::Op which corresponds to the primitive operators. + * - It can be tvm::Op which corresponds to the primitive operators. * - It can also be user defined functions (Function, GlobalVar, Var). */ Expr op; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h new file mode 100644 index 0000000..96526cc --- /dev/null +++ b/include/tvm/tir/builtin.h @@ -0,0 +1,540 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/builtin.h + * \brief TIR builtin intrinsics. + * + * TIR builtin intrinsics are stored as tvm:Op. + * They are processed in the same way as we process Ops. + * + * It is not necessary to create a function for every Op, + * as we can obtain them through Op::Get. + * + * This file contains the most commonly used intrinsics or + * those that have special semantics and need compiler support. + */ +#ifndef TVM_TIR_BUILTIN_H_ +#define TVM_TIR_BUILTIN_H_ + +#include +#include + +namespace tvm { +namespace tir { + +/*! \brief Collection of builtin intrinsics as ops */ +namespace builtin { +/*! + * \brief Reinterpret the value using the target type. + */ +TVM_DLL const Op& reinterpret(); + +/*! + * \brief Marks a condition is likely going to happen. + */ +TVM_DLL const Op& likely(); + +/*! + * \brief Bitwise and operator. + */ +TVM_DLL const Op& bitwise_and(); + +/*! + * \brief Bitwise or operator. + */ +TVM_DLL const Op& bitwise_or(); + +/*! + * \brief Bitwise xor operator. + */ +TVM_DLL const Op& bitwise_xor(); + +/*! + * \brief Bitwise not operator. + */ +TVM_DLL const Op& bitwise_not(); + +/*! + * \brief Left shift + */ +TVM_DLL const Op& shift_left(); + +/*! + * \brief Right shift + */ +TVM_DLL const Op& shift_right(); + +/*! + * \brief See pesudo code + * + * Construct a big uint that may not be representable by int64 + * + * Expr large_uint_imm(uint32_t v0, uin32_t v1) { + * return (v1 << 32) | v0; + * } + */ +TVM_DLL const Op& large_uint_imm(); + +/*! + * \brief See pesudo code + * + * Handle address_of(Load *op) { + * return &op->buffer_var[index]; + * } + */ +TVM_DLL const Op& address_of(); + +/*! + * \brief Same as select, used for unsafe memory access. + * + * Type tvm_if_then_else(cond, a, b) { + * return cond ? a : b; + * } + */ +TVM_DLL const Op& if_then_else(); + +/*! + * \brief See pesudo code + * + * bool isnullptr(void* handle) { + * return handle == nullptr + * } + */ +TVM_DLL const Op& isnullptr(); + +/*! + * \brief Check if value is nan + */ +TVM_DLL const Op& isnan(); + +/*! + * \brief Popcount + */ +TVM_DLL const Op& popcount(); + +/*! + * \brief Fused multiply add + * + * Type fma(a, b, c) { + * return a * b + c; + * } + */ +TVM_DLL const Op& fma(); + +/*! + * \brief Call an extern C function with given name + * and signature from the types of args in the runtime environment. + * + * Type call_extern(name, args...) { + * return dlsym(name)(args...); + * } + * + * \note This intrinsic does not provide any type checking, + * and is main used for backward compatibility reasons. + * Always consider use pre-registered and typed tvm::Op first. + */ +TVM_DLL const Op& call_extern(); + +/*! + * \brief Call an LLVM intrinsic with a given intrinsic id + * and signature from the types of args in the runtime environment. + * + * Type call_llvm_intrin(intrin_id, args...) { + * return dlsym(name)(args...); + * } + * + * \note This op does not provide any type checking. + */ +TVM_DLL const Op& call_llvm_intrin(); + +/*! + * \brief Call an SPIRV GLSL450 intrinsic. + * + * Type call_spirv_glsl450(intrin_id, args...) { + * return dlsym(name)(args...); + * } + * + * \note This op does not provide any type checking. + */ +TVM_DLL const Op& call_spirv_glsl450(); + +// TODO(tvm-team) revisit the builtins below +// some of them can simply become ops with special codegen attr. +/*! + * \brief Prefetch a cacheline + */ +TVM_DLL const Op& prefetch(); + +/*! + * \brief Get head access address with memory access pattern info. + * + * This operator also marks range of the memory access + * The offset and extent are in unit of the DType(including vectorization factor). + * rw_mask is a bit_mask setting whether the access is a read(1) or write(2). + * The access is assume to happen in the current expression. + * + * PtrType tvm_access_ptr(Expr dtype, DType* data, + * int offset, int extent, + * int rw_mask) { + * // DType == dtype.type(); + * return &data[offset]; + * } + */ +TVM_DLL const Op& tvm_access_ptr(); + +/*! + * \brief Create a function local static handle that iniitalizes to nullptr. + * can be used to cache function local static resources. + */ +TVM_DLL const Op& tvm_static_handle(); + +/*! + * \brief Return a unique context id, used for hint of workspace separation. + * Different context id ganrantees not having overlapping workspace. + */ +TVM_DLL const Op& tvm_context_id(); + +/*! + * \brief tvm_tuple is not an actual function and cannot codegen. + * It is used to represent tuple structure in value field of AttrStmt, + * for the sake of giving hint to optimization. + * + * Handle tvm_tuple(value0, value1, ..., value_n); + */ +TVM_DLL const Op& tvm_tuple(); + +/*! + * \brief See pesudo code + * + * Type tvm_struct_get(StructType* arr, int index, int field_id) { + * return arr[index]->field; + * } + * \sa TVMStructFieldKind + */ +TVM_DLL const Op& tvm_struct_get(); + +/*! + * \brief See pesudo code + * + * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) { + * arr[index]->field = value; + * } + * \sa TVMStructFieldKind + */ +TVM_DLL const Op& tvm_struct_set(); + +/*! + * \brief See pesudo code + * + * void tvm_throw_last_error() { + * throw TVMGetLastError(); + * } + */ +TVM_DLL const Op& tvm_throw_last_error(); + +/*! + * \brief See pesudo code + * + * dtype in {shape, array, arg_value, arg_tcode} + * + * Handle tvm_stack_alloca(string dtype, int num) { + * return new on stack dtype[num]; + * } + */ +TVM_DLL const Op& tvm_stack_alloca(); + +/*! + * \brief Allocate a shape tuple on stack, return the handle. + * + * Handle tvm_stack_make_shape(list args) { + * ret = alloca stack int64_t[len(args)]; + * for i in range(len(args)): + * ret[i] = args[i] + * return &ret[0]; + * } + */ +TVM_DLL const Op& tvm_stack_make_shape(); + +/*! + * \brief Allocate a NDArray(DLTensor) on stack, return the handle. + * + * Type tvm_stack_make_array(Expr data, + * Expr shape, + * Expr strides, + * Expr ndim, + * Expr dtype, + * Expr elem_offset) { + * ret = alloca stack DLTensor(); + * ret->data = data; + * ret->shape = shape; + * ret->strides = strides != 0 ? strides : nullptr; + * ret->ndim = ndim; + * ret->dtype = dtype.type(); + * ret->byte_offset = elem_offset * sizeof(dtype); + * return ret; + * } + */ +TVM_DLL const Op& tvm_stack_make_array(); + +/*! + * \brief See pesudo code + * + * int tvm_call_packed(name, TVMValue* args) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * (*f)(args, type_code_of(args), len(args)); + * return 0; + * } + */ +TVM_DLL const Op& tvm_call_packed(); + +/*! + * \brief See pesudo code + * + * int tvm_call_trace_packed(name, TVMValue* args) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * (*f)(args, type_code_of(args), len(args)); + * return 0; + * } + */ +TVM_DLL const Op& tvm_call_trace_packed(); + +/*! + * \brief See pesudo code + * Mark the content as thread local context, can get optimized + * by only call the call once at thread start. + * + * Do not allow nesting(getting a thread context from another). + * + * Handle tvm_thread_context(Expr call) { + * return call; + * } + */ +TVM_DLL const Op& tvm_thread_context(); + +/*! + * \brief Lowered version of call packed, the space of value and + * type codes are explicitly allocated. + * + * int tvm_call_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * f->CallPacked(TVMArgs(value_stack[begin:end], + * tcode_stack[begin:end]), + * TVMRetValue(value_stack + end, tcode_stack + end)); + * } + */ +TVM_DLL const Op& tvm_call_packed_lowered(); + +/*! + * \brief Lowered version of trace intrinsic, the space of value and + * type codes are explicitly allocated. The return value is the + * (end - 1) value on the stack. + * + * int tvm_call_trace_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * f->CallPacked(TVMArgs(value_stack[begin:end], + * tcode_stack[begin:end]), + * TVMRetValue(value_stack + end, tcode_stack + end)); + * } + */ +TVM_DLL const Op& tvm_call_trace_packed_lowered(); + +/*! + * \brief See pseudo code + * + * int tvm_storage_sync(std::string storage_scope) { + * __sync(storage_scope); + * return 0; + * } + */ +TVM_DLL const Op& tvm_storage_sync(); + +/*! + * \brief See pseudo code + * + * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id); + * } + * + * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id - offset); + * } + * + * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id + offset); + * } + * + * unsigned tvm_warp_activemask() { + * return (32-bit mask of currently active threads in the calling warp); + * } + * + * Parameter warp_id indicates the source thread ID in a warp. + * + * Parameter offset indicates the relative distance to this_warp_id. + * + * Parameter width indicates the number of threads involved in one + * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, + * __shfl_down_sync and __activemask. + * + * Parameter warp_size is the size of a warp, which helps a backend + * to determine wheter the width paramter is legal. + * + */ +TVM_DLL const Op& tvm_warp_shuffle(); +TVM_DLL const Op& tvm_warp_shuffle_up(); +TVM_DLL const Op& tvm_warp_shuffle_down(); +TVM_DLL const Op& tvm_warp_activemask(); + +/*! + * \brief Initialize the global barrier. + * Call this at beginning of kernel that need global barrier. + */ +TVM_DLL const Op& tvm_global_barrier_kinit(); + +/*! + * \brief See pesudo code + * + * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, + * Var reduce_temp0, .., Var thread_idx1, ...) { + * // constraint by the other thread_idx remain the same. + * // reduce_temp is used to save intermediate result. + * reduce_temp0, ... = reduce(combiner, source0, ..., cond + * over [thread_idx1, thread_idx2] passed by any caller) + * } + */ +TVM_DLL const Op& tvm_thread_allreduce(); + +// TODO(tvm-team) TensorCore specific intrinsics should be directly registered under +// cuda. namespace and used through op. +/*! + * \brief tvm intrinsic for tensor core load operators. + * + * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment. + * // Determine fragment layout(column-major or row major) by layout. + * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. + * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); + * } + */ +TVM_DLL const Op& tvm_load_matrix_sync(); + +/*! + * \brief tvm intrinsic for tensor core mma_sync operators. + * + * void tvm_mma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +TVM_DLL const Op& tvm_mma_sync(); + +/*! + * \brief tvm intrinsic for tensor core bmma_sync operators. + * + * void tvm_bmma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +TVM_DLL const Op& tvm_bmma_sync(); + +/*! + * \brief tvm intrinsic for tensor core fill_fragment operators. + * + * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr value) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::fill_fragment(fragment[index], value); + * } + */ +TVM_DLL const Op& tvm_fill_fragment(); + +/*! + * \brief tvm intrinsic for tensor core store operators. + * + * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); + * } + */ +TVM_DLL const Op& tvm_store_matrix_sync(); + +// TODO(tvm-team) replace the usage of the vector operations by Shuffle. +/*! + * \brief Get the high level half of the vector + */ +TVM_DLL const Op& vectorhigh(); + +/*! + * \brief Get the low-level half of the vector + */ +TVM_DLL const Op& vectorlow(); + +/*! + * \brief Concat two vectors. + */ +TVM_DLL const Op& vectorcombine(); + +/*! \brief The kind of structure field info used in intrinsic */ +enum TVMStructFieldKind : int { + // array head address + kArrAddr, + kArrData, + kArrShape, + kArrStrides, + kArrNDim, + kArrTypeCode, + kArrTypeBits, + kArrTypeLanes, + kArrByteOffset, + kArrDeviceId, + kArrDeviceType, + kArrKindBound_, + // TVMValue field + kTVMValueContent, + kTVMValueKindBound_ +}; +} // namespace builtin +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_BUILTIN_H_ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 1518d1f..a51f709 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -888,8 +888,14 @@ class CallNode : public PrimExprNode { /*! \brief Intrinsic functions that are pure. */ PureIntrinsic = 5 }; - /*! \brief The name of the function/intrinsic. */ - String name; + /*! + * \brief The operator(function) being invoked + * + * - It can be tvm::Op which corresponds to the primitive operators(intrinsics). + * - It can also be another function in the IRModule (GlobalVar). + */ + RelayExpr op; + /*! \brief The arguments. */ Array args; /*! \brief Type of calls. */ @@ -897,19 +903,19 @@ class CallNode : public PrimExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); - v->Visit("name", &name); + v->Visit("op", &op); v->Visit("args", &args); v->Visit("call_type", &call_type); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) && + return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) && equal(call_type, other->call_type); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); - hash_reduce(name); + hash_reduce(op); hash_reduce(args); hash_reduce(call_type); } @@ -917,37 +923,8 @@ class CallNode : public PrimExprNode { /*! \return Whether call node is pure. */ bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); } - /*! - * \return Whether call node corresponds to a defined intrinsic. - * \param intrin_name The name of the intrinsic. - */ - bool is_intrinsic(const char* intrin_name) const { - return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name); - } - - /*! \return Whether call node can be vectorized. */ - bool is_vectorizable() const; - static constexpr const char* _type_key = "tir.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); - - // Build-in intrinsics - static constexpr const char* reinterpret = "reinterpret"; - static constexpr const char* bitwise_and = "bitwise_and"; - static constexpr const char* bitwise_not = "bitwise_not"; - static constexpr const char* bitwise_xor = "bitwise_xor"; - static constexpr const char* bitwise_or = "bitwise_or"; - static constexpr const char* shift_left = "shift_left"; - static constexpr const char* shift_right = "shift_right"; - static constexpr const char* popcount = "popcount"; - static constexpr const char* likely = "likely"; - static constexpr const char* prefetch = "prefetch"; - static constexpr const char* isnan = "isnan"; - static constexpr const char* isfinite = "isfinite"; - static constexpr const char* isinf = "isinf"; - - /*! \brief Vectorizable intrinsic list. */ - static const char* vectorizable_intrinsics[]; }; /*! @@ -958,7 +935,7 @@ class Call : public PrimExpr { public: using CallType = CallNode::CallType; - TVM_DLL Call(DataType dtype, String name, Array args, CallType call_type); + TVM_DLL Call(DataType dtype, RelayExpr op, Array args, CallType call_type); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); }; @@ -1167,358 +1144,6 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } return ret; } - -/*! \brief namespace of TVM Intrinsic functions */ -namespace intrinsic { -/*! - * \brief See pesudo code - * - * Construct a big uint that may not be representable by int64 - * - * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) { - * return (v1 << 32) | v0; - * } - */ -constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm"; -/*! - * \brief See pesudo code - * - * Handle tvm_address_of(Load *op) { - * return &op->buffer_var[index]; - * } - */ -constexpr const char* tvm_address_of = "tvm_address_of"; -/*! - * \brief Same as select, used for unsafe memory access. - * - * Type tvm_if_then_else(cond, a, b) { - * return cond ? a : b; - * } - */ -constexpr const char* tvm_if_then_else = "tvm_if_then_else"; -/*! - * \brief Get head access address with memory access pattern info. - * - * This operator also marks range of the memory access - * The offset and extent are in unit of the DType(including vectorization factor). - * rw_mask is a bit_mask setting whether the access is a read(1) or write(2). - * The access is assume to happen in the current expression. - * - * PtrType tvm_access_ptr(Expr dtype, DType* data, - * int offset, int extent, - * int rw_mask) { - * // DType == dtype.type(); - * return &data[offset]; - * } - */ -constexpr const char* tvm_access_ptr = "tvm_access_ptr"; -/*! - * \brief Create a function local static handle that iniitalizes to nullptr. - * can be used to cache function local static resources. - */ -constexpr const char* tvm_static_handle = "tvm_static_handle"; -/*! - * \brief Return a unique context id, used for hint of workspace separation. - * Different context id ganrantees not having overlapping workspace. - */ -constexpr const char* tvm_context_id = "tvm_context_id"; -/*! - * \brief tvm_tuple is not an actual function and cannot codegen. - * It is used to represent tuple structure in value field of AttrStmt, - * for the sake of giving hint to optimization. - * - * Handle tvm_tuple(value0, value1, ..., value_n); - */ -constexpr const char* tvm_tuple = "tvm_tuple"; -/*! - * \brief See pesudo code - * - * Type tvm_struct_get(StructType* arr, int index, int field_id) { - * return arr[index]->field; - * } - * \sa TVMStructFieldKind - */ -constexpr const char* tvm_struct_get = "tvm_struct_get"; -/*! - * \brief See pesudo code - * - * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) { - * arr[index]->field = value; - * } - * \sa TVMStructFieldKind - */ -constexpr const char* tvm_struct_set = "tvm_struct_set"; -/*! - * \brief See pesudo code - * - * bool tvm_handle_is_null(void* handle) { - * return handle == nullptr - * } - */ -constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; -/*! - * \brief See pesudo code - * - * void tvm_throw_last_error() { - * throw TVMGetLastError(); - * } - */ -constexpr const char* tvm_throw_last_error = "tvm_throw_last_error"; -/*! - * \brief See pesudo code - * - * dtype in {shape, array, arg_value, arg_tcode} - * - * Handle tvm_stack_alloca(string dtype, int num) { - * return new on stack dtype[num]; - * } - */ -constexpr const char* tvm_stack_alloca = "tvm_stack_alloca"; -/*! - * \brief Allocate a shape tuple on stack, return the handle. - * - * Handle tvm_stack_make_shape(list args) { - * ret = alloca stack int64_t[len(args)]; - * for i in range(len(args)): - * ret[i] = args[i] - * return &ret[0]; - * } - */ -constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape"; -/*! - * \brief Allocate a NDArray(DLTensor) on stack, return the handle. - * - * Type tvm_stack_make_array(Expr data, - * Expr shape, - * Expr strides, - * Expr ndim, - * Expr dtype, - * Expr elem_offset) { - * ret = alloca stack DLTensor(); - * ret->data = data; - * ret->shape = shape; - * ret->strides = strides != 0 ? strides : nullptr; - * ret->ndim = ndim; - * ret->dtype = dtype.type(); - * ret->byte_offset = elem_offset * sizeof(dtype); - * return ret; - * } - */ -constexpr const char* tvm_stack_make_array = "tvm_stack_make_array"; -/*! - * \brief See pesudo code - * - * int tvm_call_packed(name, TVMValue* args) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); - * return 0; - * } - */ -constexpr const char* tvm_call_packed = "tvm_call_packed"; -/*! - * \brief See pesudo code - * - * int tvm_call_trace_packed(name, TVMValue* args) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); - * return 0; - * } - */ -constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed"; -/*! - * \brief See pesudo code - * Mark the content as thread local context, can get optimized - * by only call the call once at thread start. - * - * Do not allow nesting(getting a thread context from another). - * - * Handle tvm_thread_context(Expr call) { - * return call; - * } - */ -constexpr const char* tvm_thread_context = "tvm_thread_context"; -/*! - * \brief Lowered version of call packed, the space of value and - * type codes are explicitly allocated. - * - * int tvm_call_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); - * } - */ -constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; -/*! - * \brief Lowered version of trace intrinsic, the space of value and - * type codes are explicitly allocated. The return value is the - * (end - 1) value on the stack. - * - * int tvm_call_trace_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); - * } - */ -constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered"; -/*! - * \brief See pseudo code - * - * int tvm_storage_sync(std::string storage_scope) { - * __sync(storage_scope); - * return 0; - * } - */ -constexpr const char* tvm_storage_sync = "tvm_storage_sync"; - -/*! - * \brief See pseudo code - * - * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id); - * } - * - * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id - offset); - * } - * - * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id + offset); - * } - * - * unsigned tvm_warp_activemask() { - * return (32-bit mask of currently active threads in the calling warp); - * } - * - * Parameter warp_id indicates the source thread ID in a warp. - * - * Parameter offset indicates the relative distance to this_warp_id. - * - * Parameter width indicates the number of threads involved in one - * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, - * __shfl_down_sync and __activemask. - * - * Parameter warp_size is the size of a warp, which helps a backend - * to determine wheter the width paramter is legal. - * - */ -constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; -constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up"; -constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down"; -constexpr const char* tvm_warp_activemask = "tvm_warp_activemask"; - -/*! - * \brief Initialize the global barrier. - * Call this at beginning of kernel that need global barrier. - */ -constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; -/*! - * \brief See pesudo code - * - * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, - * Var reduce_temp0, .., Var thread_idx1, ...) { - * // constraint by the other thread_idx remain the same. - * // reduce_temp is used to save intermediate result. - * reduce_temp0, ... = reduce(combiner, source0, ..., cond - * over [thread_idx1, thread_idx2] passed by any caller) - * } - */ -constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; -/*! - * \brief tvm intrinsic for tensor core load operators. - * - * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment. - * // Determine fragment layout(column-major or row major) by layout. - * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. - * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); - * } - */ -constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; -/*! - * \brief tvm intrinsic for tensor core mma_sync operators. - * - * void tvm_mma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -constexpr const char* tvm_mma_sync = "tvm_mma_sync"; -/*! - * \brief tvm intrinsic for tensor core bmma_sync operators. - * - * void tvm_bmma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -constexpr const char* tvm_bmma_sync = "tvm_bmma_sync"; -/*! - * \brief tvm intrinsic for tensor core fill_fragment operators. - * - * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr value) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::fill_fragment(fragment[index], value); - * } - */ -constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; -/*! - * \brief tvm intrinsic for tensor core store operators. - * - * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); - * } - */ -constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; - -/*! \brief The kind of structure field info used in intrinsic */ -enum TVMStructFieldKind : int { - // array head address - kArrAddr, - kArrData, - kArrShape, - kArrStrides, - kArrNDim, - kArrTypeCode, - kArrTypeBits, - kArrTypeLanes, - kArrByteOffset, - kArrDeviceId, - kArrDeviceType, - kArrKindBound_, - // TVMValue field - kTVMValueContent, - kTVMValueKindBound_ -}; -} // namespace intrinsic - } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 919391e..caddd99 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -87,8 +87,6 @@ class PrimFuncNode : public BaseFuncNode { * While we could have express parameter unpacking and constraint using * normal statements, making buffer_map as first class citizen of PrimFunc * will make program analysis much easier. - * - * \note This field can be nullptr */ Map buffer_map; @@ -144,7 +142,7 @@ class PrimFunc : public BaseFunc { * \param attrs Additional function attributes. */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = NullValue>(), + Map buffer_map = Map(), DictAttrs attrs = NullValue()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 2948bb2..286b6d7 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -28,6 +28,7 @@ #ifndef TVM_TIR_OP_H_ #define TVM_TIR_OP_H_ +#include #include #include #include @@ -552,9 +553,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x) { \ - return tir::Call(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ + static const Op& op = Op::Get("tir." #OpName); \ + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \ } TVM_DECLARE_INTRIN_UNARY(exp); diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h new file mode 100644 index 0000000..d7c1350 --- /dev/null +++ b/include/tvm/tir/op_attr_types.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/op_attr_types.h + * \brief Attribute types in the Op registry for TIR ops. + * + * These attributes can be set via OpRegEntry::set_attr + * + * \sa tvm/ir/op.h + */ +#ifndef TVM_TIR_OP_ATTR_TYPES_H_ +#define TVM_TIR_OP_ATTR_TYPES_H_ + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Global symbol of the op after lowering. + */ +using TGlobalSymbol = String; + +/*! + * \brief Whether the op is overloaded for vector form. + */ +using TVectorizable = bool; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_OP_ATTR_TYPES_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index be1c567..b928aec 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1004,9 +1004,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \param dtype The data type * \return Expr a expression with dtype. */ -inline PrimExpr TypeAnnotation(DataType dtype) { - return tir::Call(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); -} +TVM_DLL PrimExpr TypeAnnotation(DataType dtype); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index fc82320..8c3d34a 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -98,7 +98,8 @@ def compile_cuda(code, (out, _) = proc.communicate() if proc.returncode != 0: - msg = "Compilation error:\n" + msg = code + msg += "\nCompilation error:\n" msg += py_str(out) raise RuntimeError(msg) diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index e42ac6b..f93a943 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -18,8 +18,9 @@ import tvm._ffi import tvm.runtime._ffi_api -from tvm.runtime import convert, DataType -from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm +from tvm.runtime import DataType +import tvm.tir +from tvm.tir.expr import Cast as _Cast, FloatImm as _FloatImm def register(type_name, type_code): @@ -135,9 +136,7 @@ def create_lower_func(extern_func_name): if t.lanes > 1: dtype += "x" + str(t.lanes) if isinstance(op, (_Cast, _FloatImm)): - return _Call(dtype, extern_func_name, convert([op.value]), - _Call.Extern) - return _Call(dtype, extern_func_name, convert([op.a, op.b]), - _Call.Extern) + return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value) + return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b) return lower diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py index acb0efe..78da8a6 100644 --- a/python/tvm/target/intrin.py +++ b/python/tvm/target/intrin.py @@ -83,10 +83,14 @@ def _rule_float_suffix(op): -------- register_intrin_rule : The registeration function for intrin rule. """ + name = op.name + assert name.startswith("tir.") + prefix = name[4:] + if op.dtype == "float32": - return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) + return call_pure_extern(op.dtype, "%sf" % prefix, *op.args) if op.dtype == "float64": - return call_pure_extern(op.dtype, op.name, *op.args) + return call_pure_extern(op.dtype, prefix, *op.args) return op @@ -111,7 +115,7 @@ def _rule_float_direct(op): register_intrin_rule : The registeration function for intrin rule. """ if str(op.dtype).startswith("float"): - return call_pure_extern(op.dtype, op.name, *op.args) + return call_pure_extern(op.dtype, op.op.name[4:], *op.args) return None # opencl pattern for exp diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index dfbb185..a119c20 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -148,7 +148,7 @@ def likely(func_id, args): _internal_assert(args.__len__() == 1, \ "Only one expression can be likely") _internal_assert(func_id == "likely", "This function cannot be directly invoked!") - return call_pure_intrin(args[0].dtype, 'likely', *args) + return call_pure_intrin(args[0].dtype, 'tir.likely', *args) def max_num_threads(func_id, args): diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 3b580ef..386badf 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -30,7 +30,7 @@ For example, you can use addexp.a to get the left operand of an Add node. import tvm._ffi from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, Op import tvm.ir._ffi_api from . import generic as _generic from . import _ffi_api @@ -144,7 +144,7 @@ class ExprOp(object): def __invert__(self): if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") - return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic) + return _ffi_api.bitwise_not(self) def __lt__(self, other): return _ffi_api._OpLT(self, other) @@ -968,8 +968,9 @@ class Call(PrimExprWithOp): dtype : str The return data type - name : str - The name of the function + op : Union[RelayExpr, str] + The function to be called, or the name + to the global tvm.Op args : list of Expr The input arguments to the call @@ -982,9 +983,16 @@ class Call(PrimExprWithOp): PureExtern = 2 Intrinsic = 4 PureIntrinsic = 5 - def __init__(self, dtype, name, args, call_type): - self.__init_handle_by_constructor__( - _ffi_api.Call, dtype, name, args, call_type) + def __init__(self, dtype, op, args, call_type): + if isinstance(op, str): + if not op.startswith("tir."): + raise ValueError( + ("Cannot handle str op argument %s. This function only handles str " + + "argument with the tir namespace. If you are " + + "certain about the intrinsic name, pass in Op.get(name) instead") % op) + op = Op.get(op) + self.__init_handle_by_constructor__( + _ffi_api.Call, dtype, op, args, call_type) @tvm._ffi.register_object("tir.Let") diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 47ba2e2..089127c 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -379,7 +379,7 @@ class IRBuilder(object): expr : Expr The expression will likely tag. """ - return _expr.Call(expr.dtype, "likely", [expr], + return _expr.Call(expr.dtype, "tir.likely", [expr], _expr.Call.PureIntrinsic) def get(self): diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 929d422..6826241 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -18,10 +18,10 @@ """Operators used in TIR expression.""" import tvm._ffi from tvm.runtime import convert, const -from tvm.ir import Array +from tvm.ir import Array, Op from .buffer import Buffer -from .expr import Call, Var, CommReducer +from .expr import Call, StringImm, Var, CommReducer from . import _ffi_api @@ -29,9 +29,9 @@ def _pack_buffer(buf): """Build intrinsics that packs the buffer. """ assert buf.shape - shape = Call("handle", "tvm_stack_make_shape", buf.shape, + shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape, Call.Intrinsic) - strides = Call("handle", "tvm_stack_make_shape", buf.strides, + strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides, Call.Intrinsic) if buf.strides else 0 pack_args = [buf.data, shape, @@ -39,7 +39,7 @@ def _pack_buffer(buf): len(buf.shape), const(0, dtype=buf.dtype), buf.elem_offset] - return Call("handle", "tvm_stack_make_array", + return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, Call.Intrinsic) def call_packed(*args): @@ -68,7 +68,7 @@ def call_packed(*args): """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] return Call( - "int32", "tvm_call_packed", call_args, Call.Intrinsic) + "int32", Op.get("tir.tvm_call_packed"), call_args, Call.Intrinsic) def call_pure_intrin(dtype, func_name, *args): @@ -145,7 +145,7 @@ def call_pure_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.PureExtern) + dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.PureExtern) def call_extern(dtype, func_name, *args): @@ -168,7 +168,7 @@ def call_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.Extern) + dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.Extern) def call_llvm_intrin(dtype, name, *args): @@ -194,7 +194,8 @@ def call_llvm_intrin(dtype, name, *args): from tvm.target import codegen llvm_id = codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name - return call_pure_intrin(dtype, 'llvm_intrin', tvm.tir.const(llvm_id, 'uint32'), *args) + return call_pure_intrin(dtype, Op.get("tir.call_llvm_intrin"), + tvm.tir.const(llvm_id, 'uint32'), *args) def any(*args): @@ -278,7 +279,7 @@ def trace(args, trace_action="tvm.default_trace_action"): call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) return tvm.tir.Call( - args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic) + args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args, tvm.tir.Call.Intrinsic) @@ -327,7 +328,7 @@ def exp(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "exp", x) + return call_pure_intrin(x.dtype, "tir.exp", x) def exp2(x): @@ -343,7 +344,7 @@ def exp2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "exp2", x) + return call_pure_intrin(x.dtype, "tir.exp2", x) def exp10(x): @@ -359,7 +360,7 @@ def exp10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "exp10", x) + return call_pure_intrin(x.dtype, "tir.exp10", x) def erf(x): @@ -375,7 +376,7 @@ def erf(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "erf", x) + return call_pure_intrin(x.dtype, "tir.erf", x) def tanh(x): @@ -391,7 +392,7 @@ def tanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tanh", x) + return call_pure_intrin(x.dtype, "tir.tanh", x) def sigmoid(x): @@ -407,7 +408,7 @@ def sigmoid(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sigmoid", x) + return call_pure_intrin(x.dtype, "tir.sigmoid", x) def log(x): @@ -423,7 +424,7 @@ def log(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log", x) + return call_pure_intrin(x.dtype, "tir.log", x) def log2(x): @@ -439,7 +440,7 @@ def log2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log2", x) + return call_pure_intrin(x.dtype, "tir.log2", x) def log10(x): @@ -455,7 +456,7 @@ def log10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log10", x) + return call_pure_intrin(x.dtype, "tir.log10", x) def log1p(x): @@ -471,7 +472,7 @@ def log1p(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log1p", x) + return call_pure_intrin(x.dtype, "tir.log1p", x) def tan(x): @@ -487,7 +488,7 @@ def tan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tan", x) + return call_pure_intrin(x.dtype, "tir.tan", x) def cos(x): @@ -503,7 +504,7 @@ def cos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "cos", x) + return call_pure_intrin(x.dtype, "tir.cos", x) def cosh(x): @@ -519,7 +520,7 @@ def cosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "cosh", x) + return call_pure_intrin(x.dtype, "tir.cosh", x) def acos(x): @@ -535,7 +536,7 @@ def acos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "acos", x) + return call_pure_intrin(x.dtype, "tir.acos", x) def acosh(x): @@ -551,7 +552,7 @@ def acosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "acosh", x) + return call_pure_intrin(x.dtype, "tir.acosh", x) def sin(x): @@ -567,7 +568,7 @@ def sin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sin", x) + return call_pure_intrin(x.dtype, "tir.sin", x) def sinh(x): @@ -583,7 +584,7 @@ def sinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sinh", x) + return call_pure_intrin(x.dtype, "tir.sinh", x) def asin(x): @@ -599,7 +600,7 @@ def asin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "asin", x) + return call_pure_intrin(x.dtype, "tir.asin", x) def asinh(x): @@ -615,7 +616,7 @@ def asinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "asinh", x) + return call_pure_intrin(x.dtype, "tir.asinh", x) def atan(x): @@ -631,7 +632,7 @@ def atan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "atan", x) + return call_pure_intrin(x.dtype, "tir.atan", x) def atanh(x): @@ -647,7 +648,7 @@ def atanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "atanh", x) + return call_pure_intrin(x.dtype, "tir.atanh", x) def atan2(x1, x2): @@ -666,7 +667,7 @@ def atan2(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "atan2", x1, x2) + return call_pure_intrin(x1.dtype, "tir.atan2", x1, x2) def sqrt(x): @@ -682,7 +683,7 @@ def sqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sqrt", x) + return call_pure_intrin(x.dtype, "tir.sqrt", x) def rsqrt(x): @@ -698,7 +699,7 @@ def rsqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "rsqrt", x) + return call_pure_intrin(x.dtype, "tir.rsqrt", x) def floor(x): @@ -823,7 +824,7 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "nextafter", x1, x2) + return call_pure_intrin(x1.dtype, "tir.nextafter", x1, x2) def hypot(x1, x2): @@ -842,7 +843,7 @@ def hypot(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "hypot", x1, x2) + return call_pure_intrin(x1.dtype, "tir.hypot", x1, x2) def copysign(x1, x2): @@ -861,7 +862,7 @@ def copysign(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "copysign", x1, x2) + return call_pure_intrin(x1.dtype, "tir.copysign", x1, x2) def ldexp(x1, x2): @@ -880,7 +881,7 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "ldexp", x1, x2) + return call_pure_intrin(x1.dtype, "tir.ldexp", x1, x2) def isnan(x): @@ -963,7 +964,7 @@ def popcount(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "popcount", x) + return call_pure_intrin(x.dtype, "tir.popcount", x) def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. @@ -980,7 +981,7 @@ def fmod(x, y): z : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "fmod", x, y) + return call_pure_intrin(x.dtype, "tir.fmod", x, y) def if_then_else(cond, t, f): diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index c33990c..8c90249 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include @@ -284,9 +285,10 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const CallNode* op) final { // only special handle >> and & which can be // used for index calculation. - if (op->is_intrinsic(CallNode::shift_right)) { + + if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { return Everything(op->dtype); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 84e2093..c367d0c 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -56,8 +56,10 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr real_condition = condition; + static auto op_likely = Op::Get("tir.likely"); + if (auto call = condition.as()) { - if (call->is_intrinsic(CallNode::likely)) { + if (call->op.same_as(op_likely)) { real_condition = call->args[0]; } } @@ -122,7 +124,8 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else - if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) { + static auto op_if_then_else = Op::Get("tir.if_then_else"); + if (op->op.same_as(op_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr true_value, false_value; { @@ -143,7 +146,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { false_value.same_as(op->args[2])) { return GetRef(op); } else { - return Call(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); + return Call(op->dtype, op->op, {cond, true_value, false_value}, op->call_type); } } return StmtExprMutator::VisitExpr_(op); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 3457674..108f08c 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -203,7 +204,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor> which can be // used for index calculation. - if (op->is_intrinsic(CallNode::shift_right)) { + if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); } else { return Everything(); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index ff01941..de84251 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -66,6 +66,7 @@ #define TVM_ARITH_PATTERN_MATCH_H_ #include +#include #include #include @@ -655,7 +656,7 @@ class PCallExpr : public Pattern> { bool Match_(const ObjectRef& node) const { if (const tir::CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; - if (ptr->name != Op::kName) return false; + if (!ptr->op.same_as(Op::GetOp())) return false; detail::PCallExprMatchFunctor fmatch(ptr); detail::tuple_for_each(fmatch, args_); return fmatch.matched_; @@ -675,45 +676,45 @@ class PCallExpr : public Pattern> { }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ } \ - static constexpr const char* kName = IntrinStr; \ + static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ }; \ template \ inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ return PCallExpr(a.derived(), b.derived()); \ } -TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); -TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right"); -TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and"); -TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); -TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); +TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left); +TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right); +TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and); +TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); +TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ + } \ + static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } -TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); +TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return tir::Call(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); + return tir::Call(args[1].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); } - static constexpr const char* kName = "tvm_if_then_else"; + static const Op& GetOp() { return tir::builtin::if_then_else(); } }; /*! diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4887ef0..6758c9b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -25,6 +25,7 @@ #include "rewrite_simplify.h" #include +#include #include #include @@ -1508,21 +1509,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; - if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) { + + if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) { return op->args[0]; - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(tir::builtin::shift_right())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] >> op->args[1]; } - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(tir::builtin::shift_left())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] & op->args[1]; } } ExprDeepEqual expr_equal; - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(tir::builtin::likely())) { for (const auto& constraint : literal_constraints_) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (expr_equal(constraint, op->args[0])) { diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index e08f39f..0d5d654 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -23,6 +23,7 @@ #include "codegen_hybrid.h" #include +#include #include #include @@ -216,29 +217,43 @@ void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { os << "]"; } void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->is_intrinsic(CallNode::bitwise_and)) { + if (op->op.same_as(builtin::bitwise_and())) { PrintBinaryIntrinsitc(op, "&", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { + } else if (op->op.same_as(builtin::bitwise_xor())) { PrintBinaryIntrinsitc(op, "^", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { + } else if (op->op.same_as(builtin::bitwise_or())) { PrintBinaryIntrinsitc(op, "|", os, this); - } else if (op->is_intrinsic(CallNode::shift_left)) { + } else if (op->op.same_as(builtin::shift_left())) { PrintBinaryIntrinsitc(op, "<<", os, this); - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(builtin::shift_right())) { PrintBinaryIntrinsitc(op, ">>", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { + } else if (op->op.same_as(builtin::bitwise_not())) { CHECK_EQ(op->args.size(), 1U); os << "(~"; PrintExpr(op->args[0], os); os << ')'; - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + } else if (op->op.same_as(builtin::if_then_else())) { PrintExpr(op->args[1], os); os << " if "; PrintExpr(op->args[0], os); os << " else "; PrintExpr(op->args[2], os); + } else if (op->op.same_as(builtin::call_extern())) { + StringImm fname = Downcast(op->args[0]); + os << fname << "("; + for (size_t i = 1; i < op->args.size(); i++) { + PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; } else { - os << op->name << "("; + auto* ptr_op = op->op.as(); + CHECK(ptr_op != nullptr); + std::string name = ptr_op->name; + CHECK_EQ(name.compare(0, 4, "tir."), 0); + os << name.substr(4) << "("; for (size_t i = 0; i < op->args.size(); i++) { PrintExpr(op->args[i], os); if (i < op->args.size() - 1) { diff --git a/src/ir/op.cc b/src/ir/op.cc index 63d2230..45c3196 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -42,7 +42,7 @@ using OpRegistry = AttrRegistry; // find operator by name const Op& Op::Get(const String& name) { const OpRegEntry* reg = OpRegistry::Global()->Get(name); - CHECK(reg != nullptr) << "Operator " << name << " is not registered"; + CHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered"; return reg->op(); } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 2992737..233a739 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -345,7 +345,14 @@ inline const char* CallType2String(CallNode::CallType t) { Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { Doc doc; - doc << "@" << Doc::Text(op->name) << "("; + if (auto* ptr_op = op->op.as()) { + doc << "@" << Doc::Text(ptr_op->name) << "("; + } else { + // TODO(bohan): Print out the name by he global var in the module. + auto* op_gvar = op->op.as(); + CHECK(op_gvar != nullptr); + doc << "@" << Doc::Text(op_gvar->name_hint) << "("; + } std::vector args; for (const auto& arg : op->args) { args.push_back(Print(arg)); @@ -370,7 +377,7 @@ Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; - doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body); + doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body); return doc; } @@ -389,8 +396,8 @@ Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { Doc doc; - doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" - << PrintBody(op->body); + doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine() + << Print(op->body); return doc; } diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 35bbb23..5f58762 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -121,7 +121,7 @@ inline bool IsAtomic(const Expr& e) { * \return compiler_begin op */ inline const Op& CompilerBeginOp() { - static Op op = Op::Get("annotation.compiler_begin"); + static auto op = Op::Get("annotation.compiler_begin"); return op; } @@ -131,7 +131,7 @@ inline const Op& CompilerBeginOp() { * \return compiler_end op */ inline const Op& CompilerEndOp() { - static Op op = Op::Get("annotation.compiler_end"); + static auto op = Op::Get("annotation.compiler_end"); return op; } diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 5a23e83..36e5539 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -25,6 +25,7 @@ #define TVM_TARGET_INTRIN_RULE_H_ #include +#include #include #include @@ -58,9 +59,20 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); - std::string name = T()(call->dtype, call->name); + // Use string based dispatch to extern for backward compact + // TODO(tvm-team) replace once the new dispatching system is inplace. + const OpNode* op = call->op.as(); + CHECK(op != nullptr); + std::string name = op->name; + CHECK_EQ(name.substr(0, 4), "tir."); + name = T()(call->dtype, name.substr(4)); + if (name.length() != 0) { - *rv = Call(call->dtype, name, call->args, CallNode::PureExtern); + Array new_args = {StringImm(name)}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + *rv = Call(call->dtype, tir::builtin::call_extern(), new_args, CallNode::PureExtern); } else { *rv = e; } diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 991d473..13ce59d 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -46,7 +46,7 @@ class CodeGenARM final : public CodeGenCPU { }; llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { - if (op->is_intrinsic("llvm_intrin")) { + if (op->op.same_as(builtin_call_llvm_intrin_)) { llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); @@ -70,7 +70,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::Call(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -94,14 +94,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = tir::Call(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = + tir::Call(uint8_type, builtin_call_llvm_intrin_, vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = tir::Call(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = + tir::Call(uint16_type, builtin_call_llvm_intrin_, vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } @@ -111,7 +113,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = tir::Call(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = + tir::Call(uint32_type, builtin_call_llvm_intrin_, vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } @@ -121,7 +124,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::Call(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt64_args, CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6ad050a..f855dd5 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -226,7 +226,7 @@ std::unique_ptr CodeGenCPU::Finish() { } llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind) { - if (kind < intrinsic::kArrKindBound_) { + if (kind < builtin::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); } else { @@ -234,40 +234,40 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } } switch (kind) { - case intrinsic::kArrAddr: { + case builtin::kArrAddr: { return builder_->CreateInBoundsGEP(buf, index); } - case intrinsic::kArrData: { + case builtin::kArrData: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)}); } - case intrinsic::kArrShape: { + case builtin::kArrShape: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)}); } - case intrinsic::kArrStrides: { + case builtin::kArrStrides: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)}); } - case intrinsic::kArrNDim: { + case builtin::kArrNDim: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); } - case intrinsic::kArrTypeCode: { + case builtin::kArrTypeCode: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); } - case intrinsic::kArrTypeBits: { + case builtin::kArrTypeBits: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); } - case intrinsic::kArrTypeLanes: { + case builtin::kArrTypeLanes: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); } - case intrinsic::kArrByteOffset: { + case builtin::kArrByteOffset: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); } - case intrinsic::kArrDeviceId: { + case builtin::kArrDeviceId: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); } - case intrinsic::kArrDeviceType: { + case builtin::kArrDeviceType: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); } - case intrinsic::kTVMValueContent: { + case builtin::kTVMValueContent: { CHECK_EQ(t.lanes(), 1); CHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { @@ -289,23 +289,23 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } } -llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { - std::vector arg_values(op->args.size()); - for (size_t i = 0; i < op->args.size(); ++i) { - arg_values[i] = MakeValue(op->args[i]); +llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg) { + std::vector arg_values; + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + arg_values.push_back(MakeValue(args[i])); } std::vector arg_types; for (llvm::Value* v : arg_values) { arg_types.push_back(v->getType()); } - llvm::FunctionType* ftype = - llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_types, false); + llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false); // Check if it is available in global function table as injected function. - auto it = gv_func_map_.find(op->name); + auto it = gv_func_map_.find(global_symbol); if (it != gv_func_map_.end()) { if (it->second == nullptr) { - gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name); - it = gv_func_map_.find(op->name); + gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); + it = gv_func_map_.find(global_symbol); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second)); @@ -314,10 +314,10 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { #endif return builder_->CreateCall(ext_callee, arg_values); } else { - llvm::Function* f = module_->getFunction(op->name); + llvm::Function* f = module_->getFunction(global_symbol); if (f == nullptr) { f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - op->name.operator llvm::StringRef(), module_.get()); + global_symbol.operator llvm::StringRef(), module_.get()); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(f); @@ -773,38 +773,38 @@ void CodeGenCPU::AddStartupFunction() { } llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { return CreateCallPacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) { + } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) { + } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); - } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { builder_->CreateRet(ConstInt32(-1)); return ConstInt32(-1); - } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { + } else if (op->op.same_as(builtin::tvm_struct_get())) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; llvm::Value* ref = this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); - if (kind == intrinsic::kArrAddr) { + if (kind == builtin::kArrAddr) { return builder_->CreatePointerCast(ref, t_void_p_); } else { return builder_->CreateLoad(ref); } - } else if (op->is_intrinsic(intrinsic::tvm_struct_set)) { + } else if (op->op.same_as(builtin::tvm_struct_set())) { CHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); - CHECK(kind != intrinsic::kArrAddr); + CHECK(kind != builtin::kArrAddr); if (value->getType()->isPointerTy()) { value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); } builder_->CreateStore(value, ref); return ConstInt32(0); - } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { + } else if (op->op.same_as(builtin::tvm_stack_alloca())) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; return WithFunctionEntry([&]() -> llvm::AllocaInst* { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 7a14b8f..fdeab41 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -47,7 +47,8 @@ class CodeGenCPU : public CodeGenLLVM { void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const ForNode* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override; - llvm::Value* CreateCallExtern(const CallNode* op) override; + llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg) override; protected: void AddStartupFunction() final; @@ -122,7 +123,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr}; - std::unordered_map gv_func_map_; + std::unordered_map gv_func_map_; // context for direct dynamic lookup llvm::Function* f_tvm_func_call_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr}; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 85e3de5..49f14c3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -653,19 +653,19 @@ llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { return it->second; } -llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { +llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg) { std::vector arg_value; std::vector arg_type; - for (size_t i = 0; i < op->args.size(); ++i) { - arg_value.push_back(MakeValue(op->args[i])); + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + arg_value.push_back(MakeValue(args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::FunctionType* ftype = - llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_type, false); - llvm::Function* f = module_->getFunction(op->name); + llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false); + llvm::Function* f = module_->getFunction(global_symbol); if (f == nullptr) { f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - op->name.operator llvm::StringRef(), module_.get()); + global_symbol.operator llvm::StringRef(), module_.get()); } llvm::CallInst* call = builder_->CreateCall(f, arg_value); return call; @@ -738,7 +738,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type } llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { - if (op->is_intrinsic("llvm_intrin")) { + if (op->op.same_as(builtin_call_llvm_intrin_)) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); int64_t num_signature = Downcast(op->args[1])->value; @@ -759,30 +759,29 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // type as LLVM. llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) : llvm::Type::getVoidTy(*ctx_); - llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {}); return builder_->CreateCall(f, arg_value); - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(builtin::bitwise_and())) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { + } else if (op->op.same_as(builtin::bitwise_or())) { return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { + } else if (op->op.same_as(builtin::bitwise_not())) { return builder_->CreateNot(MakeValue(op->args[0])); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { + } else if (op->op.same_as(builtin::bitwise_xor())) { return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::shift_left)) { + } else if (op->op.same_as(builtin::shift_left())) { return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(builtin::shift_right())) { if (op->args[0].dtype().is_int()) { return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1])); } else { return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1])); } - } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { + } else if (op->op.same_as(builtin::tvm_storage_sync())) { return CreateStorageSync(op); - } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { + } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); const RampNode* r = l->index.as(); @@ -797,17 +796,17 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); - } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) { + } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) { return llvm::Constant::getNullValue(t_void_p_); - } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + } else if (op->op.same_as(builtin::isnullptr())) { return builder_->CreateIsNull(MakeValue(op->args[0])); - } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + } else if (op->op.same_as(builtin::large_uint_imm())) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + } else if (op->op.same_as(builtin::if_then_else())) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; using llvm::BasicBlock; BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); @@ -827,22 +826,22 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(then_value, then_value_block); value->addIncoming(else_value, else_value_block); return value; - } else if (op->is_intrinsic(CallNode::reinterpret)) { + } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); - } else if (op->is_intrinsic(CallNode::isnan)) { + } else if (op->op.same_as(builtin::isnan())) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); return builder_->CreateFCmpUNO(a, a); - } else if (op->is_intrinsic("vectorlow")) { + } else if (op->op.same_as(builtin::vectorlow())) { llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); return CreateVecSlice(v, 0, l / 2); - } else if (op->is_intrinsic("vectorhigh")) { + } else if (op->op.same_as(builtin::vectorhigh())) { llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); return CreateVecSlice(v, l / 2, l / 2); - } else if (op->is_intrinsic("vectorcombine")) { + } else if (op->op.same_as(builtin::vectorcombine())) { llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); int num_elems = llvm::cast(v0->getType())->getNumElements() * 2; @@ -856,7 +855,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } return builder_->CreateShuffleVector(v0, v1, indices); } else { - LOG(FATAL) << "unknown intrinsic " << op->name; + LOG(FATAL) << "unknown intrinsic " << op->op; return nullptr; } } @@ -1076,13 +1075,24 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - return CreateIntrinsic(op); - } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - return CreateCallExtern(op); + if (auto* ptr_op = op->op.as()) { + auto call_op = GetRef(ptr_op); + if (op->op.same_as(builtin_call_extern_)) { + // call extern intrinsic + CHECK_GE(op->args.size(), 1U); + auto global_symbol = Downcast(op->args[0]); + return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, + true); + } else if (op_attr_global_symbol_.count(call_op)) { + // call extern if the op itself have a global symbol. + return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], + op->args, false); + } else { + return CreateIntrinsic(op); + } } else { - LOG(FATAL) << "Unknown call type " - << "name= " << op->name << " call_type= " << op->call_type; + CHECK(op->op.as()); + LOG(FATAL) << "Do not yet support cross function call"; return nullptr; } } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 0bca2a1..2bfe047 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -175,7 +176,9 @@ class CodeGenLLVM : public ExprFunctor, // create intrinstic given call virtual llvm::Value* CreateIntrinsic(const CallNode* op); // create extern function call - virtual llvm::Value* CreateCallExtern(const CallNode* op); + // skip first arg mode used for call extern intrinsic. + virtual llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg); // Get the corresponding thread index virtual llvm::Value* GetThreadIndex(const IterVar& iv); // Get the corresponding thread index @@ -319,6 +322,11 @@ class CodeGenLLVM : public ExprFunctor, std::unordered_set alias_var_set_; // set of volatile buffer. std::unordered_set volatile_buf_; + // Cache potential common path ops to slightly improve lookup time. + // global symbol table. + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + const Op& builtin_call_extern_ = builtin::call_extern(); + const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { std::unique_ptr di_builder_; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index bc47ce1..71c8e78 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -197,11 +197,11 @@ static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; int offset = 0; - if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) { + if (op->op.same_as(builtin::tvm_warp_shuffle())) { offset = 0; - } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) { + } else if (op->op.same_as(builtin::tvm_warp_shuffle_up())) { offset = 2; - } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) { + } else if (op->op.same_as(builtin::tvm_warp_shuffle_down())) { offset = 4; } else { return false; @@ -226,7 +226,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { llvm::Type* return_type = arg_type[0]; llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); return builder_->CreateCall(func, arg_value); - } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) { + } else if (op->op.same_as(builtin::tvm_warp_activemask())) { // Only nvptx target may keep this intrinsic at this point. // PTX assembly: asm "activemask.b32 r1;" auto fty = llvm::FunctionType::get(t_int32_, false); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index edffda2..5d269fa 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -89,7 +89,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), {op->value}, tir::CallNode::PureIntrinsic)), MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), @@ -105,7 +105,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { return CallVectorIntrin( ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), {op->value}, tir::CallNode::PureIntrinsic))}); } #endif diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 8804b1e..abf350e 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -39,6 +39,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); +// TODO(tvm-team): migrate the legalization transformations as a separate +// set of rules in TIR that can be shared across backends. TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { using tir::make_const; @@ -48,7 +50,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") CHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = tir::Call(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); + PrimExpr ret = exp(x * ln10); *rv = ret; }); @@ -97,8 +99,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_two = make_const(x.dtype(), -2); - PrimExpr exp_neg2x = tir::Call(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_pos2x = tir::Call(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_neg2x = exp(neg_two * x); + PrimExpr exp_pos2x = exp(two * x); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); @@ -116,9 +118,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; - PrimExpr sin_x = tir::Call(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); - PrimExpr cos_x = tir::Call(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); - PrimExpr tan_x = sin_x / cos_x; + PrimExpr tan_x = sin(x) / cos(x); *rv = tan_x; }); @@ -135,8 +135,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_negx = exp(neg_one * x); + PrimExpr exp_posx = exp(x); PrimExpr ret = (exp_posx + exp_negx) / two; *rv = ret; }); @@ -154,8 +154,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_negx = exp(neg_one * x); + PrimExpr exp_posx = exp(x); PrimExpr ret = (exp_posx - exp_negx) / two; *rv = ret; }); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index 5613621..cc9437d 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -49,7 +50,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); + *rv = + tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::PureIntrinsic); } template @@ -64,7 +66,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); + *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 49c2224..a0ffe11 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -23,7 +23,9 @@ #ifdef TVM_LLVM_VERSION #include +#include #include +#include #include @@ -36,10 +38,21 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { const CallNode* call = e.as(); CHECK(call != nullptr); CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; + + const OpNode* op = call->op.as(); + CHECK(op != nullptr); + std::string name = op->name; + CHECK_EQ(name.substr(0, 4), "tir."); + std::ostringstream intrinsic_name; - intrinsic_name << "__nv_" << call->name; + intrinsic_name << "__nv_" << name.substr(4); if (call->dtype.bits() == 32) intrinsic_name << "f"; - *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); + + Array new_args = {StringImm(intrinsic_name.str())}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); } namespace llvm { diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 3a2b8ac..07520ae 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -23,6 +23,7 @@ #ifdef TVM_LLVM_VERSION #include +#include #include #include @@ -36,9 +37,21 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { using namespace tir; const CallNode* call = e.as(); CHECK(call != nullptr); + + const OpNode* op = call->op.as(); + CHECK(op != nullptr); + std::string name = op->name; + CHECK_EQ(name.substr(0, 4), "tir."); + std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); - *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); + intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); + + Array new_args = {StringImm(intrinsic_name.str())}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + + *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); } inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { @@ -53,29 +66,30 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { // get own lane in self (__lane_id) PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); PrimExpr zero = tir::make_zero(DataType::Int(32)); - PrimExpr lo = - Call(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, CallNode::PureExtern); - PrimExpr self = - Call(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, CallNode::PureExtern); + PrimExpr lo = Call(DataType::Int(32), builtin::call_extern(), + {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, CallNode::PureExtern); + PrimExpr self = Call(DataType::Int(32), builtin::call_extern(), + {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, CallNode::PureExtern); // compute lane to get from PrimExpr width = call->args[3]; PrimExpr index; - if (call->name == "tvm_warp_shuffle") { + if (call->op.same_as(builtin::tvm_warp_shuffle())) { PrimExpr src_lane = call->args[2]; index = src_lane + (self & ~(width - 1)); - } else if (call->name == "tvm_warp_shuffle_up") { + } else if (call->op.same_as(builtin::tvm_warp_shuffle_up())) { PrimExpr delta = call->args[2]; index = self - delta; index = Select(index < (self & ~(width - 1)), self, index); } else { - CHECK_EQ(call->name, "tvm_warp_shuffle_down"); + CHECK(call->op.same_as(builtin::tvm_warp_shuffle_down())); PrimExpr delta = call->args[2]; index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } PrimExpr res = - Call(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, CallNode::PureExtern); + Call(var.dtype(), builtin::call_extern(), + {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}, CallNode::PureExtern); *rv = res; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 9255d7c..ffeaba0 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -223,12 +223,12 @@ std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr i // Print a reference expression to a buffer. std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { - if (kind < intrinsic::kArrKindBound_) { + if (kind < builtin::kArrKindBound_) { std::ostringstream os; os << "(((DLTensor*)"; this->PrintExpr(buffer, os); os << ")"; - if (kind == intrinsic::kArrAddr) { + if (kind == builtin::kArrAddr) { os << " + "; this->PrintExpr(index, os); os << ")"; @@ -239,34 +239,34 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << "]."; // other case: get fields. switch (kind) { - case intrinsic::kArrData: + case builtin::kArrData: os << "data"; break; - case intrinsic::kArrShape: + case builtin::kArrShape: os << "shape"; break; - case intrinsic::kArrStrides: + case builtin::kArrStrides: os << "strides"; break; - case intrinsic::kArrNDim: + case builtin::kArrNDim: os << "ndim"; break; - case intrinsic::kArrTypeCode: + case builtin::kArrTypeCode: os << "dtype.code"; break; - case intrinsic::kArrTypeBits: + case builtin::kArrTypeBits: os << "dtype.bits"; break; - case intrinsic::kArrByteOffset: + case builtin::kArrByteOffset: os << "byte_offset"; break; - case intrinsic::kArrTypeLanes: + case builtin::kArrTypeLanes: os << "dtype.lanes"; break; - case intrinsic::kArrDeviceId: + case builtin::kArrDeviceId: os << "ctx.device_id"; break; - case intrinsic::kArrDeviceType: + case builtin::kArrDeviceType: os << "ctx.device_type"; break; default: @@ -275,7 +275,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << ')'; return os.str(); } else { - CHECK_LT(kind, intrinsic::kTVMValueKindBound_); + CHECK_LT(kind, builtin::kTVMValueKindBound_); std::ostringstream os; os << "(((TVMValue*)"; this->PrintExpr(buffer, os); @@ -559,80 +559,94 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->a, os); } -void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - os << op->name << "("; - for (size_t i = 0; i < op->args.size(); i++) { - this->PrintExpr(op->args[i], os); - if (i < op->args.size() - 1) { - os << ", "; - } +void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) { // NOLINT(*) + os << global_symbol << "("; + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + this->PrintExpr(args[i], os); + if (i < args.size() - 1) { + os << ", "; } - os << ")"; - } else if (op->is_intrinsic(CallNode::bitwise_and)) { - PrintBinaryIntrinsic(op, " & ", os, this); - } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { - CHECK_EQ(op->args.size(), 2U); - uint64_t low = static_cast(Downcast(op->args[0])->value); - uint64_t high = static_cast(Downcast(op->args[1])->value); - uint64_t val = (high << 32U) | low; - PrintUIntConst(op->dtype, val, os, this); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { - PrintBinaryIntrinsic(op, " ^ ", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { - PrintBinaryIntrinsic(op, " | ", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { - CHECK_EQ(op->args.size(), 1U); - os << "(~"; - this->PrintExpr(op->args[0], os); - os << ')'; - } else if (op->is_intrinsic(CallNode::shift_left)) { - PrintBinaryIntrinsic(op, " << ", os, this); - } else if (op->is_intrinsic(CallNode::shift_right)) { - PrintBinaryIntrinsic(op, " >> ", os, this); - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { - os << "("; - PrintExpr(op->args[0], os); - os << " ? "; - PrintExpr(op->args[1], os); - os << " : "; - PrintExpr(op->args[2], os); - os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode* l = op->args[0].as(); - CHECK(op->args.size() == 1 && l); - os << "(("; - this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; - this->PrintExpr(l->index, os); - os << ')'; - } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { - CHECK_EQ(op->args.size(), 3U); - os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); - } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { - CHECK_EQ(op->args.size(), 1U); - os << "("; - this->PrintExpr(op->args[0], os); - os << " == NULL)"; - } else if (op->is_intrinsic(CallNode::reinterpret)) { - // generate (*( TYPE *)(&(ARG))) - os << "(*("; - this->PrintType(op->dtype, os); - os << " *)(&("; - this->PrintExpr(op->args[0], os); - os << ")))"; - } else if (op->is_intrinsic(CallNode::isnan)) { - os << "("; - this->PrintExpr(op->args[0], os); - os << " != "; - this->PrintExpr(op->args[0], os); - os << ")"; - } else { - if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + } + os << ")"; +} + +void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + if (auto* ptr_op = op->op.as()) { + auto call_op = GetRef(ptr_op); + + if (op->op.same_as(builtin_call_extern_)) { + CHECK_GE(op->args.size(), 1U); + auto func = Downcast(op->args[0]); + this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); + } else if (op_attr_global_symbol_.count(call_op)) { + // call extern if the op itself have a global symbol. + this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], + op->args, false, os); + } else if (op->op.same_as(builtin::bitwise_and())) { + PrintBinaryIntrinsic(op, " & ", os, this); + } else if (op->op.same_as(builtin::large_uint_imm())) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + PrintUIntConst(op->dtype, val, os, this); + } else if (op->op.same_as(builtin::bitwise_xor())) { + PrintBinaryIntrinsic(op, " ^ ", os, this); + } else if (op->op.same_as(builtin::bitwise_or())) { + PrintBinaryIntrinsic(op, " | ", os, this); + } else if (op->op.same_as(builtin::bitwise_not())) { + CHECK_EQ(op->args.size(), 1U); + os << "(~"; + this->PrintExpr(op->args[0], os); + os << ')'; + } else if (op->op.same_as(builtin::shift_left())) { + PrintBinaryIntrinsic(op, " << ", os, this); + } else if (op->op.same_as(builtin::shift_right())) { + PrintBinaryIntrinsic(op, " >> ", os, this); + } else if (op->op.same_as(builtin::if_then_else())) { + os << "("; + PrintExpr(op->args[0], os); + os << " ? "; + PrintExpr(op->args[1], os); + os << " : "; + PrintExpr(op->args[2], os); + os << ")"; + } else if (op->op.same_as(builtin::address_of())) { + const LoadNode* l = op->args[0].as(); + CHECK(op->args.size() == 1 && l); + os << "(("; + this->PrintType(l->dtype.element_of(), os); + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; + this->PrintExpr(l->index, os); + os << ')'; + } else if (op->op.same_as(builtin::tvm_struct_get())) { + CHECK_EQ(op->args.size(), 3U); + os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); + } else if (op->op.same_as(builtin::isnullptr())) { + CHECK_EQ(op->args.size(), 1U); + os << "("; + this->PrintExpr(op->args[0], os); + os << " == NULL)"; + } else if (op->op.same_as(builtin::reinterpret())) { + // generate (*( TYPE *)(&(ARG))) + os << "(*("; + this->PrintType(op->dtype, os); + os << " *)(&("; + this->PrintExpr(op->args[0], os); + os << ")))"; + } else if (op->op.same_as(builtin::isnan())) { + os << "("; + this->PrintExpr(op->args[0], os); + os << " != "; + this->PrintExpr(op->args[0], os); + os << ")"; } else { - LOG(FATAL) << "Unresolved call type " << op->call_type; + LOG(FATAL) << "Unresolved call " << op->op; } + } else { + CHECK(op->op.as()); + LOG(FATAL) << "Do not yet support cross function call"; } } @@ -903,10 +917,10 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); if (call) { - if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { + if (call->op.same_as(builtin::tvm_storage_sync())) { this->PrintStorageSync(call); return; - } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) { + } else if (call->op.same_as(builtin::tvm_struct_set())) { CHECK_EQ(call->args.size(), 4); std::string value = PrintExpr(call->args[3]); std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 309eb06..9346f87 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -24,10 +24,13 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ +#include #include #include +#include #include #include +#include #include #include @@ -220,6 +223,16 @@ class CodeGenC : public ExprFunctor, virtual bool IsScopePartOfType() const { return true; } /*! + * \brief Print external function call. + * \param ret_type The return type. + * \param global_symbol The symbolc of the target function. + * \param args The arguments to the function. + * \param skip_first_arg Whether to skip the first arguments. + * \param os The output stream. + */ + virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os); // NOLINT(*) + /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. * \param t The type to be checked. @@ -245,6 +258,10 @@ class CodeGenC : public ExprFunctor, std::unordered_map alloc_storage_scope_; /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; + /*! \brief Record of ops that have pre-defined global symbol. */ + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + // cache commonly used ops + const Op& builtin_call_extern_ = builtin::call_extern(); private: /*! \brief whether to print in SSA form */ diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index b11b3d8..839962a 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -175,7 +175,7 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar } void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { + if (op->op.same_as(builtin::tvm_stack_alloca())) { std::string stack_name = GetUniqueName("stack"); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); @@ -197,7 +197,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT this->PrintIndent(); this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; os << stack_name; - } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { + } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { const StringImmNode* s = op->args[0].as(); CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; int64_t begin = op->args[3].as()->value; @@ -216,7 +216,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } this->PrintGetFuncFromBackend(func_name, packed_func_name); this->PrintFuncCall(packed_func_name, num_args); - } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PrintIndent(); this->stream << "return -1;\n"; } else { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index cf7a74f..ae5e40a 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -429,15 +429,71 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } +void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) { // NOLINT(*) + DataType ret_dtype = GetRuntimeDataType(ret_type); + if (ret_dtype.is_vector()) { + // + // Emit an unsupported vector call + // + // v = intrin_f((float4*)A[0], (float4*)B[0]) + // + // as + // + // float4 __ret; + // { + // float4 __arg0 = ((float4*)A)[0]; + // float4 __arg1 = ((float4*)B)[0]; + // __ret.x = intrin_f(__arg0.x, __arg1.x); + // __ret.y = intrin_f(__arg0.y, __arg1.y); + // __ret.z = intrin_f(__arg0.z, __arg1.z); + // __ret.w = intrin_f(__arg0.w, __arg1.w); + // } + // v = __ret; + // + // Declare the result vector. + std::string sret = GetUniqueName("_"); + this->PrintIndent(); + this->PrintType(ret_dtype, stream); + stream << ' ' << sret << ";\n"; + { + // Load arguments. + std::vector sargs; + size_t arg_begin = static_cast(skip_first_arg); + for (size_t i = arg_begin; i < args.size(); ++i) { + std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); + sargs.push_back(std::move(val)); + } + + // Emit a scalar call for each lane. + for (int i = 0; i < ret_dtype.lanes(); ++i) { + std::ostringstream scall; + scall << global_symbol << "("; + for (size_t j = 0; j < sargs.size(); ++j) { + if (j > 0) scall << ", "; + PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); + } + scall << ")"; + PrintVecElemStore(sret, ret_dtype, i, scall.str()); + } + } + os << sret; + } else { + CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); + } +} + void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { - // This is only for backward compatibility with __shfl_{up/down}. - // A macro will be used to replace *_sync calls to legacy ones. - if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") || - op->is_intrinsic("__shfl_down_sync")) { - enable_warp_shuffle_ = true; + if (auto* ptr_op = op->op.as()) { + Op call_op = GetRef(ptr_op); + // This is only for backward compatibility with __shfl_{up/down}. + // A macro will be used to replace *_sync calls to legacy ones. + if (op_need_warp_shuffle_.get(call_op, false)) { + enable_warp_shuffle_ = true; + } } - if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + if (op->op.same_as(builtin::tvm_fill_fragment())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 6U); os << "nvcuda::wmma::fill_fragment("; @@ -447,7 +503,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "], "; this->PrintExpr(op->args[5], os); os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) { + } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::load_matrix_sync("; @@ -459,7 +515,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << ", "; this->PrintExpr(op->args[6], os); os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::store_matrix_sync("; @@ -476,7 +532,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { LOG(FATAL) << "Invalid parameters"; } os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { + } else if (op->op.same_as(builtin::tvm_mma_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::mma_sync("; @@ -486,7 +542,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) { + } else if (op->op.same_as(builtin::tvm_bmma_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::bmma_sync("; @@ -496,51 +552,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) { - // - // Emit an unsupported vector call - // - // v = intrin_f((float4*)A[0], (float4*)B[0]) - // - // as - // - // float4 __ret; - // { - // float4 __arg0 = ((float4*)A)[0]; - // float4 __arg1 = ((float4*)B)[0]; - // __ret.x = intrin_f(__arg0.x, __arg1.x); - // __ret.y = intrin_f(__arg0.y, __arg1.y); - // __ret.z = intrin_f(__arg0.z, __arg1.z); - // __ret.w = intrin_f(__arg0.w, __arg1.w); - // } - // v = __ret; - // - // Declare the result vector. - std::string sret = GetUniqueName("_"); - this->PrintIndent(); - this->PrintType(op->dtype, stream); - stream << ' ' << sret << ";\n"; - { - // Load arguments. - std::vector sargs; - for (size_t i = 0; i < op->args.size(); ++i) { - std::string val = SSAGetID(PrintExpr(op->args[i]), op->args[i].dtype()); - sargs.push_back(std::move(val)); - } - - // Emit a scalar call for each lane. - for (int i = 0; i < op->dtype.lanes(); ++i) { - std::ostringstream scall; - scall << op->name << "("; - for (size_t j = 0; j < op->args.size(); ++j) { - if (j > 0) scall << ", "; - PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall); - } - scall << ")"; - PrintVecElemStore(sret, op->dtype, i, scall.str()); - } - } - os << sret; } else { CodeGenC::VisitExpr_(op, os); } @@ -600,7 +611,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); - if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) { + if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { PrintIndent(); stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n"; PrintIndent(); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index f9ab0ad..3cde8e3 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -68,6 +69,10 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; + protected: + void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) final; // NOLINT(*) + private: // Handle volatile loads void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; @@ -91,6 +96,8 @@ class CodeGenCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; + // Op attribute map + OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 2c26ee9..1c4256c 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -270,7 +270,7 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N } void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->is_intrinsic(CallNode::reinterpret)) { + if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; this->PrintType(op->dtype, os); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 45746b8..53a2799 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -21,6 +21,9 @@ * \file intrin_rule_cuda.cc * \brief CUDA intrinsic rules. */ +#include +#include + #include "../intrin_rule.h" namespace tvm { @@ -93,23 +96,23 @@ struct CUDAPopcount { }; struct CUDAWarpIntrinsic { - const char* operator()(DataType t, const std::string& name) const { - if (name == intrinsic::tvm_warp_shuffle) { - return "__shfl_sync"; - } - if (name == intrinsic::tvm_warp_shuffle_up) { - return "__shfl_up_sync"; - } - if (name == intrinsic::tvm_warp_shuffle_down) { - return "__shfl_down_sync"; - } - if (name == intrinsic::tvm_warp_activemask) { - return "__activemask"; + const Op operator()(DataType t, const Op& orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tir.cuda.__shfl_sync"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tir.cuda.__shfl_up_sync"); + } else { + CHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tir.cuda.__shfl_down_sync"); } - return ""; } }; +static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) { + Call call = args[0]; + *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, CallNode::PureExtern); +} + template static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; @@ -117,8 +120,9 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - const char* name = T()(call->dtype, call->name); - *rv = Call(call->dtype, name, cuda_args, CallNode::PureExtern); + + *rv = + Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); @@ -175,10 +179,32 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") .set_body(DispatchCUDAShuffle); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") - .set_body(DispatchExtern); + .set_body(DispatchCUDAWarpActiveMask); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern); +// Register low-level builtin ops. +// TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. +TVM_REGISTER_OP("tir.cuda.__shfl_sync") + .set_num_inputs(4) + .set_attr("TGlobalSymbol", "__shfl_sync") + .set_attr("cuda.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.cuda.__shfl_up_sync") + .set_num_inputs(4) + .set_attr("TGlobalSymbol", "__shfl_up_sync") + .set_attr("cuda.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.cuda.__shfl_down_sync") + .set_num_inputs(4) + .set_attr("TGlobalSymbol", "__shfl_down_sync") + .set_attr("cuda.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.cuda.__activemask") + .set_num_inputs(0) + .set_attr("TGlobalSymbol", "__activemask") + .set_attr("cuda.need_warp_shuffle", true); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 8453b33..82eabdd 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -79,8 +79,8 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { arith::Analyzer analyzer; CHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; - Array opencl_args{{call->args[1], call->args[2]}}; - *rv = Call(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); + Array opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; + *rv = Call(call->dtype, builtin::call_extern(), opencl_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 699d395..6c12343 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -24,6 +24,7 @@ #include "codegen_spirv.h" #include +#include #include #include @@ -236,7 +237,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { - if (op->is_intrinsic("spirv_glsl450")) { + if (op->op.same_as(builtin::call_spirv_glsl450())) { CHECK_GE(op->args.size(), 2U); uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; @@ -244,31 +245,31 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { values.push_back(MakeValue(op->args[i])); } return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(builtin::bitwise_and())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { + } else if (op->op.same_as(builtin::bitwise_xor())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { + } else if (op->op.same_as(builtin::bitwise_or())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { + } else if (op->op.same_as(builtin::bitwise_not())) { CHECK_EQ(op->args.size(), 1U); spirv::Value a = MakeValue(op->args[0]); return builder_->MakeValue(spv::OpNot, a.stype, a); - } else if (op->is_intrinsic(CallNode::shift_left)) { + } else if (op->op.same_as(builtin::shift_left())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(builtin::shift_right())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); @@ -277,18 +278,18 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } - } else if (op->is_intrinsic(CallNode::reinterpret)) { + } else if (op->op.same_as(builtin::reinterpret())) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); - } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + } else if (op->op.same_as(builtin::large_uint_imm())) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; return builder_->UIntImm(builder_->GetSType(op->dtype), val); - } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { + } else if (op->op.same_as(builtin::tvm_storage_sync())) { return this->CreateStorageSync(op); - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + } else if (op->op.same_as(builtin::if_then_else())) { CHECK_EQ(op->args.size(), 3U); spirv::Value cond = MakeValue(op->args[0]); spirv::Label then_label = builder_->NewLabel(); @@ -312,14 +313,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(0, then_value, then_value_label); phi.SetIncoming(1, else_value, else_value_label); return phi; - } else if (op->is_intrinsic("popcount")) { + } else if (op->op.same_as(builtin::popcount())) { return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else { if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + LOG(FATAL) << "Unresolved intrinsic " << op->op << " with return type " << op->dtype; } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype; + LOG(FATAL) << "Unresolved extern " << op->op << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index a6b2547..1b9d2e4 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include namespace tvm { @@ -43,7 +44,8 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, tir::builtin::call_spirv_glsl450(), cargs, + tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 6dd2ca0..84b1492 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -41,31 +42,31 @@ using namespace tir; // map struct field kind to runtime variants // We keep two separate enums to ensure runtime/compiler isolation. StackVM::StructFieldKind MapFieldKind(int64_t kind) { - auto val = static_cast(kind); + auto val = static_cast(kind); switch (val) { - case intrinsic::kArrData: + case builtin::kArrData: return StackVM::kArrData; - case intrinsic::kArrShape: + case builtin::kArrShape: return StackVM::kArrShape; - case intrinsic::kArrAddr: + case builtin::kArrAddr: return StackVM::kArrAddr; - case intrinsic::kArrStrides: + case builtin::kArrStrides: return StackVM::kArrStrides; - case intrinsic::kArrNDim: + case builtin::kArrNDim: return StackVM::kArrNDim; - case intrinsic::kArrTypeCode: + case builtin::kArrTypeCode: return StackVM::kArrTypeCode; - case intrinsic::kArrTypeBits: + case builtin::kArrTypeBits: return StackVM::kArrTypeBits; - case intrinsic::kArrTypeLanes: + case builtin::kArrTypeLanes: return StackVM::kArrTypeLanes; - case intrinsic::kArrByteOffset: + case builtin::kArrByteOffset: return StackVM::kArrByteOffset; - case intrinsic::kArrDeviceId: + case builtin::kArrDeviceId: return StackVM::kArrDeviceId; - case intrinsic::kArrDeviceType: + case builtin::kArrDeviceType: return StackVM::kArrDeviceType; - case intrinsic::kTVMValueContent: + case builtin::kTVMValueContent: return StackVM::kTVMValueContent; default: LOG(FATAL) << "Do not know how to map field " << kind; @@ -174,7 +175,7 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { } void CodeGenStackVM::VisitExpr_(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_address_of)) { + if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); @@ -182,9 +183,9 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); - } else if (op->is_intrinsic(CallNode::reinterpret)) { + } else if (op->op.same_as(builtin::reinterpret())) { this->Push(op->args[0]); - } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { + } else if (op->op.same_as(builtin::tvm_struct_get())) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; this->Push(op->args[0]); @@ -197,7 +198,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { vm_.code.push_back(code); code.v_int = MapFieldKind(kind); vm_.code.push_back(code); - } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { + } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { CHECK_GE(op->args.size(), 5U); const StringImmNode* s = op->args[0].as(); CHECK(s != nullptr) << "tvm_call_global expect first argument as function name"; @@ -226,7 +227,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { vm_.code.push_back(code); code.v_int = end; vm_.code.push_back(code); - } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { + } else if (op->op.same_as(builtin::tvm_stack_alloca())) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); @@ -249,7 +250,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { // add stack size to be safe. vm_.stack_size += size; this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast(size)); - } else if (op->name == "TVMBackendAllocWorkspace") { + } else if (op->op.same_as(backend_alloc_workspace_op_)) { CHECK_EQ(op->args.size(), 5U); this->Push(op->args[0]); this->Push(op->args[1]); @@ -257,21 +258,21 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->Push(op->args[3]); this->Push(op->args[4]); this->PushOp(StackVM::TVM_DEVICE_ALLOCA); - } else if (op->name == "TVMBackendFreeWorkspace") { + } else if (op->op.same_as(backend_free_workspace_op_)) { CHECK_EQ(op->args.size(), 3U); this->Push(op->args[0]); this->Push(op->args[1]); this->Push(op->args[2]); this->PushOp(StackVM::TVM_DEVICE_FREE); - } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PushOp(StackVM::TVM_THROW_LAST_ERROR); - } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + } else if (op->op.same_as(builtin::isnullptr())) { CHECK_EQ(op->args.size(), 1U); this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_HANDLE); } else { - LOG(FATAL) << "unknown function call " << op->name; + LOG(FATAL) << "unknown function call " << op->op; } } @@ -430,7 +431,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { if (is_const(ev->value)) return; const CallNode* op = ev->value.as(); - if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) { + if (op && op->op.same_as(builtin::tvm_struct_set())) { CHECK_EQ(op->args.size(), 4U); this->Push(op->args[0]); this->Push(op->args[3]); diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index b77c406..480ffc7 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -153,6 +154,9 @@ class CodeGenStackVM : public ExprFunctor, std::unordered_map str_idmap_; /*! \brief id of each global function */ std::unordered_map extern_fun_idmap_; + + Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace"); + Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace"); }; } // namespace codegen diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 1834aa3..f625412 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -96,31 +96,30 @@ class JacobianMutator : public ExprMutator { PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); if (op->call_type == CallNode::CallType::PureIntrinsic) { - static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; - if (op->name == "exp") { + if (op->op.same_as(op_exp_)) { return Mul(Mutate(op->args[0]), expr); - } else if (op->name == "log") { + } else if (op->op.same_as(op_log_)) { return Div(Mutate(op->args[0]), op->args[0]); - } else if (op->name == "sigmoid") { + } else if (op->op.same_as(op_sigmoid_)) { return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr))); - } else if (op->name == "sqrt") { + } else if (op->op.same_as(op_sqrt_)) { return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0))); - } else if (op->name == "tanh") { + } else if (op->op.same_as(op_tanh_)) { return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr))); - } else if (op->name == "pow") { + } else if (op->op.same_as(op_pow_)) { auto x = op->args[0], y = op->args[1]; return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); - } else if (op->name == "fabs") { + } else if (op->op.same_as(op_fabs_)) { auto type = op->args[0].dtype(); return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0), FloatImm(type, -1.0))); - } else if (op->name == intrinsic::tvm_if_then_else) { + } else if (op->op.same_as(op_if_then_else_)) { Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; - return Call(op->dtype, op->name, new_args, op->call_type); - } else if (piecewise_const.count(op->name)) { + return Call(op->dtype, op->op, new_args, op->call_type); + } else if (piecewise_const.count(op->op)) { return FloatImm(expr.dtype(), 0.0); } else { - throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); + LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op; } } NOT_IMPLEMENTED; @@ -281,6 +280,17 @@ class JacobianMutator : public ExprMutator { Array indices_; Var input_var_; arith::Analyzer analyzer_; + + const Op& op_exp_ = Op::Get("tir.exp"); + const Op& op_log_ = Op::Get("tir.log"); + const Op& op_sigmoid_ = Op::Get("tir.sigmoid"); + const Op& op_sqrt_ = Op::Get("tir.sqrt"); + const Op& op_tanh_ = Op::Get("tir.tanh"); + const Op& op_pow_ = Op::Get("tir.pow"); + const Op& op_fabs_ = Op::Get("tir.fabs"); + const Op& op_if_then_else_ = Op::Get("tir.if_then_else"); + std::unordered_set piecewise_const = { + Op::Get("tir.floor"), Op::Get("tir.ceil"), Op::Get("tir.trunc"), Op::Get("tir.round")}; }; PrimExpr Derivative(const PrimExpr& expr, const Var& var) { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 1fc0520..b4725c5 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -278,7 +279,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, attr->dim_align_offset}; realize = tir::AttrStmt( t, tir::attr::buffer_dim_align, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), realize); } } diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index e834ff2..eeaab30 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -21,6 +21,8 @@ * \brief Logics related to cross thread reduction, used by ComputeOpNode. * \file cross_thread_reduction.cc */ +#include + #include "compute_op.h" #include "op_util.h" @@ -194,7 +196,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // Apply the existing input predicate if any. output_preds.push_back(input_pred); - Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, + Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args, CallNode::Intrinsic)); reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index ef55c44..01019e4 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -153,7 +153,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage, tuple.push_back(buffer->shape[k]); } ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); + Call(DataType::Handle(), builtin::tvm_tuple(), tuple, CallNode::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 8d5265b..714e885 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -153,7 +154,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, } input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // output binding @@ -177,7 +178,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // Check variable remap diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 82832c9..dd978a4 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -370,7 +370,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, } input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -390,7 +390,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 1ff569f..67121b8 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -42,7 +43,6 @@ namespace tvm { namespace te { using namespace te; -using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; @@ -255,9 +255,9 @@ class BodyVisitor : public StmtExprVisitor { } } - void VisitExpr_(const CallNode* op) final { + void VisitExpr_(const ProducerLoadNode* op) final { StmtExprVisitor::VisitExpr_(op); - args_.insert(std::make_pair(op->name, op->args)); + args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices)); } friend class ScheduleAnalyser; @@ -415,7 +415,7 @@ class BufferAnalyser : public StmtExprVisitor { } else if (op->attr_key == tir::attr::buffer_dim_align) { te::Tensor tensor = Downcast(op->node); const CallNode* tuple = op->value.as(); - CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); + CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); auto& vinfo = dim_align_[tensor]; size_t dim = tuple->args[0].as()->value; if (dim >= vinfo.size()) { @@ -848,13 +848,13 @@ class TensorCoreIRMutator : public StmtExprMutator { Buffer buffer_b(buffer_node_b); if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { return Evaluate( - Call(DataType::Handle(), intrinsic::tvm_bmma_sync, + Call(DataType::Handle(), builtin::tvm_bmma_sync(), {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, CallNode::Intrinsic)); } else { return Evaluate( - Call(DataType::Handle(), intrinsic::tvm_mma_sync, + Call(DataType::Handle(), builtin::tvm_mma_sync(), {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, CallNode::Intrinsic)); @@ -879,7 +879,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto pload = dst.as(); auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment, + return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, op->value}, CallNode::Intrinsic)); @@ -889,11 +889,11 @@ class TensorCoreIRMutator : public StmtExprMutator { return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); } - const CallNode* value = op->value.as(); + const ProducerLoadNode* value = op->value.as(); CHECK(value != nullptr) << "Can only load fragment from a buffer"; - auto it = strides_.find(value->name); - CHECK(it != strides_.end()) << "Cannot find stride for " << value->name; + auto it = strides_.find(value->producer->GetNameHint()); + CHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint(); auto strides = it->second; CHECK_GE(strides.size(), 2); PrimExpr stride = strides[strides.size() - 2]; @@ -902,7 +902,9 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); - PrimExpr src = Call(value->dtype, "&", {mutated_value}, CallNode::Extern); + // TODO(tvm-team) The extern function name seems to be a hack. + PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}, + CallNode::Extern); auto pload = dst.as(); PrimExpr matrix_major; @@ -918,7 +920,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync, + return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, src, stride, matrix_major}, CallNode::Intrinsic)); @@ -941,12 +943,13 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = Call(DataType::Handle(), "&", {dst}, CallNode::Extern); + dst = + Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}, CallNode::Extern); auto pload = op->value.as(); auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync, + return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, dst, stride, StringImm("col_major")}, CallNode::Intrinsic)); @@ -1064,7 +1067,7 @@ class TensorCoreIRMutator : public StmtExprMutator { args.push_back(pload->indices[i]); args.push_back(shape[i]); } - auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); + auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args, CallNode::Intrinsic); Array node = {buffer, tensor}; return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 8eb846b..12ec270 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -120,7 +121,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { const auto& iter = defs_.find(V); if (iter == defs_.end()) return false; const CallNode* C = iter->second.as(); - if (!C || C->name != intrinsic::tvm_struct_get) return false; + if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false; V = C->args[0].as(); } return false; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 4e433fc..6cccfa0 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -376,7 +377,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Array acc_args{e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); + return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args, tir::CallNode::Intrinsic); } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 9390fea..4b20351 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -698,50 +698,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Call -Call::Call(DataType dtype, String name, Array args, CallType call_type) { +Call::Call(DataType dtype, RelayExpr op, Array args, CallType call_type) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } ObjectPtr node = make_object(); node->dtype = dtype; - node->name = std::move(name); + node->op = std::move(op); node->args = std::move(args); node->call_type = call_type; data_ = std::move(node); } -const char* CallNode::vectorizable_intrinsics[] = {"floor", - "ceil", - "sign", - "trunc", - "fabs", - "round", - "exp", - "tanh", - "sqrt", - "log", - "sin", - "cos", - "pow", - "tan", - tir::CallNode::shift_left, - tir::CallNode::shift_right, - tir::CallNode::likely, - tir::CallNode::popcount}; - -bool CallNode::is_vectorizable() const { - size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); - for (size_t i = 0; i < cnt; ++i) { - if (name == CallNode::vectorizable_intrinsics[i]) { - return true; - } - } - return false; -} - TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, String name, Array args, int call_type) { + .set_body_typed([](DataType type, RelayExpr op, Array args, int call_type) { Array prim_expr_args; for (const auto& it : args) { CHECK(it->IsInstance() || it->IsInstance()); @@ -751,7 +722,7 @@ TVM_REGISTER_GLOBAL("tir.Call") prim_expr_args.push_back(Downcast(it)); } } - return Call(type, name, prim_expr_args, static_cast(call_type)); + return Call(type, op, prim_expr_args, static_cast(call_type)); }); TVM_REGISTER_NODE_TYPE(CallNode); @@ -759,7 +730,13 @@ TVM_REGISTER_NODE_TYPE(CallNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << op->name << "("; + if (auto* ptr_op = op->op.as()) { + p->stream << ptr_op->name << "("; + } else { + auto* ptr_gvar = op->op.as(); + CHECK(ptr_gvar != nullptr); + p->stream << "@" << ptr_gvar->name_hint << "("; + } for (size_t i = 0; i < op->args.size(); ++i) { p->Print(op->args[i]); if (i < op->args.size() - 1) { diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index b92127b..98b9fd0 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -166,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return Call(op->dtype, op->name, args, op->call_type); + return Call(op->dtype, op->op, args, op->call_type); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6649775..c3ddb66 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -582,5 +582,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->PrintIndent(); p->stream << "}\n"; }); + +PrimExpr TypeAnnotation(DataType dtype) { + static auto op = Op::Get("tir.type_annotation"); + return tir::Call(dtype, op, {}, tir::CallNode::PureIntrinsic); +} + +TVM_REGISTER_OP("tir.type_annotation"); + } // namespace tir } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc new file mode 100644 index 0000000..8efcf3f --- /dev/null +++ b/src/tir/op/builtin.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/op/builtin.cc + * + * builtin intrinsic operators. + */ +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace builtin { + +#define TIR_DEFINE_BUILTIN_FUNC(OpName) \ + const Op& OpName() { \ + static const Op& op = Op::Get("tir." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tir." #OpName) + +TIR_DEFINE_BUILTIN_FUNC(reinterpret).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(likely).set_num_inputs(1).set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_and) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_or) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_xor) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_not) + .set_num_inputs(1) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(shift_left) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(shift_right) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(large_uint_imm).set_num_inputs(2); + +TIR_DEFINE_BUILTIN_FUNC(address_of).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(if_then_else).set_num_inputs(3); + +TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(popcount).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(fma).set_num_inputs(3).set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(call_extern); + +TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin); + +TIR_DEFINE_BUILTIN_FUNC(call_spirv_glsl450); + +TIR_DEFINE_BUILTIN_FUNC(prefetch); + +TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr).set_num_inputs(5); + +TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle).set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(tvm_context_id).set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(tvm_tuple); + +TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get).set_num_inputs(3); + +TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set).set_num_inputs(4); + +TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error).set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca).set_num_inputs(2); + +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape); + +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array).set_num_inputs(6); + +// When num_inputs are not set, the function is assumed to be variable length. +TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed); + +TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed); + +TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered); + +TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered); + +// TODO(tvm-team) revisit storage sync once we have a good memory hierachy structure. +TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask); + +TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit); + +TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce); + +TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment); + +TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync); + +TIR_DEFINE_BUILTIN_FUNC(vectorhigh); + +TIR_DEFINE_BUILTIN_FUNC(vectorlow); + +TIR_DEFINE_BUILTIN_FUNC(vectorcombine); + +} // namespace builtin +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/op.cc b/src/tir/op/op.cc similarity index 82% rename from src/tir/ir/op.cc rename to src/tir/op/op.cc index 5ac9f59..f8049ea 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/op/op.cc @@ -18,12 +18,16 @@ */ /*! - * \file expr_operator.cc + * \file tir/op/op.cc + * + * Common operator definitions for ops in tir/op.h */ #include +#include #include #include +#include #include // Centralized header for constant folders. @@ -33,6 +37,12 @@ namespace tvm { using namespace tir; +// macro to register an unary op +#define TIR_REGISTER_PURE_UNARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(1) + +// macro to register an binary op +#define TIR_REGISTER_PURE_BINARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(2) + runtime::DataType GetRuntimeDataType(const Type& type) { if (auto* n = type.as()) { return n->dtype; @@ -70,8 +80,9 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return tir::Cast(t, value); } +// LargeUIntImm PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { - return tir::Call(t, tir::intrinsic::tvm_large_uint_imm, + return tir::Call(t, tir::builtin::large_uint_imm(), {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, tir::CallNode::PureIntrinsic); } @@ -248,11 +259,13 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } } +// reinterpret PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::Call(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::builtin::reinterpret(), {value}, tir::CallNode::PureIntrinsic); } +// operator+ PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); @@ -360,6 +373,7 @@ PrimExpr max(PrimExpr a, PrimExpr b) { return tir::Max(a, b); } +// if_then_else PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; @@ -371,15 +385,20 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - return tir::Call(true_value.dtype(), tir::intrinsic::tvm_if_then_else, + + return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } +// likely PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::Call(cond.dtype(), tir::CallNode::likely, {cond}, tir::CallNode::PureIntrinsic); + return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, tir::CallNode::PureIntrinsic); } +TVM_REGISTER_OP("tir.likely").set_num_inputs(1); + +// operator> PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); @@ -445,6 +464,7 @@ PrimExpr operator!(PrimExpr a) { return tir::Not(a); } +// shirt right PrimExpr operator>>(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -460,9 +480,11 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::CallNode::shift_right, {a, b}, tir::CallNode::PureIntrinsic); + + return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, tir::CallNode::PureIntrinsic); } +// shift left PrimExpr operator<<(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -478,9 +500,10 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::CallNode::shift_left, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwise and PrimExpr operator&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -489,9 +512,10 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); - return tir::Call(a.dtype(), tir::CallNode::bitwise_and, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwise_or PrimExpr operator|(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -500,9 +524,10 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); - return tir::Call(a.dtype(), tir::CallNode::bitwise_or, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwise_xor PrimExpr operator^(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -511,20 +536,30 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); - return tir::Call(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwie_not PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::Call(a.dtype(), tir::CallNode::bitwise_not, {a}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, tir::CallNode::PureIntrinsic); } +TVM_REGISTER_OP("tir.bitwise_not"); + +TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; }); + +// pow PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return tir::Call(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.pow"); + return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); } +TVM_REGISTER_OP("tir.pow").set_num_inputs(2).set_attr("TVectorizable", true); + +// abs PrimExpr abs(PrimExpr x) { if (x.dtype().is_int()) { using tir::IntImmNode; @@ -539,7 +574,8 @@ PrimExpr abs(PrimExpr x) { if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value)); } - return tir::Call(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.fabs"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { @@ -549,6 +585,9 @@ PrimExpr abs(PrimExpr x) { } } +TIR_REGISTER_PURE_UNARY_OP("tir.fabs").set_attr("TVectorizable", true); + +// isnan PrimExpr isnan(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -559,12 +598,12 @@ PrimExpr isnan(PrimExpr x) { if (fx) { return make_const(t, std::isnan(fx->value)); } + static auto op = Op::Get("tir.isnan"); if (x.dtype().bits() == 16) { - return tir::Call(t, tir::CallNode::isnan, - {cast(DataType::Float(32, t.lanes()), std::move(x))}, + return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))}, tir::CallNode::PureIntrinsic); } else { - return tir::Call(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(t, op, {x}, tir::CallNode::PureIntrinsic); } } else { LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; @@ -572,6 +611,9 @@ PrimExpr isnan(PrimExpr x) { } } +TIR_REGISTER_PURE_UNARY_OP("tir.isnan"); + +// isinf PrimExpr isinf(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -585,6 +627,7 @@ PrimExpr isinf(PrimExpr x) { } } +// isfinite PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); } PrimExpr sum(PrimExpr source, Array rdom) { @@ -637,12 +680,17 @@ PrimExpr prod(PrimExpr source, Array rdom) { return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } +// fmod PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return tir::Call(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.fmod"); + return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.fmod"); + +// floor PrimExpr floor(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -650,9 +698,13 @@ PrimExpr floor(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); - return tir::Call(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.floor"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr("TVectorizable", true); + +// ceil PrimExpr ceil(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -660,9 +712,13 @@ PrimExpr ceil(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); - return tir::Call(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.ceil"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr("TVectorizable", true); + +// round PrimExpr round(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -670,9 +726,13 @@ PrimExpr round(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::Call(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.round"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr("TVectorizable", true); + +// nearbyint PrimExpr nearbyint(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -680,9 +740,13 @@ PrimExpr nearbyint(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::Call(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.nearbyint"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); + +// trunc PrimExpr trunc(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -692,9 +756,72 @@ PrimExpr trunc(PrimExpr x) { if (fx) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - return tir::Call(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.trunc"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr("TVectorizable", true); + +// unary op registration. +TIR_REGISTER_PURE_UNARY_OP("tir.exp").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.exp2").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.exp10").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.erf"); + +TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid"); + +TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.rsqrt"); + +TIR_REGISTER_PURE_UNARY_OP("tir.log").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.log2").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.log1p"); + +TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.popcount").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.cosh").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.sin").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.sinh").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.asin"); + +TIR_REGISTER_PURE_UNARY_OP("tir.acos"); + +TIR_REGISTER_PURE_UNARY_OP("tir.atan"); + +TIR_REGISTER_PURE_UNARY_OP("tir.acosh"); + +TIR_REGISTER_PURE_UNARY_OP("tir.asinh"); + +TIR_REGISTER_PURE_UNARY_OP("tir.atanh"); + +// binary intrinsics +TIR_REGISTER_PURE_BINARY_OP("tir.atan2"); + +TIR_REGISTER_PURE_BINARY_OP("tir.nextafter"); + +TIR_REGISTER_PURE_BINARY_OP("tir.hypot"); + +TIR_REGISTER_PURE_BINARY_OP("tir.copysign"); + +TIR_REGISTER_PURE_BINARY_OP("tir.ldexp"); + // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { @@ -783,4 +910,5 @@ TVM_REGISTER_GLOBAL("tir._OpIfThenElse") .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { return if_then_else(cond, true_value, false_value); }); + } // namespace tvm diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc new file mode 100644 index 0000000..1c540e3 --- /dev/null +++ b/src/tir/op/runtime.cc @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/op/runtime.cc + * \brief TIR ops for runtime functions. + */ +#include +#include + +namespace tvm { +namespace tir { + +TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace") + .set_num_inputs(5) + .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace"); + +TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace") + .set_num_inputs(3) + .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace"); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index ae7065d..80c5268 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -24,6 +24,7 @@ #include "arg_binder.h" #include +#include #include #include @@ -141,7 +142,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st } } -inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) { +inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kind) { return TVMStructGet(t, arr, 0, kind); } @@ -152,7 +153,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); // dimension checks - PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); + PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); @@ -162,11 +163,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, DataType dtype = buffer->dtype; std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; - PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == + PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == IntImm(DataType::UInt(8), dtype.code()) && - TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == IntImm(DataType::UInt(8), dtype.bits()) && - TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == + TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == IntImm(DataType::UInt(16), dtype.lanes())); if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); @@ -174,7 +175,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); } // data field - if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), + if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), arg_name + ".data", true)) { Var vptr(buffer->data); def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); @@ -186,7 +187,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Var v_shape(arg_name + ".shape", DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( - LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); + LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; @@ -202,9 +203,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back( - LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); + LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); PrimExpr is_null = - Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic); + Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}, CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -262,12 +263,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, if (const auto* const_offset = buffer->elem_offset.as()) { Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), - TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), arg_name + ".byte_offset", true); } else { if (Bind_(buffer->elem_offset, cast(buffer->elem_offset.dtype(), - (TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset) / + (TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset) / make_const(DataType::UInt(64), data_bytes))), arg_name + ".elem_offset", true)) { if (buffer->offset_factor > 1) { @@ -280,9 +281,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } } // device info. - Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), + Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), arg_name + ".device_type", true); - Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), + Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), arg_name + ".device_id", true); } diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 445ac1c..9722d11 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -188,13 +189,13 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. - return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); + return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}, CallNode::PureIntrinsic); } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); + auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}, CallNode::PureIntrinsic); auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 94464a0..3b6af06 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -66,7 +67,7 @@ class BoundChecker : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { + if (process_store_ && op->op.same_as(builtin::if_then_else())) { unsafe_rewritten_ = true; } return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 73bf4c6..0485bb1 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,7 @@ namespace tir { class ContextCallCombiner final : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_thread_context)) { + if (op->op.same_as(builtin::tvm_thread_context())) { CHECK_EQ(op->args.size(), 1U); PrimExpr ctx = op->args[0]; auto it = ctx_map_.find(ctx); @@ -48,13 +49,7 @@ class ContextCallCombiner final : public StmtExprMutator { return it->second; } else { CHECK(ctx.dtype().is_handle()); - std::string name; - if (const CallNode* call = ctx.as()) { - name = call->name + "_cache"; - } else { - name = "ctx_cache_"; - } - Var ctx_var(name, ctx.dtype()); + Var ctx_var("ctx_cache_", ctx.dtype()); ctx_map_[ctx] = ctx_var; return std::move(ctx_var); } diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 384dbcb..092a7cd 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -21,6 +21,7 @@ * \file coproc_sync.cc */ #include +#include #include #include #include @@ -54,7 +55,7 @@ class CoProcTouchedBuffer : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { const VarNode* buffer = op->args[1].as(); if (in_scope_) { touched_[buffer].coproc = true; @@ -195,7 +196,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; + return { + Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -208,8 +210,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { explicit CoProcBarrierDetector(const std::unordered_set& touched, const std::string& coproc_name) : touched_(touched) { - read_barrier_name_ = coproc_name + ".coproc_read_barrier"; - write_barrier_name_ = coproc_name + ".coproc_write_barrier"; + read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier"; + write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier"; } void PlanReadBarrier(const Stmt& stmt) { @@ -331,7 +333,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; - return Evaluate(Call(DataType::Int(32), func, + return Evaluate(Call(DataType::Int(32), Op::Get(func), {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic)); } @@ -346,8 +348,8 @@ class CoProcInstDepDetector : public StmtVisitor { public: explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) : coproc_axis_(coproc_axis) { - sync_push_name_ = coproc_name + ".coproc_dep_push"; - sync_pop_name_ = coproc_name + ".coproc_dep_pop"; + sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push"); + sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop"); } void Plan(const Stmt& stmt) { @@ -555,12 +557,12 @@ class CoProcInstDepDetector : public StmtVisitor { } Stmt MakePush(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_push_name_, + return Evaluate(Call(DataType::Int(32), sync_push_op_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_pop_name_, + return Evaluate(Call(DataType::Int(32), sync_pop_op_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, CallNode::Intrinsic)); } @@ -568,7 +570,7 @@ class CoProcInstDepDetector : public StmtVisitor { SyncState first_state_, last_state_, curr_state_; // Variables IterVar coproc_axis_; - std::string sync_push_name_, sync_pop_name_; + Op sync_push_op_, sync_pop_op_; }; class CoProcSyncInserter : public StmtMutator { diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 042ddab..7180dd2 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -21,6 +21,7 @@ * \file inject_virtual_thread.cc */ #include +#include #include #include #include @@ -54,7 +55,7 @@ class ExprTouched final : public StmtExprVisitor { } void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); CHECK(buffer_var); @@ -219,7 +220,7 @@ class VTInjector : public StmtExprMutator { } // Expression. PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -230,9 +231,9 @@ class VTInjector : public StmtExprMutator { PrimExpr extent = this->VisitExpr(op->args[3]); PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return Call(op->dtype, op->name, {op->args[0], op->args[1], offset, extent, op->args[4]}, + return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); - } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { + } else if (op->op.same_as(builtin::tvm_context_id())) { return allow_share_ ? GetRef(op) : var_; } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 6c0eeea..758923b 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -25,6 +25,7 @@ #define TVM_TIR_TRANSFORMS_IR_UTIL_H_ #include +#include #include #include @@ -83,10 +84,10 @@ inline Array UpdateArray(Array arr, F fupdate) { * \return the get expression. */ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, - intrinsic::TVMStructFieldKind kind) { + builtin::TVMStructFieldKind kind) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind))}; - return Call(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); + return Call(dtype, builtin::tvm_struct_get(), args, CallNode::PureIntrinsic); } /*! @@ -96,7 +97,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, * \param offset the offset index. */ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { - return Call(DataType::Handle(), intrinsic::tvm_address_of, + return Call(DataType::Handle(), builtin::address_of(), {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), const_true(dtype.lanes()))}, CallNode::PureIntrinsic); @@ -113,7 +114,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return Call(DataType::Handle(), intrinsic::tvm_address_of, + return Call(DataType::Handle(), builtin::address_of(), {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic); } @@ -125,11 +126,10 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \param value The value to be set. * \return the set stmt. */ -inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind, - PrimExpr value) { +inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), value}; - return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args, CallNode::Intrinsic)); } /*! diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 3b2580c..2fb8003 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -140,11 +141,11 @@ class CandidateSelector final : public StmtExprVisitor { } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(builtin::likely())) { in_likely_ = true; StmtExprVisitor::VisitExpr_(op); in_likely_ = false; - } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) { + } else if (op->op.same_as(builtin::tvm_thread_allreduce())) { // no split if the body contains allreduce. no_split_ = true; return; @@ -214,7 +215,7 @@ class PartitionFinder : public StmtExprVisitor { } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(builtin::likely())) { PrimExpr cond = op->args[0]; if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is @@ -596,7 +597,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b class RemoveLikelyTags : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(builtin::likely())) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); } else { diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 9d6b47a..fac50a0 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -79,7 +80,7 @@ class StorageAccessInfoLower : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { return MakeAccessPtr(op); } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index c7aa949..d38cb7b 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -51,9 +51,17 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode* op) final { + // NOTE: call_type will eventually be deprecated and the information + // will be folded into Op's attr if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - PrimExpr r = ApplyPattern(op->name, GetRef(op)); - if (r.defined()) return r; + if (auto* ptr_op = op->op.as()) { + // Still use legacy string based rewriting + // TODO(tvm-team): migrate the pattern application from global function look up + // to an OpAttrMap + std::string name = ptr_op->name; + PrimExpr r = ApplyPattern(name, GetRef(op)); + if (r.defined()) return r; + } } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -230,7 +238,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(Call(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { @@ -241,7 +249,11 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) { + PrimExpr ApplyPattern(std::string name, const PrimExpr& e) { + if (name.compare(0, 4, "tir.") == 0) { + name = name.substr(4); + } + for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i]; size_t psize = p.length(); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index ee17f08..dab8d5a 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -71,7 +72,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); const CallNode* call = op->value.as(); - if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { + if (call && call->op.same_as(builtin::tvm_thread_allreduce())) { return MakeAllreduce(call); } else { return stmt; @@ -242,7 +243,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { { PrimExpr pred = const_true(1); PrimExpr mask = - Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); seq.emplace_back(Store(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. @@ -273,8 +274,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The former may cause dead lock as there is a divergent // branch with a warp sync call inside. // - const char* shfl_func = intrinsic::tvm_warp_shuffle_down; - PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset); + PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_var, val, offset); const AllocateNode* repl = local_vars[i].as(); Stmt s = Store(repl->buffer_var, other, index, pred); seq.push_back(s); @@ -303,9 +303,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); - const char* shfl_func = intrinsic::tvm_warp_shuffle; PrimExpr val = Load(types[i], var, index, pred); - PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); + PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 0); seq.push_back(Store(var, splat, index, pred)); } @@ -465,18 +464,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)}, + return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}, CallNode::Intrinsic)); } - // Emit warp shuffle intrinsic calls. - PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) { + // Emit warp shuffle calls. + PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int delta_or_lane) { PrimExpr pred = const_true(1); PrimExpr index(0); PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; - return Call(val.dtype(), name, args, CallNode::Intrinsic); + return Call(val.dtype(), op, args, CallNode::Intrinsic); } // Check if this is a reduction on threadIdx.x and its extent matches diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 7611e0f..e618230 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -22,6 +22,7 @@ * \file tir/transforms/lower_tvm_buildin.cc */ #include +#include #include #include #include @@ -40,7 +41,7 @@ inline PrimExpr ConstInt32(size_t index) { inline PrimExpr StackAlloca(std::string type, size_t num) { Array args = {StringImm(type), ConstInt32(num)}; - return Call(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); + return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -103,23 +104,22 @@ class BuiltinLower : public StmtExprMutator { CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = - Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); + Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}, CallNode::Intrinsic)); - Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, - {op->buffer_var}, CallNode::PureIntrinsic), + Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}, + CallNode::PureIntrinsic), throw_last_error), op->body}); - Stmt alloca = LetStmt( op->buffer_var, - Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace", + Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), IntImm(DataType::Int(32), op->dtype.bits())}, CallNode::Extern), body); - PrimExpr free_op = Call(DataType::Int(32), "TVMBackendFreeWorkspace", + PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), op->buffer_var}, CallNode::Extern); @@ -144,15 +144,15 @@ class BuiltinLower : public StmtExprMutator { } } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_call_packed)) { + if (op->op.same_as(builtin::tvm_call_packed())) { return MakeCallPacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) { + } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { return MakeCallTracePacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) { + } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { return MakeShape(op); - } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) { + } else if (op->op.same_as(builtin::tvm_stack_make_array())) { return MakeArray(op); - } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { + } else if (op->op.same_as(builtin::tvm_context_id())) { return make_zero(op->dtype); } else { return StmtExprMutator::VisitExpr_(op); @@ -176,21 +176,21 @@ class BuiltinLower : public StmtExprMutator { run_array_stack_ += 1; PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrData, op->args[0])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrShape, op->args[1])); PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrStrides, strides)); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrNDim, op->args[3])); DataType dtype = op->args[4].dtype(); prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode, + TVMStructSet(stack_array_, idx, builtin::kArrTypeCode, make_const(DataType::UInt(8), static_cast(dtype.code())))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeBits, make_const(DataType::UInt(8), dtype.bits()))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeLanes, make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); @@ -198,15 +198,15 @@ class BuiltinLower : public StmtExprMutator { if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrByteOffset, cast(DataType::UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceId, cast(DataType::Int(32), device_id_))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceType, cast(DataType::Int(32), device_type_))); - return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); + return TVMStructGet(DataType::Handle(), stack_array_, idx, builtin::kArrAddr); } // call packed. PrimExpr MakeCallPacked(const CallNode* op) { @@ -226,7 +226,7 @@ class BuiltinLower : public StmtExprMutator { arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); if (api_type.is_handle() && arg.as()) { arg_tcode = kTVMStr; @@ -245,7 +245,7 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], stack_value_, stack_tcode_, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - return Call(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, + return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args, CallNode::Intrinsic); } @@ -267,7 +267,7 @@ class BuiltinLower : public StmtExprMutator { arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( @@ -287,7 +287,7 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin + op->args.size() - 1), // Pass traced value. op->args[args_size - 1]}; - return Call(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, + return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args, CallNode::Intrinsic); } @@ -295,8 +295,8 @@ class BuiltinLower : public StmtExprMutator { bool IsArrayHandle(const PrimExpr& arg) { // specially set array handle. if (const CallNode* buf = arg.as()) { - if (buf->is_intrinsic(intrinsic::tvm_struct_get) && - buf->args[2].as()->value == intrinsic::kArrAddr) { + if (buf->op.same_as(builtin::tvm_struct_get()) && + buf->args[2].as()->value == builtin::kArrAddr) { return true; } } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 92f9ab5..3e7d13b 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -250,8 +251,8 @@ class WarpAccessRewriter : protected StmtExprMutator { << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); PrimExpr mask = - Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); - return Call(load_value.dtype(), intrinsic::tvm_warp_shuffle, + Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); + return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index a91e350..9bb5fc6 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -82,10 +83,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; + IntImm(DataType::Int(32), builtin::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = Call(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); + PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args, CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { res = Cast(t, res); @@ -189,7 +190,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { Stmt set_device = - Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, CallNode::Intrinsic)); body = SeqStmt({set_device, body}); diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 07b0ea2..a14fd02 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -318,6 +319,8 @@ class DataTypeRewriter : public StmtExprMutator { std::unordered_map ivmap_; // indicator of LoadNode::index and StoreNode::index bool is_index_{false}; + // cached ops + const Op& builtin_pow_ = Op::Get("tir.pow"); }; #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ @@ -352,23 +355,23 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { op = e.as(); CHECK(op != nullptr) << "Expected type to be CallNode" << ", but get " << e->GetTypeKey(); - if (op->call_type == CallNode::PureIntrinsic) { - if (op->name == intrinsic::tvm_if_then_else) { - return if_then_else(op->args[0], op->args[1], op->args[2]); - } else if (op->name == CallNode::shift_right) { - return op->args[0] >> op->args[1]; - } else if (op->name == CallNode::shift_left) { - return op->args[0] << op->args[1]; - } else if (op->name == CallNode::bitwise_and) { - return op->args[0] & op->args[1]; - } else if (op->name == CallNode::bitwise_or) { - return op->args[0] | op->args[1]; - } else if (op->name == CallNode::bitwise_xor) { - return op->args[0] ^ op->args[1]; - } else if (op->name == "pow") { - return pow(op->args[0], op->args[1]); - } + + if (op->op.same_as(builtin::if_then_else())) { + return if_then_else(op->args[0], op->args[1], op->args[2]); + } else if (op->op.same_as(builtin::shift_right())) { + return op->args[0] >> op->args[1]; + } else if (op->op.same_as(builtin::shift_left())) { + return op->args[0] << op->args[1]; + } else if (op->op.same_as(builtin::bitwise_and())) { + return op->args[0] & op->args[1]; + } else if (op->op.same_as(builtin::bitwise_or())) { + return op->args[0] | op->args[1]; + } else if (op->op.same_as(builtin::bitwise_xor())) { + return op->args[0] ^ op->args[1]; + } else if (op->op.same_as(builtin_pow_)) { + return pow(op->args[0], op->args[1]); } + return e; } diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 701f0ce..e553536 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -22,6 +22,7 @@ * \brief Rewrite uinsafe select expression. */ #include +#include #include #include #include @@ -37,9 +38,9 @@ class UnsafeExprDetector : public ExprFunctor { // Because we will issue guard to make sure it is. bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } bool VisitExpr_(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + if (op->op.same_as(builtin::if_then_else())) { return VisitExpr(op->args[0]); - } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { + } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); return this->VisitExpr(l->index); } else if (op->is_pure()) { @@ -104,7 +105,7 @@ class UnsafeSelectRewriter : public StmtExprMutator { bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { - return Call(op->dtype, intrinsic::tvm_if_then_else, + return Call(op->dtype, builtin::if_then_else(), {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); } else { return expr; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 0684189..c35caf5 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -238,7 +239,7 @@ class HostDeviceSplitter : public StmtMutator { call_args.push_back(ext); } return Evaluate( - Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic)); + Call(DataType::Int(32), builtin::tvm_call_packed(), call_args, CallNode::Intrinsic)); } // target ir module diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 20cc640..24f8b75 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -181,10 +181,10 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { } void StorageAccessVisitor::VisitExpr_(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_address_of)) { + if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); StmtExprVisitor::VisitExpr_(l); - } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + } else if (op->op.same_as(builtin::tvm_access_ptr())) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -211,7 +211,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } StmtExprVisitor::VisitExpr_(op); - } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { + } else if (op->op.same_as(builtin::tvm_storage_sync())) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; if (s != "warp") { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index e29d978..3080550 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -45,7 +46,6 @@ namespace tvm { namespace tir { -using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; @@ -101,7 +101,7 @@ class StorageFlattener : public StmtExprMutator { } else if (op->attr_key == attr::buffer_dim_align) { auto buffer = Downcast(op->node); const CallNode* tuple = op->value.as(); - CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); + CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); auto& vinfo = dim_align_[buffer]; int dim = tuple->args[0].as()->value; if (static_cast(dim) >= vinfo.size()) { @@ -322,9 +322,9 @@ class StorageFlattener : public StmtExprMutator { } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); PrimExpr address = - Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); + Call(DataType::Handle(), builtin::address_of(), {load}, CallNode::PureIntrinsic); PrimExpr prefetch = - Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); + Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}, CallNode::Intrinsic); stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -392,7 +392,7 @@ class StorageFlattener : public StmtExprMutator { const BufferNode* target = arr[1].as(); const CallNode* tuple = op->value.as(); CHECK(buffer && target); - CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); + CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); auto key = GetRef(target); auto it = buf_map_.find(key); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 283ab0f..d7a258c 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -131,7 +132,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { } } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_address_of)) { + if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); this->VisitExpr(l->index); } else { @@ -387,7 +388,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -403,7 +404,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return Call(op->dtype, op->name, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, op->call_type); } else { return StmtExprMutator::VisitExpr_(op); @@ -911,7 +912,7 @@ class VectorAllocRewriter : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); UpdateTypeMap(buffer, dtype); diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 493aa51..1b3b3c4 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -53,8 +53,8 @@ class FragmentGetter : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); - if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || - op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + if (op->op.same_as(builtin::tvm_load_matrix_sync()) || + op->op.same_as(builtin::tvm_store_matrix_sync())) { // Get shape and layout information from load and store intrinsic CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var = op->args[0].as(); @@ -89,7 +89,7 @@ class FragmentGetter : public StmtExprVisitor { } fragments[buffer_var] = info; } - } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + } else if (op->op.same_as(builtin::tvm_fill_fragment())) { // Get shape information from fill intrinsic CHECK_EQ(op->args.size(), 6U); const VarNode* buffer_var = op->args[0].as(); @@ -141,7 +141,7 @@ class FragmentChecker : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync - if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) { + if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) { CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 612efb0..cdd9377 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -209,7 +210,7 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. @@ -259,7 +260,7 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); @@ -299,7 +300,7 @@ class ThreadSyncInserter : public StmtExprMutator { CHECK(op != nullptr); Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = - Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs, CallNode::Intrinsic)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; @@ -309,7 +310,7 @@ class ThreadSyncInserter : public StmtExprMutator { } rw_stats_.clear(); Stmt kinit = Evaluate( - Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); + Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); body = AttrStmt(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); @@ -332,7 +333,7 @@ class ThreadSyncInserter : public StmtExprMutator { } else { CHECK_EQ(num_work_dim_, thread_extents_.size()); } - return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, CallNode::Intrinsic)); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 227aea2..1a2ec50 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -23,8 +23,10 @@ // Loop vectorizer as in Halide pipeline. #include #include +#include #include #include +#include #include #include @@ -212,15 +214,18 @@ class Vectorizer : public StmtExprMutator { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type); + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->call_type); } } // Call PrimExpr VisitExpr_(const CallNode* op) final { - if (op->name == intrinsic::tvm_if_then_else) { + if (op->op.same_as(builtin::if_then_else())) { return MutateIfThenElseExpr_(op); } - if (!op->is_vectorizable()) { + auto* op_ptr = op->op.as(); + bool vectorizable = op_ptr && op_vectorizable_.get(GetRef(op_ptr), false); + + if (!vectorizable) { // Cannot vectorize this op Array new_args; for (auto arg : op->args) { @@ -234,7 +239,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype, op->name, new_args, op->call_type); + return Call(op->dtype, op->op, new_args, op->call_type); } } else { int lane = 0; @@ -243,7 +248,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype.with_lanes(lane), op->name, new_args, op->call_type); + return Call(op->dtype.with_lanes(lane), op->op, new_args, op->call_type); } } } @@ -380,6 +385,9 @@ class Vectorizer : public StmtExprMutator { bool need_scalarize_{false}; // The lets std::unordered_map lets_; + // vectorizable property + OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); + // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. Array MutateArray(Array arr, int* p_lanes) { diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 8dae799..ce50ed0 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -192,9 +193,10 @@ TEST(IRF, StmtMutator) { } { - auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); + auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}, + CallNode::Extern)); auto res = v(std::move(body)); - CHECK(res.as()->value.as()->args[0].same_as(x)); + CHECK(res.as()->value.as()->args[1].same_as(x)); } { Stmt body = fmakealloc(); diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index c4ac042..1e4fe6b 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -853,8 +853,8 @@ def test_duplicate_adt_cons_defn(): def test_duplicate_global_var(): parse_text( """ - def @id[A](%x: A) -> A { x } - def @id[A](%x: A) -> A { x } + def @id[A](%%x: A) -> A { x } + def @id[A](%%x: A) -> A { x } """ ) diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 525cd6c..2fbc82f 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -202,7 +202,8 @@ def test_reduce_combiner_simplify(): assert tvm.ir.structural_equal(lhs, rhs) # Test that components with side effects are not removed - side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic) + dummy = tvm.ir.GlobalVar("dummy") + side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs, tvm.tir.Call.Intrinsic) ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index 0f00e08..18a98ee 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -98,7 +98,7 @@ def test_reinterpret(): nn = 1024 n = tvm.runtime.convert(nn) A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "reinterpret", A(*i)), name='B') + B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "tir.reinterpret", A(*i)), name='B') s = te.create_schedule(B.op) def check_c(): diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 0b415b0..a6a2315 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -29,12 +29,12 @@ def test_llvm_intrin(): n = tvm.runtime.convert(4) A = ib.pointer("float32", name="A") args = [ - tvm.tir.call_pure_intrin("handle", "tvm_address_of", A[0]), + tvm.tir.call_pure_intrin("handle", "tir.address_of", A[0]), 0, 3, 1 ] ib.emit(tvm.tir.Evaluate( tvm.tir.Call( - "int32", "prefetch", args, tvm.tir.Call.Intrinsic))) + "int32", "tir.prefetch", args, tvm.tir.Call.Intrinsic))) body = ib.get() mod = tvm.IRModule.from_expr( @@ -738,20 +738,20 @@ def test_llvm_shuffle(): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) def np_float2np_bf16(arr): - ''' Convert a numpy array of float to a numpy array + ''' Convert a numpy array of float to a numpy array of bf16 in uint16''' orig = arr.view('> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32, type="pure_intrin")' assert str(10 % x) == 'floormod(10, x: int32)' - assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin")' + + assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32, type="pure_intrin")' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -239,10 +240,10 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool, type="pure_intrin")' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' z = te.var('z', 'int32') assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') diff --git a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py index 38529e9..61accf2 100644 --- a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py +++ b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py @@ -26,20 +26,21 @@ def test_ir_transform(): ib.emit(tvm.tir.call_extern("int32", "TestB", x)) ib.emit(tvm.tir.call_extern("int32", "TestC", x)) body = ib.get() + builtin_call_extern = tvm.ir.Op.get("tir.call_extern") def preorder(op): - if op.name == "TestC": + if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestC": return tvm.tir.const(0, "int32") return None def postorder(op): assert isinstance(op, tvm.tir.Call) - if op.name == "TestA": - return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1) + if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestA": + return tvm.tir.call_extern("int32", "TestB", op.args[1] + 1) return op body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) - assert stmt_list[0].value.args[0].name == "TestB" + assert stmt_list[0].value.args[1].args[0].value == "TestB" assert stmt_list[1].value.value == 0 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index 77a0602..55a6819 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -116,19 +116,19 @@ def test_legalize(): def to32(v): uint32_v = topi.cast(v, "uint32") uint32_v = tvm.tir.call_pure_intrin( - "uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32")) - return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v) + "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32")) + return tvm.tir.call_pure_intrin("float32", "tir.reinterpret", uint32_v) def to16(v): - uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v) + uint32_v = tvm.tir.call_pure_intrin("uint32", "tir.reinterpret", v) rounding_bias = tvm.tir.call_pure_intrin( - "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) rounding_bias = tvm.tir.call_pure_intrin( - "uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) + "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16") uint32_v = uint32_v + rounding_bias uint32_v = tvm.tir.call_pure_intrin( - "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) return topi.cast(uint32_v, 'uint16') def check(fcompute_before, fcompute_after): diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 29a3303..d7a25ca 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -22,7 +22,7 @@ def test_for(): def device_context(dev_id): ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) return tvm.tir.Call( - "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) + "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) ib = tvm.tir.ir_builder.create() n = te.var("n") diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index f658349..8469bc9 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -17,6 +17,14 @@ import tvm from tvm import te +# register the ops +tvm.ir.register_op_attr("tir.cop.coproc_sync", "TGlobalSymbol", "coproc_sync") +tvm.ir.register_op_attr("tir.cop.coproc_read_barrier", "TGlobalSymbol", "coproc_readb") +tvm.ir.register_op_attr("tir.cop.coproc_write_barrier", "TGlobalSymbol", "coproc_writeb") +tvm.ir.register_op_attr("tir.cop.coproc_dep_push", "TGlobalSymbol", "coproc_dep_push") +tvm.ir.register_op_attr("tir.cop.coproc_dep_pop", "TGlobalSymbol", "coproc_dep_pop") + + def test_coproc_sync(): @tvm.register_func("tvm.info.mem.global.cache") def meminfo_cache(): @@ -26,6 +34,7 @@ def test_coproc_sync(): max_simd_bits=32, max_num_bits=128, head_address=tvm.tir.call_extern("handle", "global_cache")) + ib = tvm.tir.ir_builder.create() n = te.size_var("n") cp = te.thread_axis((0, 1), "cop") @@ -43,10 +52,11 @@ def test_coproc_sync(): body = stmt.body.body.body blist = tvm.tir.stmt_list(body) - assert(blist[1].value.name == "cop.coproc_read_barrier") + + assert(blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier"))) assert(blist[1].value.args[3].value == 80) - assert(blist[-2].value.name == "cop.coproc_sync") - assert(blist[-1].value.name == "cop.coproc_write_barrier") + assert(blist[-2].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_sync"))) + assert(blist[-1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_write_barrier"))) assert(blist[-1].value.args[3].value == 10) @@ -106,9 +116,9 @@ def test_coproc_sync3(): slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] - assert(push_st.value.name == "cop.coproc_dep_push") + assert(push_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_push"))) assert(__check_list(push_st.value.args, [2,3])) - assert(pop_st.value.name == "cop.coproc_dep_pop") + assert(pop_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_pop"))) assert(__check_list(pop_st.value.args, [2,3])) diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 0b6b167..cf58632 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -56,7 +56,7 @@ def test_double_buffer(): f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] def count_sync(op): - if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": + if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")): count[0] += 1 tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index c0789c6..4964039 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -36,7 +36,7 @@ def test_vthread(): bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) ib.emit(tvm.tir.call_extern("int32", "Run", bbuffer.access_ptr("r"), - tvm.tir.call_pure_intrin("int32", "tvm_context_id"))) + tvm.tir.call_pure_intrin("int32", "tir.tvm_context_id"))) C[i * nthread + tx] = B[i] + 1 return ib.get() diff --git a/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py index 229c11b..9f1104d 100644 --- a/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py +++ b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py @@ -39,8 +39,10 @@ def test_rewrite_Select(): mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a))) aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value - assert yy.name == "tvm_if_then_else" - assert zz.name == "tvm_if_then_else" + builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else") + + assert yy.op.same_as(builtin_if_then_else) + assert yy.op.same_as(builtin_if_then_else) assert isinstance(aa, tvm.tir.Select) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 5fea580..468867a 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -125,7 +125,7 @@ def test_flatten_double_buffer(): count = [0] def count_sync(op): - if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": + if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")): count[0] += 1 tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 783b669..3ff6804 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -49,7 +49,7 @@ def test_thread_storage_sync(): cuda_target = tvm.target.create("cuda") f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] body_list = tvm.tir.stmt_list(f.body.body.body.body) - assert(body_list[1].value.name == "tvm_storage_sync") + assert(body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))) diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index d7124b6..a69c9d3 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -117,7 +117,7 @@ def test_vectorize_if_then_else(): ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") with ib.for_range(0, 4, for_type="vectorize") as i: - A[i] = tvm.tir.call_intrin("float32", "tvm_if_then_else", + A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] + 1, A[i]) stmt = ib.get() @@ -132,7 +132,7 @@ def test_vectorize_if_then_else(): A = ib.pointer("float32", name="A") with ib.for_range(0, n) as k: with ib.for_range(0, 4, for_type="vectorize") as i: - A[k * 4 + i] = tvm.tir.call_intrin("float32", "tvm_if_then_else", + A[k * 4 + i] = tvm.tir.call_intrin("float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0) stmt = ib.get() diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index b84fbc7..7068b95 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -25,6 +25,7 @@ #define TOPI_DETAIL_EXTERN_H_ #include +#include #include #include @@ -111,11 +112,11 @@ inline Array make_extern(const Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + strides = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; @@ -126,7 +127,7 @@ inline PrimExpr pack_buffer(Buffer buf) { make_const(DataType::Int(32), static_cast(buf->shape.size())), make_const(buf->dtype, 0), buf->elem_offset}; - return tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, pack_args, + return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args, tvm::tir::CallNode::CallType::Intrinsic); } @@ -140,7 +141,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::Call(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, + return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args, tvm::tir::CallNode::CallType::Intrinsic); } diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index a92d21c..0ec7e4d 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -25,6 +25,7 @@ #define TOPI_ELEMWISE_H_ #include +#include #include #include @@ -309,7 +310,8 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te return compute( x->shape, [&](const Array& i) { - return tvm::tir::Call(type, "reinterpret", {x(i)}, tvm::tir::CallNode::PureIntrinsic); + return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)}, + tvm::tir::CallNode::PureIntrinsic); }, name, tag); } diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index ac1ac45..f035251 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -231,8 +231,10 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar): cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_) else: cnts = tvm.tir.popcount(w_ & x_) - upper_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorhigh', cnts) - lower_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorlow', cnts) + upper_half = tvm.tir.call_pure_intrin( + half_dtype, 'tir.vectorhigh', cnts) + lower_half = tvm.tir.call_pure_intrin( + half_dtype, 'tir.vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, @@ -241,7 +243,7 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( - full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) + full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) out = tvm.tir.call_llvm_intrin( return_dtype, vpadalu, @@ -261,7 +263,7 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( - full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) + full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) out = tvm.tir.call_llvm_intrin( return_dtype, vpadalu, diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index bab9157..da9c71a 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -86,11 +86,11 @@ def dot_int8_int8_int32(int32_lanes, dtype='uint'): dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) - re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'tir.reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) - vec_a = tvm.tir.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin(dtype_b, 'tir.reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index f2c1143..c98d7e9 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -38,9 +38,10 @@ def cuda_atomic_add_rule(op): tvm.target.intrin.register_intrin_rule( "cuda", "atomic_add", cuda_atomic_add_rule, override=True) +tvm.ir.register_op_attr("tir.atomic_add", "TVectorizable", False) def atomic_add(x, y): - return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y) + return tvm.tir.call_pure_intrin(y.dtype, "tir.atomic_add", x, y) def get_valid_counts_ir(data, valid_count, out, out_indices, @@ -113,7 +114,7 @@ def get_valid_counts_ir(data, valid_count, out, out_indices, with ib.if_scope( tvm.tir.all(data[tid * elem_length + score_index] > score_threshold, tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): - atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tvm_address_of", + atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tir.address_of", valid_count[i]), one_count) with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = data[tid * elem_length + k] diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index f713bb2..5b7e090 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -185,7 +185,7 @@ def argsort_ir(data_buf, out_index_buf): temp_index[0] = index_out[offset] index_out[offset] = index_out[offset + 1] index_out[offset + 1] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) return ib.get() @@ -246,7 +246,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold): iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5) with ib.if_scope(iou > nms_threshold): p_out[base_idx + i] = True - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) return ib.get() diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index ddae2bd..7181d57 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -115,7 +115,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): if indices_out is not None: indices_out[base_idx + tid * axis_mul_after] = \ tvm.tir.generic.cast(tid, indices_out.dtype) - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) idxd = tvm.tir.indexdiv @@ -143,7 +143,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): temp_index[0] = indices_out[offset] indices_out[offset] = indices_out[offset + axis_mul_after] indices_out[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) @@ -235,7 +235,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): temp_index[0] = output[offset] output[offset] = output[offset + axis_mul_after] output[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index 3941c00..c2b7d25 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -100,7 +100,7 @@ def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, BC = outs[0] row = wmma_m * wmma_k warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, wmma_m, wmma_n, wmma_k, warp_index, BA.access_ptr('r'), strides_from[0], layout)) return ib.get() @@ -128,7 +128,7 @@ def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape, BC = outs[0] row = wmma_n * wmma_k warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, wmma_m, wmma_n, wmma_k, warp_index, BA.access_ptr('r'), strides_from[0], layout)) return ib.get() @@ -156,7 +156,7 @@ def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shap BC = outs[0] row = wmma_m * wmma_n warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n - ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', BA.data, wmma_m, wmma_n, wmma_k, warp_index, BC.access_ptr('w'), strides_dst[0], 'row_major')) return ib.get() @@ -207,13 +207,14 @@ def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A, def init(): ib = tvm.tir.ir_builder.create() ib.emit( - tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k, + tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', + BC.data, wmma_m, wmma_n, wmma_k, warp_index_C, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', BC.data, warp_index_C, BA.data, warp_index_A, BB.data, warp_index_B, diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index ee8d83d..31de70e 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -88,9 +88,9 @@ def dot_16x1x16_uint8_int8_int32_skylake(): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.tir.const(1, "int16x32") pair_reduction = tvm.tir.call_llvm_intrin('int16x32', @@ -174,9 +174,9 @@ def dot_16x1x16_uint8_int8_int16(): return ib.get() a_int8 = ins[0].vload([0], "uint8x2") - re_int16 = tvm.tir.call_pure_intrin('int16', 'reinterpret', a_int8) + re_int16 = tvm.tir.call_pure_intrin('int16', 'tir.reinterpret', a_int8) vec_ai16 = re_int16.astype('int16x32') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai16) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai16) for i in range(4): vec_b = ins[1].vload([i*32, 0], "int8x64") @@ -254,7 +254,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake(): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_b = ins[1].vload([0, 0], "int8x64") @@ -262,7 +262,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake(): llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) if llvm_id != 0: # VNNI is available for current LLVM version - vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'reinterpret', vec_b) + vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'tir.reinterpret', vec_b) vec_zero = tvm.tir.const(0, "int32x16") quad_reduction = tvm.tir.call_llvm_intrin('int32x16', 'llvm.x86.avx512.vpdpbusd.512', @@ -270,7 +270,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake(): vec_zero, vec_ai32, vec_bi32) else: # Fall back to the normal AVX512 - vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) vec_one = tvm.tir.const(1, "int16x32") pair_reduction = tvm.tir.call_llvm_intrin('int16x32', 'llvm.x86.avx512.pmaddubs.w.512', diff --git a/topi/tests/python/test_topi_basic.py b/topi/tests/python/test_topi_basic.py index 13f1463..a83ff50 100644 --- a/topi/tests/python/test_topi_basic.py +++ b/topi/tests/python/test_topi_basic.py @@ -34,7 +34,7 @@ def test_ewise(): def test_apply(func, name): B = func(A) assert tuple(B.shape) == tuple(A.shape) - assert B.op.body[0].name == name + assert B.op.body[0].op.name == "tir." + name test_apply(topi.exp, "exp") test_apply(topi.erf, "erf") diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index ea98083..6f1e858 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -50,11 +50,11 @@ def test_ewise(): B = func(A) assert tuple(B.shape) == tuple(A.shape) if not skip_name_check: - assert B.op.body[0].name == name + assert B.op.body[0].op.name == "tir." + name a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 # avoid round check too close to boundary if check_round: - a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5 + a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-4 b_np = f_numpy(a_np) def check_device(device): @@ -89,7 +89,7 @@ def test_ewise(): B = topi.isnan(A) assert tuple(B.shape) == tuple(A.shape) if not skip_name_check: - assert B.op.body[0].name == "isnan" + assert B.op.body[0].op.name == "tir.isnan" a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan # avoid round check too close to boundary diff --git a/tutorials/language/intrin_math.py b/tutorials/language/intrin_math.py index 146263d..65bfd4c 100644 --- a/tutorials/language/intrin_math.py +++ b/tutorials/language/intrin_math.py @@ -100,12 +100,15 @@ print(fopencl.imported_modules[0].get_source()) def my_cuda_math_rule(op): """Customized CUDA intrinsic lowering rule""" assert isinstance(op, tvm.tir.Call) + name = op.op.name + assert name.startswith("tir.") + dispatch_name = name[4:] if op.dtype == "float32": # call float function - return tvm.tir.call_pure_extern("float32", "%sf" % op.name, op.args[0]) + return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0]) elif op.dtype == "float64": # call double function - return tvm.tir.call_pure_extern("float32", op.name, op.args[0]) + return tvm.tir.call_pure_extern("float32", dispatch_name, op.args[0]) else: # cannot do translation, return self. return op @@ -132,7 +135,7 @@ print(fcuda.imported_modules[0].get_source()) def mylog(x): """customized log intrinsic function""" - return tvm.tir.call_pure_intrin(x.dtype, "mylog", x) + return tvm.tir.call_pure_intrin(x.dtype, "tir.mylog", x) def my_cuda_mylog_rule(op): @@ -144,7 +147,8 @@ def my_cuda_mylog_rule(op): else: return op - +# new op registration is triggered by registering an attribute of the op +tvm.ir.register_op_attr("tir.mylog", "TVectorizable", True) tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True) n = te.var("n") diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py index cd40a91..4b2823c 100644 --- a/tutorials/optimize/opt_conv_tensorcore.py +++ b/tutorials/optimize/opt_conv_tensorcore.py @@ -163,7 +163,7 @@ def intrin_wmma_load_matrix(scope): BA = ins[0] BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, n, n, n, BC.elem_offset // 256, BA.access_ptr('r'), n, 'row_major')) return ib.get() @@ -190,12 +190,12 @@ def intrin_wmma_gemm(): def init(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', BC.data, BC.elem_offset // 256, BA.data, BA.elem_offset // 256, BB.data, BB.elem_offset // 256, @@ -218,7 +218,7 @@ def intrin_wmma_store_matrix(): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', BA.data, n, n, n, BA.elem_offset // 256, BC.access_ptr('w'), n, 'row_major')) return ib.get() diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index e68f098..947c583 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -77,9 +77,9 @@ class DevContext(object): def __init__(self, env): self.vta_axis = te.thread_axis("vta") self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp") - ctx = tvm.tir.call_extern("handle", "VTATLSCommandHandle") + ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle") self.command_handle = tvm.tir.Call( - "handle", "tvm_thread_context", [ctx], + "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) self.DEBUG_NO_SYNC = False env._dev_ctx = self @@ -298,6 +298,7 @@ def coproc_sync(op): tvm.runtime.const(1<<31, dtype="uint32")) + @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push") def coproc_dep_push(op): return tvm.tir.call_extern( @@ -313,6 +314,15 @@ def coproc_dep_pop(op): get_env().dev.command_handle, op.args[0], op.args[1]) +# register a dummy into to trigger registration of the ops +# change the info to lowering rule later. +tvm.ir.register_op_attr("tir.vta.coproc_sync", "TVectorizable", False) +tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TVectorizable", False) +tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TVectorizable", False) + +tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush") +tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle") + def _init_env(): """Initialize the default global env""" diff --git a/vta/python/vta/intrin.py b/vta/python/vta/intrin.py index 8532ffa..897bbcb 100644 --- a/vta/python/vta/intrin.py +++ b/vta/python/vta/intrin.py @@ -82,16 +82,16 @@ def gemm(env, mock=False): irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) if index in (0, 2): - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin( + "int32", "tir.vta.uop_push", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), dwgt.access_ptr("r", "int32"), 0, 0, 0)) else: - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin( + "int32", "tir.vta.uop_push", 0, 1, dout.access_ptr("rw", "int32"), 0, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 207f784..e92b178 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -59,11 +59,12 @@ def FoldUopLoop(): loop_var = stmt.loop_var gemm_offsets = [None, None, None] fail = [False] + builtin_uop_push = tvm.ir.Op.get("tir.vta.uop_push") def _post_order(op): assert isinstance(op, tvm.tir.Call) base_args = 2 - if op.name == "VTAUopPush": + if op.op.same_as(builtin_uop_push): args = [] args += op.args[:base_args] for i in range(3): @@ -81,8 +82,8 @@ def FoldUopLoop(): gemm_offsets[i] = m[0] args.append(m[1]) args += op.args[base_args+3:] - return tvm.tir.call_extern("int32", "VTAUopPush", *args) - if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): + return tvm.tir.call_intrin("int32", builtin_uop_push, *args) + if op.op.name not in ("tir.vta.command_handle", "tir.tvm_thread_context"): raise RuntimeError("unexpected op %s" % op) return op @@ -643,7 +644,7 @@ def InjectConv2DTransposeSkip(): dev = env.dev irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push", 0, 1, dout.access_ptr("rw", "int32"), 0, 0, @@ -658,7 +659,7 @@ def InjectConv2DTransposeSkip(): tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( [dout, res_buffer], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) return inner else: conv_call, data_call, kernel_call = calls[-3:] @@ -678,7 +679,7 @@ def InjectConv2DTransposeSkip(): irb.scope_attr( dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), @@ -691,19 +692,19 @@ def InjectConv2DTransposeSkip(): 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( [dout, res_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) args = kernel_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( [dwgt, kernel_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) args = data_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( [dinp, pad_data_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) return inner return None @@ -833,11 +834,11 @@ def InjectALUIntrin(): lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.tir.Call): - if loop_body.value.name == 'shift_left': + if loop_body.value.op.name == 'tir.shift_left': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = analyzer.simplify(-loop_body.value.args[1]) - elif loop_body.value.name == 'shift_right': + elif loop_body.value.op.name == 'tir.shift_right': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = loop_body.value.args[1] @@ -942,8 +943,8 @@ def InjectALUIntrin(): "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0)) use_imm = int(use_imm) - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin( + "int32", "tir.vta.uop_push", 1, 0, dst_coeff[len(dst_coeff)-1], src_coeff[len(src_coeff)-1], -- 2.7.4