From: Minmin Sun (孙敏敏) Date: Sat, 9 Nov 2019 21:01:36 +0000 (+0800) Subject: Auto TensorCore CodeGen (#4234) X-Git-Tag: upstream/0.7.0~1682 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d64bf6b51b1d720cd437bd356bd8aed9ef655575;p=platform%2Fupstream%2Ftvm.git Auto TensorCore CodeGen (#4234) * Add Auto TensorCore TensorCore Unit Test * Rebase to tvm master branch & Add auto tensor core * Code Refine * Add tensor core switch by pragma * Add pragma in tensor core example code * Get real tile size to replace hard coded 16 * support more than 2 dimensions (e.g. batchmatmul) for buffer bind scope * support batch matmul * Move cuda env check to tensor_core.cc * Coderefine for tensor_core.cc * Refine comments * Some refinements of code and comment * Update TensorCore UT to pass the CPU test * remove redundant code * matmul's storage align for different layout * Add support for differenct position of type cast * Add formal tutorial for auto tensorcore codegen * move tensorcore check up to tutorial code * code and doc refine * comment out tune_and_evaluate in tutorial * fix cpplint error --- diff --git a/include/tvm/ir.h b/include/tvm/ir.h index b6c3028..aca45f4 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1248,6 +1248,8 @@ constexpr const char* reduce_scope = "reduce_scope"; constexpr const char* pragma_scope_prefix = "pragma_"; /*! \brief Import llvm source or file into the final code gen module */ constexpr const char* pragma_import_llvm = "pragma_import_llvm"; +/*! \brief Try to modify the AST to support Tensor Core */ +constexpr const char* pragma_tensor_core = "pragma_tensor_core"; /*! * \brief Mark of prefetch scope, value=offset, * run prefetch of Tensor on the current loop scope diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 842c6af..76d7d61 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -206,6 +206,20 @@ Stmt StorageFlatten(Stmt stmt, Map extern_buffer, int cache_line_size, bool create_bound_attribute = false); + +/*! + * \brief Try to modify the AST to support TensorCore + * + * \param stmt The stmt to be trasnformed. + * \param schedule The original schedule. + * \param extern_buffer Map specifies external + * buffer assignment of input and outputs. + * \return Transformed stmt. + */ +Stmt RewriteForTensorCore(Stmt stmt, + Schedule schedule, + Map extern_buffer); + /*! * \brief Verify if there is any argument bound to compact buffer. * diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index fe2f641..217318e 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -387,6 +387,7 @@ def lower(sch, binds, arg_list = get_binds(args, compact, binds) # Phase 1 + stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index d7f621f..6994e2d 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -94,6 +94,12 @@ TVM_REGISTER_API("ir_pass.StorageFlatten") } }); +TVM_REGISTER_API("ir_pass.RewriteForTensorCore") +.set_body_typed&)> + ([](const Stmt& stmt, const Schedule& schedule, const Map& extern_buffer) { + return RewriteForTensorCore(stmt, schedule, extern_buffer); + }); + TVM_REGISTER_API("ir_pass.AttrsEqual") .set_body_typed([](const NodeRef& lhs, const NodeRef& rhs) { return AttrsEqual()(lhs, rhs); diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc new file mode 100644 index 0000000..183ed3b --- /dev/null +++ b/src/pass/tensor_core.cc @@ -0,0 +1,1202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tensor_core.cc + */ +// IR Passes for TensorCore CodeGen +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir_util.h" +#include "../arithmetic/compute_expr.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +using runtime::StorageRank; +using runtime::StorageScope; +using runtime::ThreadScope; +using intrinsic::tvm_address_of; + +struct Tile { + int m{-1}; + int n{-1}; + int k{-1}; +}; + +std::string simplify_name(std::string input) { + auto pos = input.find("."); + if (pos != std::string::npos) { + return input.substr(0, pos); + } else { + return input; + } +} + +Expr unpack_type_cast(const Expr &input, const Type &target_type) { + auto cast = input.as(); + if (cast == nullptr) { + return input; + } else if (cast->type == target_type) { + return cast->value; + } + return Expr(); +} + +// MMAMatcher matches C = Cast(A)*Cast(B)+C, +// where A & B are fp16/int8 local buffers, +// and C is fp32/int32 local buffer. +class MMAMatcher: public IRVisitor { + public: + explicit MMAMatcher(Map extern_buffer) { + for (auto kv : extern_buffer) { + BufferInfo bi; + bi.name = kv.second->name; + bi.dtype = kv.second->dtype; + bi.external = true; + buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; + } + } + using IRVisitor::Visit_; + + void Visit_(const AttrStmt* op) final { + if (op->attr_key == attr::pragma_tensor_core) { + tensor_core_on_ = true; + IRVisitor::Visit_(op); + } else if (op->attr_key == attr::realize_scope) { + storage_scope_[op->node.get()] = op->value.as()->value; + Visit(op->body); + } else { + IRVisitor::Visit_(op); + } + } + + void Visit_(const Provide* op) final { + IRVisitor::Visit_(op); + auto it = buf_map_.find(TensorKey{op->func, op->value_index}); + if (it == buf_map_.end()) { + return; + } + const BufferInfo& bi = it->second; + if (bi.released) { + return; + } + if (tensor_core_on_ && mma_sync_match_(op, bi)) { + matched_ = true; + } + } + + void Visit_(const Realize* op) final { + TensorKey key{op->func, op->value_index}; + if (buf_map_.count(key)) { + if (!buf_map_.at(key).external) { + return; + } + Visit(op->body); + } else { + BufferInfo bi; + bi.name = key.GetName(); + bi.dtype = op->type; + buf_map_[key] = bi; + Visit(op->body); + buf_map_[key].released = true; + } + } + + inline bool Matched() const {return matched_;} + + friend class ScheduleAnalyser; + friend class BufferAnalyser; + + private: + struct BufferInfo { + std::string name; + Type dtype; + bool external{false}; + bool released{false}; + bool same_as(const BufferInfo &bi) { + if (this->dtype != bi.dtype) return false; + if (this->name != bi.name) return false; + if (this->external != bi.external) return false; + if (this->released != bi.released) return false; + return true; + } + }; + + // Check whether the storage scope is local + bool check_local_buffer_(const Call* op, BufferInfo* bi) { + if (op->call_type == Call::Halide) { + auto it = storage_scope_.find(op->func.get()); + if (it == storage_scope_.end()) { + return false; + } + const std::string& strkey = it->second; + if (strkey != "local") { + return false; + } + auto it1 = buf_map_.find(TensorKey{op->func, op->value_index}); + if (it1 == buf_map_.end()) { + return false; + } + *bi = it1->second; + if (bi->released) { + return false; + } + return true; + } + return false; + } + + // Do the pattern matching + bool mma_sync_match_(const Provide* op, BufferInfo store_buffer) { + auto* add = op->value.as(); + if (add == nullptr) { + return false; + } + + auto* load_c = add->a.as(); + BufferInfo buffer_c; + if (!check_local_buffer_(load_c, &buffer_c) + || !buffer_c.same_as(store_buffer) + || !(buffer_c.dtype == Float(32) || + buffer_c.dtype == Int(32))) { + return false; + } + + auto mul = unpack_type_cast(add->b, buffer_c.dtype).as(); + if (mul == nullptr) { + return false; + } + + auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); + auto load_a = load_a_expr.as(); + BufferInfo buffer_a; + if (!check_local_buffer_(load_a, &buffer_a) + || !(buffer_a.dtype == Float(16) || + buffer_a.dtype == Int(8))) { + return false; + } + + auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); + auto load_b = load_b_expr.as(); + BufferInfo buffer_b; + if (!check_local_buffer_(load_b, &buffer_b) + || !(buffer_b.dtype == Float(16) || + buffer_b.dtype == Int(8))) { + return false; + } + + frag_reg_.insert(buffer_c.name); + frag_reg_.insert(buffer_a.name); + frag_reg_.insert(buffer_b.name); + buf_name_.insert(std::make_pair(load_a, buffer_a.name)); + buf_name_.insert(std::make_pair(load_b, buffer_b.name)); + mma_sync_.insert(std::make_pair(op, + Array{load_a_expr, load_b_expr, add->a})); + + return true; + } + + std::unordered_map buf_map_; + std::unordered_map storage_scope_; + std::unordered_map> mma_sync_; + std::unordered_map buf_name_; + std::unordered_set frag_reg_; + bool matched_{false}; + bool tensor_core_on_{false}; +}; + +// BodyVisitor visits the body stmt of original ComputeOp +// to get the access indices of input matrices, +// if it is recognized as matrix multiply. +class BodyVisitor : public IRVisitor { + public: + BodyVisitor() {} + using IRVisitor::Visit_; + + void Visit_(const Reduce* op) final { + auto* comm_add = op->combiner->result[0].as(); + if (comm_add == nullptr || op->combiner->result.size() > 1) { + return; + } + for (Expr source : op->source) { + auto mul_0 = unpack_type_cast(source, Float(32)).as(); + auto mul_1 = unpack_type_cast(source, Int(32)).as(); + if (mul_0 == nullptr && mul_1 == nullptr) { + continue; + } + + tensorcore_candidate_ = true; + IRVisitor::Visit(source); + } + } + + void Visit_(const Call* op) final { + IRVisitor::Visit_(op); + args_.insert(std::make_pair(op->name, op->args)); + } + + friend class ScheduleAnalyser; + + private: + std::unordered_map> args_; + bool tensorcore_candidate_{false}; +}; + +// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major +class ScheduleAnalyser { + public: + explicit ScheduleAnalyser(const MMAMatcher &mma_matcher) + : mma_sync_(mma_matcher.mma_sync_), + buf_name_(mma_matcher.buf_name_) {} + + bool MatrixIdentify(Schedule schedule) { + // TODO(minmin): handle the case where MatMul is not the output stage + for (Operation output : schedule->outputs) { + const ComputeOpNode* compute = output.as(); + if (compute == nullptr) { + // Not a ComputeOp + continue; + } + auto axis = compute->axis; + auto reduce_axis = compute->reduce_axis; + if (axis.size() < 2 || reduce_axis.size() != 1) { + continue; + } + const Variable* axis_var[2]; + const Variable* reduce_axis_var; + axis_var[0] = axis[axis.size()-2]->var.as(); + axis_var[1] = axis[axis.size()-1]->var.as(); + reduce_axis_var = reduce_axis[0]->var.as(); + + BodyVisitor body_visitor; + for (Expr expr : compute->body) { + body_visitor.Visit(expr); + } + if (!body_visitor.tensorcore_candidate_) { + continue; + } + for (auto iter : body_visitor.args_) { + auto name = iter.first; + auto args = iter.second; + if (args.size() < 2) { + continue; + } + const Variable* var0 = args[args.size() - 2].as(); + const Variable* var1 = args[args.size() - 1].as(); + if (var0 == nullptr || var1 == nullptr) { + continue; + } + std::string matrix_abc, major; + if (var0 == reduce_axis_var && var1 == axis_var[1]) { + matrix_abc = "matrix_a"; + major = "col_major"; + } else if (var0 == reduce_axis_var && var1 == axis_var[0]) { + matrix_abc = "matrix_b"; + major = "row_major"; + } else if (var0 == axis_var[1] && var1 == reduce_axis_var) { + matrix_abc = "matrix_a"; + major = "row_major"; + } else if (var0 == axis_var[0] && var1 == reduce_axis_var) { + matrix_abc = "matrix_b"; + major = "col_major"; + } + matrix_abc_.insert(std::make_pair(name, matrix_abc)); + matrix_major_.insert(std::make_pair(name, major)); + } + matrix_abc_.insert(std::make_pair(compute->name, "accumulator")); + matrix_major_.insert(std::make_pair(compute->name, "col_major")); + } + + for (auto &mma_sync : mma_sync_) { + auto &operands = mma_sync.second; + auto* load_a = operands[0].as(); + auto* load_b = operands[1].as(); + auto input0 = simplify_name(buf_name_.find(load_a)->second); + auto input1 = simplify_name(buf_name_.find(load_b)->second); + auto it0 = matrix_abc_.find(input0); + auto it1 = matrix_abc_.find(input1); + + if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) { + return false; + } + if (it0->second == "matrix_a" && it1->second == "matrix_b") { + return true; + } else if (it0->second == "matrix_b" && it1->second == "matrix_a") { + mma_sync.second = Array{operands[1], operands[0], operands[2]}; + } else { + return false; + } + } + return true; + } + + friend class BufferAnalyser; + friend class TensorCoreIRMutator; + + private: + std::unordered_map matrix_abc_; + std::unordered_map matrix_major_; + std::unordered_map> mma_sync_; + std::unordered_map buf_name_; +}; + +// IndexVisitor visits access index of fragment +// to record variable for loop scaling +class IndexVisitor : public IRVisitor { + public: + IndexVisitor() {} + using IRVisitor::Visit_; + + void Visit_(const Variable* op) final { + loop_scaling_.insert(std::make_pair(op, scaling_factor_)); + } + + friend class BufferAnalyser; + friend class TensorCoreIRMutator; + + private: + std::unordered_map loop_scaling_; + unsigned scaling_factor_{0}; +}; + +// BufferAnalyser gets buffer info, +// e.g. thread tile and warp tile, for TensorCore CodeGen +class BufferAnalyser : public IRVisitor { + public: + explicit BufferAnalyser(Map extern_buffer, + const ScheduleAnalyser &schedule_analyser, + const MMAMatcher &mma_matcher) + : matrix_abc_(schedule_analyser.matrix_abc_), + matrix_major_(schedule_analyser.matrix_major_), + frag_reg_(mma_matcher.frag_reg_) { + for (auto kv : extern_buffer) { + BufferInfo bi; + bi.name = kv.second->name; + bi.dtype = kv.second->dtype; + bi.strides = kv.second->strides; + bi.shape = kv.second->shape; + bi.external = true; + buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; + } + } + using IRVisitor::Visit_; + + void Visit_(const AttrStmt* op) final { + if (op->attr_key == attr::thread_extent) { + if (const IntImm* value = op->value.as()) { + thread_extent_.insert( + std::make_pair( + op->node.as()->var->name_hint, + value->value)); + } + IRVisitor::Visit_(op); + } else if (op->attr_key == attr::realize_scope) { + storage_scope_[op->node.get()] = op->value.as()->value; + Visit(op->body); + } else if (op->attr_key == attr::buffer_dim_align) { + Tensor tensor = Downcast(op->node); + const Call* tuple = op->value.as(); + CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); + auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}]; + size_t dim = tuple->args[0].as()->value; + if (dim >= vinfo.size()) { + vinfo.resize(dim + 1); + } + vinfo[dim].align_factor = tuple->args[1].as()->value; + vinfo[dim].align_offset = tuple->args[2].as()->value; + Visit(op->body); + } else { + IRVisitor::Visit_(op); + } + } + + void Visit_(const Provide* op) final { + IRVisitor::Visit_(op); + TensorKey key{op->func, op->value_index}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key.f; + const BufferInfo& bi = it->second; + CHECK(!bi.released) + << "Read a buffer that is already out of scope"; + + if (matrix_abc_.count(key.GetName())) { + if (bi.shape.size() < 2) { + invalid_ = true; + return; + } + for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { + const IntImm* shape = bi.shape[i].as(); + if (shape == nullptr || shape->value % 16 != 0) { + invalid_ = true; + return; + } + } + } + + Array strides; + if (bi.strides.size() > 0) { + strides = bi.strides; + } else { + for (size_t i = 1; i < bi.shape.size(); ++i) { + Expr stride = IntImm::make(Int(32), 1); + for (size_t j = bi.shape.size() - 1; j >= i; --j) { + stride = Mul::make(stride, bi.shape[j]); + } + strides.push_back(stride); + } + strides.push_back(make_const(Int(32), 1)); + } + strides_.insert(std::make_pair(key.GetName(), strides)); + + if (frag_reg_.count(bi.name)) { + Expr dst = Call::make(bi.dtype, + bi.name, + op->args, + Call::Halide, + op->func, + 0); + frag_load_.insert(std::make_pair(op, dst)); + + auto rel_index = bi.RelIndex(op->args); + if (op->args.size() < 2) { + invalid_ = true; + return; + } + std::vector tile_size; + for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) { + index_visitor.scaling_factor_ = 16; + if (const IntImm* shape = bi.shape[i].as()) { + tile_size.push_back(shape->value); + index_visitor.scaling_factor_ = shape->value; + } else { + invalid_ = true; + return; + } + auto index = rel_index[i]; + auto simplified_index = ir::Simplify(index); + index_visitor.Visit(simplified_index); + } + + std::string input_name = simplify_name(bi.name); + auto it = matrix_abc_.find(input_name); + auto it2 = matrix_major_.find(input_name); + bool ret = true; + if (it != matrix_abc_.end() && it2 != matrix_major_.end()) { + if (it->second == "matrix_a" && it2->second == "col_major") { + ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); + ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); + } + if (it->second == "matrix_a" && it2->second == "row_major") { + ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); + ret &= assign_or_check_(&thread_tile_.m, tile_size[1]); + } + if (it->second == "matrix_b" && it2->second == "col_major") { + ret &= assign_or_check_(&thread_tile_.k, tile_size[0]); + ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); + } + if (it->second == "matrix_b" && it2->second == "row_major") { + ret &= assign_or_check_(&thread_tile_.n, tile_size[0]); + ret &= assign_or_check_(&thread_tile_.k, tile_size[1]); + } + if (it->second == "accumulator") { + ret &= assign_or_check_(&thread_tile_.m, tile_size[0]); + ret &= assign_or_check_(&thread_tile_.n, tile_size[1]); + } + if (!ret) { + invalid_ = true; + return; + } + } + } + + const Call* value = op->value.as(); + if (value != nullptr && frag_reg_.count(value->name)) { + Expr dst = Call::make(bi.dtype, + bi.name, + op->args, + Call::Halide, + op->func, + 0); + frag_store_.insert(std::make_pair(op, dst)); + } + } + + void Visit_(const Call* op) final { + IRVisitor::Visit_(op); + if (op->call_type == Call::Halide) { + TensorKey key{op->func, op->value_index}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key.f; + const BufferInfo& bi = it->second; + CHECK(!bi.released) + << "Read a buffer that is already out of scope"; + + if (matrix_abc_.count(op->name)) { + if (bi.shape.size() < 2) { + invalid_ = true; + return; + } + for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { + const IntImm* shape = bi.shape[i].as(); + if (shape == nullptr || shape->value % 16 != 0) { + invalid_ = true; + return; + } + } + } + + Array strides; + if (bi.strides.size() > 0) { + strides = bi.strides; + } else { + for (size_t i = 1; i < bi.shape.size(); ++i) { + Expr stride = IntImm::make(Int(32), 1); + for (size_t j = bi.shape.size() - 1; j >= i; --j) { + stride = Mul::make(stride, bi.shape[j]); + } + strides.push_back(stride); + } + strides.push_back(make_const(Int(32), 1)); + } + strides_.insert(std::make_pair(key.GetName(), strides)); + + if (!frag_reg_.count(bi.name)) { + return; + } + + auto rel_index = bi.RelIndex(op->args); + if (op->args.size() < 2) { + invalid_ = true; + return; + } + for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) { + index_visitor.scaling_factor_ = 16; + if (const IntImm* shape = bi.shape[i].as()) { + index_visitor.scaling_factor_ = shape->value; + } + auto index = rel_index[i]; + auto simplified_index = ir::Simplify(index); + index_visitor.Visit(simplified_index); + } + } + } + + void Visit_(const Realize* op) final { + TensorKey key{op->func, op->value_index}; + if (buf_map_.count(key)) { + CHECK(buf_map_.at(key).external); + Visit(op->body); + } else { + // create a buffer entry + BufferInfo bi; + + bi.bounds = op->bounds; + Array shape; + for (auto r : bi.bounds) { + shape.push_back(r->extent); + } + + Array strides; + if (dim_align_.count(key) != 0 && shape.size() != 0) { + std::vector rstrides; + const std::vector& avec = dim_align_[key]; + int first_dim = 0; + Expr stride = make_const(shape[first_dim].type(), 1); + for (size_t i = shape.size(); i != 0; --i) { + size_t dim = i - 1; + if (dim < avec.size() && avec[dim].align_factor != 0) { + Expr factor = make_const(stride.type(), avec[dim].align_factor); + Expr offset = make_const(stride.type(), avec[dim].align_offset); + stride = stride + \ + indexmod(factor + offset - indexmod(stride, factor), factor); + stride = ir::Simplify(stride); + } + rstrides.push_back(stride); + stride = stride * shape[dim]; + } + strides = Array(rstrides.rbegin(), rstrides.rend()); + } + + bi.name = key.GetName(); + bi.dtype = op->type; + bi.strides = strides; + bi.shape = shape; + + buf_map_[key] = bi; + Visit(op->body); + buf_map_[key].released = true; + } + } + + // Derive warp tile from thread tile, + // and check whether it is qualified for TensorCore. + bool QualifiedForTensorCore() { + if (invalid_) { + return false; + } + auto itx = thread_extent_.find("threadIdx.x"); + if (itx == thread_extent_.end()) { + return false; + } + int warp_threads_x = itx->second; + warp_tile_.m = warp_threads_x * thread_tile_.m; + warp_threads_y_ = 32 / warp_threads_x; + auto ity = thread_extent_.find("threadIdx.y"); + if (ity == thread_extent_.end()) { + return false; + } + if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) { + return false; + } + warp_tile_.n = warp_threads_y_ * thread_tile_.n; + warp_tile_.k = thread_tile_.k; + return supported_warp_tile_(); + } + + friend class TensorCoreIRMutator; + + private: + struct DimAlignInfo { + int align_factor{0}; + int align_offset{0}; + }; + + struct BufferInfo { + std::string name; + Type dtype; + Array strides; + Array shape; + Region bounds; + bool external{false}; + bool released{false}; + inline Array RelIndex(Array args) const { + if (bounds.size() != 0) { + Array index; + CHECK_EQ(bounds.size(), args.size()); + for (size_t i = 0; i < bounds.size(); ++i) { + index.push_back(args[i] - bounds[i]->min); + } + return index; + } else { + return args; + } + } + }; + + bool assign_or_check_(int* dst, int src) { + if (*dst <= 0) { + *dst = src; + return true; + } + if (*dst == src) { + return true; + } + return false; + } + + bool supported_warp_tile_() { + if (warp_tile_.m == 16 && + warp_tile_.n == 16 && + warp_tile_.k == 16) { + return true; + } + if (warp_tile_.m == 8 && + warp_tile_.n == 32 && + warp_tile_.k == 16) { + return true; + } + if (warp_tile_.m == 32 && + warp_tile_.n == 8 && + warp_tile_.k == 16) { + return true; + } + return false; + } + + std::unordered_map buf_map_; + std::unordered_map > dim_align_; + std::unordered_map storage_scope_; + std::unordered_map matrix_abc_; + std::unordered_map matrix_major_; + std::unordered_set frag_reg_; + std::unordered_map> strides_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; + std::unordered_map thread_extent_; + IndexVisitor index_visitor; + Tile warp_tile_; + Tile thread_tile_; + int warp_threads_y_{-1}; + bool invalid_{false}; +}; + +// ThreadIdxMutator does the thread index unification inside a warp +class ThreadIdxMutator : public IRMutator { + public: + explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {} + + Expr Mutate_(const Variable* op, const Expr& olde) final { + Expr expr = IRMutator::Mutate_(op, olde); + op = expr.as(); + if (op != nullptr) { + if (op->name_hint == "threadIdx.x") { + Expr zero = IntImm::make(Int(32), 0); + return zero; + } + if (op->name_hint == "threadIdx.y") { + Expr div = Div::make(expr, warp_y_); + Expr mul = Mul::make(div, warp_y_); + return mul; + } + } + return expr; + } + + private: + Expr warp_y_; +}; + +// TensorCoreIRMutator mutates the AST for TensorCore CodeGen +// based on tensor core intrinsics +class TensorCoreIRMutator : public IRMutator { + public: + explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser, + const BufferAnalyser &buffer_analyser) + : matrix_abc_(schedule_analyser.matrix_abc_), + matrix_major_(schedule_analyser.matrix_major_), + mma_sync_(schedule_analyser.mma_sync_), + strides_(buffer_analyser.strides_), + frag_reg_(buffer_analyser.frag_reg_), + loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), + frag_load_(buffer_analyser.frag_load_), + frag_store_(buffer_analyser.frag_store_), + warp_tile_(buffer_analyser.warp_tile_), + warp_threads_y_(buffer_analyser.warp_threads_y_) {} + + Stmt Mutate_(const Realize* op, const Stmt& s) final { + TensorKey key{op->func, op->value_index}; + bounds_[key] = op->bounds; + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op != nullptr) { + if (!frag_reg_.count(key.GetName())) { + return stmt; + } + + auto new_extents = get_tile_size_(simplify_name(key.GetName())); + + Region new_bounds; + for (size_t i = 0; i < op->bounds.size() - 2; ++i) { + new_bounds.push_back(op->bounds[i]); + } + CHECK_GE(op->bounds.size(), 2) + << "Less than 2 dimensions for matrix " << key.GetName(); + new_bounds.push_back(Range::make_by_min_extent( + op->bounds[op->bounds.size() - 2]->min, new_extents[0])); + new_bounds.push_back(Range::make_by_min_extent( + op->bounds[op->bounds.size() - 1]->min, new_extents[1])); + + return Realize::make(op->func, op->value_index, + op->type, new_bounds, + op->condition, op->body); + } + return stmt; + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + if (op->attr_key == attr::realize_scope) { + auto node = op->node.as(); + if (node != nullptr) { + if (!frag_reg_.count(node->name)) { + return stmt; + } + + auto it = matrix_abc_.find(simplify_name(node->name)); + CHECK(it != matrix_abc_.end()) + << "Cannot find matrix info for " << node->name; + auto matrix_abc = "wmma." + it->second; + Stmt body = Mutate(op->body); + return AttrStmt::make(op->node, + op->attr_key, + matrix_abc, + body); + } + } + return stmt; + } + + Stmt Mutate_(const Provide* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + auto it = mma_sync_.find(op); + if (it != mma_sync_.end()) { + const auto &operands = it->second; + Expr a = operands[0]; + auto ca = a.as(); + Expr b = operands[1]; + auto cb = b.as(); + Expr c = operands[2]; + auto cc = c.as(); + + NodePtr buffer_node_a = make_node(); + NodePtr buffer_node_b = make_node(); + NodePtr buffer_node_c = make_node(); + + auto mma_sync_call = + [&buffer_node_a, &buffer_node_b] + (const Buffer &buffer) { + Buffer buffer_a(buffer_node_a); + Buffer buffer_b(buffer_node_b); + return Evaluate::make( + Call::make(Handle(), + intrinsic::tvm_mma_sync, + {buffer->data, buffer->elem_offset, + buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, + buffer->data, buffer->elem_offset}, + Call::Intrinsic)); + }; + + auto call_add_c = + [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer &buffer) { + return add_buffer_bind_scope_(cc, buffer_node_c, + TensorKey{cc->func, cc->value_index}, mma_sync_call, cc->type); + }; + + auto call_add_b = + [this, &cb, &buffer_node_b, &call_add_c](const Buffer &buffer) { + return add_buffer_bind_scope_(cb, buffer_node_b, + TensorKey{cb->func, cb->value_index}, call_add_c, cb->type); + }; + + return add_buffer_bind_scope_(ca, buffer_node_a, + TensorKey{ca->func, ca->value_index}, call_add_b, ca->type); + } + + auto it2 = frag_load_.find(op); + if (it2 != frag_load_.end()) { + Expr dst = it2->second; + if (op->value.as() != nullptr || + op->value.as() != nullptr) { + auto call = dst.as(); + + auto fill_fragment_call = + [this, &op](const Buffer &buffer) { + return Evaluate::make( + Call::make(Handle(), + intrinsic::tvm_fill_fragment, + {buffer->data, + warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, op->value}, + Call::Intrinsic)); + }; + + NodePtr buffer_node = make_node(); + return add_buffer_bind_scope_(call, buffer_node, + TensorKey{call->func, call->value_index}, + fill_fragment_call, call->type); + } + + const Call* value = op->value.as(); + CHECK(value != nullptr) + << "Can only load fragment from a buffer"; + + auto it = strides_.find(value->name); + CHECK(it != strides_.end()) + << "Cannot find stride for " << value->name; + auto strides = it->second; + CHECK_GE(strides.size(), 2); + Expr stride = strides[strides.size()-2]; + + // thread index unification inside a warp + Expr warp_y = IntImm::make(Int(32), warp_threads_y_); + ThreadIdxMutator thread_idx_mutator(warp_y); + Expr mutated_value = thread_idx_mutator.Mutate(op->value); + Expr src = Call::make(value->type, + "&", + {mutated_value}, + Call::Extern); + + auto call = dst.as(); + Expr matrix_major; + auto iter2 = matrix_major_.find(simplify_name(call->name)); + CHECK(iter2 != matrix_major_.end()) + << "Can not determine matrix major for " << call->name; + if (iter2->second == "col_major") { + matrix_major = StringImm::make("col_major"); + } else if (iter2->second == "row_major") { + matrix_major = StringImm::make("row_major"); + } else { + LOG(FATAL) << "invalid matrix major for " << call->name; + } + + auto load_matrix_call = + [this, &src, &stride, &matrix_major](const Buffer &buffer) { + return Evaluate::make( + Call::make(Handle(), + intrinsic::tvm_load_matrix_sync, + {buffer->data, + warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, src, stride, matrix_major}, + Call::Intrinsic)); + }; + + NodePtr buffer_node = make_node(); + return add_buffer_bind_scope_(call, buffer_node, + TensorKey{op->func, op->value_index}, + load_matrix_call, call->type); + } + + auto it3 = frag_store_.find(op); + if (it3 != frag_store_.end()) { + TensorKey key{op->func, op->value_index}; + auto it = strides_.find(key.GetName()); + CHECK(it != strides_.end()) + << "Cannot find stride for " << key.GetName(); + auto strides = it->second; + CHECK_GE(strides.size(), 2); + Expr stride = strides[strides.size()-2]; + + Expr dst = it3->second; + // thread index unification inside a warp + Expr warp_y = IntImm::make(Int(32), warp_threads_y_); + ThreadIdxMutator thread_idx_mutator(warp_y); + dst = thread_idx_mutator.Mutate(dst); + dst = Call::make(Handle(), + "&", + {dst}, + Call::Extern); + + auto call = op->value.as(); + + auto store_matrix_call = + [this, &dst, &stride](const Buffer &buffer) { + return Evaluate::make( + Call::make(Handle(), + intrinsic::tvm_store_matrix_sync, + {buffer->data, + warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, dst, stride, + StringImm::make("col_major")}, + Call::Intrinsic)); + }; + + NodePtr buffer_node = make_node(); + return add_buffer_bind_scope_(call, buffer_node, + TensorKey{call->func, call->value_index}, + store_matrix_call, call->type); + } + + return stmt; + } + + Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op != nullptr) { + auto it = loop_scaling_.find(op->loop_var.get()); + if (it != loop_scaling_.end()) { + int scale_factor = it->second; + int scaled_extent_value = 1; + if (const IntImm *ori_extent = op->extent.as()) { + int ori_extent_value = ori_extent->value; + scaled_extent_value = ori_extent_value / scale_factor; + } + Expr scaled_extent = make_const(op->extent.type(), scaled_extent_value); + stmt = For::make(op->loop_var, op->min, scaled_extent, op->for_type, + op->device_api, op->body); + } + } + return stmt; + } + + private: + Array get_tile_size_(const std::string &name) { + auto it = matrix_abc_.find(name); + auto it2 = matrix_major_.find(name); + CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) + << "Cannot find matrix info for " << name; + Expr size0 = make_const(Int(32), 16); + Expr size1 = make_const(Int(32), 16); + if (it->second == "matrix_a" && it2->second == "col_major") { + size0 = make_const(Int(32), warp_tile_.k); + size1 = make_const(Int(32), warp_tile_.m); + } + if (it->second == "matrix_a" && it2->second == "row_major") { + size0 = make_const(Int(32), warp_tile_.m); + size1 = make_const(Int(32), warp_tile_.k); + } + if (it->second == "matrix_b" && it2->second == "row_major") { + size0 = make_const(Int(32), warp_tile_.k); + size1 = make_const(Int(32), warp_tile_.n); + } + if (it->second == "matrix_b" && it2->second == "col_major") { + size0 = make_const(Int(32), warp_tile_.n); + size1 = make_const(Int(32), warp_tile_.k); + } + if (it->second == "matrix_c") { + size0 = make_const(Int(32), warp_tile_.n); + size1 = make_const(Int(32), warp_tile_.m); + } + Array tile_size = {size0, size1}; + return tile_size; + } + + Stmt add_buffer_bind_scope_(const Call* call, + const NodePtr &buffer_node, const TensorKey &key, + const std::function &call_back, + DataType datatype) { + auto it = bounds_.find(key); + CHECK(it != bounds_.end()); + Array min_bound; + for (auto i : it->second) { + min_bound.push_back(i->min); + } + + CHECK_GE(it->second.size(), 2); + Array shape; + for (size_t i = 0; i < it->second.size() - 2; ++i) { + shape.push_back(it->second[i]->extent); + } + auto tile_size = get_tile_size_(simplify_name(call->name)); + shape.push_back(tile_size[0]); + shape.push_back(tile_size[1]); + + Array strides; + for (size_t i = 1; i < shape.size(); ++i) { + Expr stride = IntImm::make(Int(32), 1); + for (size_t j = shape.size() - 1; j >= i; --j) { + stride = Mul::make(stride, shape[j]); + } + strides.push_back(stride); + } + strides.push_back(make_const(Int(32), 1)); + + Expr elem_offset = IntImm::make(Int(32), 0); + CHECK_EQ(call->args.size(), min_bound.size()); + for (size_t i = 0; i < min_bound.size(); i++) { + elem_offset = Add::make( + elem_offset, Mul::make( + strides[i], Sub::make(call->args[i], min_bound[i]))); + } + + auto it2 = matrix_abc_.find(simplify_name(call->name)); + CHECK(it2 != matrix_abc_.end()) + << "Cannot find matrix info for " << call->name; + buffer_node->data = Variable::make(Handle(), call->name); + buffer_node->name = call->name; + buffer_node->scope = "wmma." + it2->second; + buffer_node->dtype = datatype; + buffer_node->strides = strides; + buffer_node->shape = shape; + buffer_node->data_alignment = 1; + buffer_node->elem_offset = Simplify(elem_offset); + buffer_node->offset_factor = 1; + Buffer buffer(buffer_node); + + NodePtr tensor_node = make_node(); + tensor_node->value_index = key.value_index; + tensor_node->op = Downcast(key.f); + tensor_node->shape = shape; + tensor_node->dtype = datatype; + Tensor tensor(tensor_node); + + Array args; + for (size_t i = 0; i < call->args.size(); ++i) { + args.push_back(call->args[i]); + args.push_back(shape[i]); + } + auto tuple = Call::make(Handle(), + intrinsic::tvm_tuple, + args, + Call::Intrinsic); + Array node = {buffer, tensor}; + return AttrStmt::make(node, + "buffer_bind_scope", + tuple, + call_back(buffer)); + } + + std::unordered_map matrix_abc_; + std::unordered_map matrix_major_; + std::unordered_map> mma_sync_; + std::unordered_map> strides_; + std::unordered_set frag_reg_; + std::unordered_map loop_scaling_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; + std::unordered_map bounds_; + Tile warp_tile_; + int warp_threads_y_{-1}; +}; + +Stmt RewriteForTensorCore(Stmt stmt, + Schedule schedule, + Map extern_buffer) { + // Check if current lower target is CUDA + auto target = tvm::Target::Current(true); + if (target.defined() && target->target_name != "cuda") { + return stmt; + } + + // Check if current runtime support GPU CUDA + TVMContext ctx{kDLGPU, 0}; + auto api = tvm::runtime::DeviceAPI::Get(ctx, true); + if (api == nullptr) { + return stmt; + } + + MMAMatcher mma_matcher(extern_buffer); + mma_matcher.Visit(stmt); + if (!mma_matcher.Matched()) { + return stmt; + } + + ScheduleAnalyser schedule_analyser(mma_matcher); + if (!schedule_analyser.MatrixIdentify(schedule)) { + return stmt; + } + + BufferAnalyser buffer_analyser(extern_buffer, + schedule_analyser, mma_matcher); + buffer_analyser.Visit(stmt); + if (!buffer_analyser.QualifiedForTensorCore()) { + return stmt; + } + + return TensorCoreIRMutator(schedule_analyser, buffer_analyser).Mutate(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_pass_rewrite_for_tensor_core.py b/tests/python/unittest/test_pass_rewrite_for_tensor_core.py new file mode 100644 index 0000000..272d243 --- /dev/null +++ b/tests/python/unittest/test_pass_rewrite_for_tensor_core.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import topi +import numpy as np +from tvm.contrib import nvcc + +def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96): + A = tvm.placeholder((n, l), name='A', dtype='float16') + B = tvm.placeholder((l, m), name='B', dtype='float16') + k = tvm.reduce_axis((0, l), name='k') + C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k].astype('float32') * B[k, j].astype('float32'), axis=k)) + s = tvm.create_schedule(C.op) + y, x = s[C].op.axis + k = s[C].op.reduce_axis[0] + + AA = s.cache_read(A, "shared", [C]) + AL = s.cache_read(AA, "local", [C]) + BB = s.cache_read(B, "shared", [C]) + BL = s.cache_read(BB, "local", [C]) + CL = s.cache_write(C, "local") + + bx = 4 + by = 32 + step_k = 8 + v = 4 + TX = 8 + TY = 1 + tile_x = bx * TX + tile_y = by * TY + WX = min(warp_tile_m, tile_x) + tile_k = 16 + vthread = 1 + + yo, ty = s[C].split(y, tile_y*vthread) + vy, ty = s[C].split(ty, tile_y) + ty, yi = s[C].split(ty, TY) + + xo, xi = s[C].split(x, tile_x) + tz, xi = s[C].split(xi, WX) + tx, xi = s[C].split(xi, TX) + ko, ki = s[CL].split(k, step_k * tile_k) + kl, ki = s[CL].split(ki, tile_k) + + s[C].reorder(yo, xo, tz, ty, tx, yi, xi) + s[C].bind(yo, tvm.thread_axis("blockIdx.y")) + s[C].bind(xo, tvm.thread_axis("blockIdx.x")) + s[C].bind(ty, tvm.thread_axis("threadIdx.y")) + s[C].bind(tz, tvm.thread_axis("threadIdx.z")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + s[C].bind(vy, tvm.thread_axis((0, vthread), "vthread", name="vy")) + s[CL].compute_at(s[C], tx) + yo, xo = CL.op.axis + s[CL].reorder(ko, kl, ki, yo, xo) + + s[AA].compute_at(s[CL], ko) + xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v) + tz, tx = s[AA].split(xi, factor=(WX//TX)*v) + tx, vec = s[AA].split(tx, factor=v) + fused = s[AA].fuse(s[AA].op.axis[0], xo) + _, ty = s[AA].split(fused, factor=by) + s[AA].bind(ty, tvm.thread_axis("threadIdx.y")) + s[AA].bind(tz, tvm.thread_axis("threadIdx.z")) + s[AA].bind(tx, tvm.thread_axis("threadIdx.x")) + s[AA].vectorize(vec) + + s[BB].compute_at(s[CL], ko) + xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v) + tz, tx = s[BB].split(xi, factor=(WX//TX)*v) + tx, vec = s[BB].split(tx, factor=v) + fused = s[BB].fuse(s[BB].op.axis[0], xo) + _, ty = s[BB].split(fused, factor=by) + s[BB].bind(ty, tvm.thread_axis("threadIdx.y")) + s[BB].bind(tz, tvm.thread_axis("threadIdx.z")) + s[BB].bind(tx, tvm.thread_axis("threadIdx.x")) + s[BB].vectorize(vec) + + s[AL].compute_at(s[CL], kl) + s[BL].compute_at(s[CL], kl) + + s[CL].pragma(ko, 'tensor_core') + + func = tvm.build(s, [A, B, C], 'cuda') + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=(n, l)).astype(A.dtype) + b_np = np.random.uniform(size=(l, m)).astype(B.dtype) + c_np = np.zeros((n, m), dtype=np.float32) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) + func(a, b, c) + evaluator = func.time_evaluator(func.entry_name, ctx, number=3) + print('gemm m=%d n=%d k=%d: %f ms' % (m, n, l, evaluator(a, b, c).mean * 1e3)) + + c_np = np.dot(a_np, b_np) + np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3) + +def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2): + A = tvm.placeholder((batch, n, l), name='A', dtype='float16') + B = tvm.placeholder((batch, l, m), name='B', dtype='float16') + k = tvm.reduce_axis((0, l), name='k') + C = tvm.compute((batch, n, m), lambda b, i, j: tvm.sum((A[b, i, k] * B[b, k, j]).astype('float32'), axis=k)) + s = tvm.create_schedule(C.op) + z, y, x = s[C].op.axis + k = s[C].op.reduce_axis[0] + + AA = s.cache_read(A, "shared", [C]) + AL = s.cache_read(AA, "local", [C]) + BB = s.cache_read(B, "shared", [C]) + BL = s.cache_read(BB, "local", [C]) + CL = s.cache_write(C, "local") + + bx = 2 + by = 32 + step_k = 8 + v = 4 + TX = 8 + TY = 1 + tile_x = bx * TX + tile_y = by * TY + WX = min(warp_tile_m, tile_x) + tile_k = 16 + vthread = 1 + + yo, ty = s[C].split(y, tile_y*vthread) + vy, ty = s[C].split(ty, tile_y) + ty, yi = s[C].split(ty, TY) + + xo, xi = s[C].split(x, tile_x) + tz, xi = s[C].split(xi, WX) + tx, xi = s[C].split(xi, TX) + ko, ki = s[CL].split(k, step_k * tile_k) + kl, ki = s[CL].split(ki, tile_k) + + s[C].reorder(z, yo, xo, tz, ty, tx, yi, xi) + s[C].bind(z, tvm.thread_axis("blockIdx.z")) + s[C].bind(yo, tvm.thread_axis("blockIdx.y")) + s[C].bind(xo, tvm.thread_axis("blockIdx.x")) + s[C].bind(ty, tvm.thread_axis("threadIdx.y")) + s[C].bind(tz, tvm.thread_axis("threadIdx.z")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + s[C].bind(vy, tvm.thread_axis((0, vthread), "vthread", name="vy")) + s[CL].compute_at(s[C], tx) + zo, yo, xo = CL.op.axis + s[CL].reorder(ko, kl, ki, zo, yo, xo) + + s[AA].compute_at(s[CL], ko) + xo, xi = s[AA].split(s[AA].op.axis[2], factor=bx*v) + tz, tx = s[AA].split(xi, factor=(WX//TX)*v) + tx, vec = s[AA].split(tx, factor=v) + fused = s[AA].fuse(s[AA].op.axis[1], xo) + _, ty = s[AA].split(fused, factor=by) + s[AA].bind(ty, tvm.thread_axis("threadIdx.y")) + s[AA].bind(tz, tvm.thread_axis("threadIdx.z")) + s[AA].bind(tx, tvm.thread_axis("threadIdx.x")) + s[AA].vectorize(vec) + + s[BB].compute_at(s[CL], ko) + xo, xi = s[BB].split(s[BB].op.axis[2], factor=bx*v) + tz, tx = s[BB].split(xi, factor=(WX//TX)*v) + tx, vec = s[BB].split(tx, factor=v) + fused = s[BB].fuse(s[BB].op.axis[1], xo) + _, ty = s[BB].split(fused, factor=by) + s[BB].bind(ty, tvm.thread_axis("threadIdx.y")) + s[BB].bind(tz, tvm.thread_axis("threadIdx.z")) + s[BB].bind(tx, tvm.thread_axis("threadIdx.x")) + s[BB].vectorize(vec) + + s[AL].compute_at(s[CL], kl) + s[BL].compute_at(s[CL], kl) + + s[CL].pragma(ko, 'tensor_core') + + func = tvm.build(s, [A, B, C], 'cuda') + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=(batch, n, l)).astype(A.dtype) + b_np = np.random.uniform(size=(batch, l, m)).astype(B.dtype) + c_np = np.zeros((batch, n, m), dtype=np.float32) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros((batch, n, m), dtype=C.dtype), ctx) + func(a, b, c) + evaluator = func.time_evaluator(func.entry_name, ctx, number=3) + print('batch gemm m=%d n=%d k=%d batch=%d: %f ms' % (m, n, l, batch, evaluator(a, b, c).mean * 1e3)) + + for bs in range(batch): + c_np[bs, :, :] = np.dot(a_np[bs, :, :], b_np[bs, :, :]) + np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3) + +def test_tensor_core_matmul(): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + print("skip because gpu does not support tensor core") + return + + tensor_core_matmul(16) #test with warp_tile 16x16x16 + tensor_core_matmul(8) #test with warp_tile 8x32x16 + tensor_core_matmul(32) #test with warp_tile 32x8x16 + +def test_tensor_core_batch_matmul(): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + print("skip because gpu does not support tensor core") + return + + tensor_core_batch_matmul() + +if __name__ == '__main__': + test_tensor_core_matmul() + test_tensor_core_batch_matmul() diff --git a/tutorials/optimize/opt_matmul_auto_tensorcore.py b/tutorials/optimize/opt_matmul_auto_tensorcore.py new file mode 100644 index 0000000..00dbe9d --- /dev/null +++ b/tutorials/optimize/opt_matmul_auto_tensorcore.py @@ -0,0 +1,465 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _opt-matmul-auto-tensorcore: + +How to optimize matmul with Auto TensorCore CodeGen +================================== +**Author**: `Minmin Sun `_, \ + `Lanbo Li `_, \ + `Chenfan Jia `_, \ + `Jun Yang `_ + +In this tutorial, we will demonstrate how to write a high performance matmul +schedule on Volta/Turing GPUs with TVM Auto TensorCore CodeGen. +This is a transparent solution to generate tensorcore kernel +with most transformations done in ir passes. +Users can also write schedule with tensorization to generate TensorCore code. +Both solutions use the same tensorcore intrinsics. +Please refer to :ref:`opt-conv-tensorcore` tutorial for more details. + +""" + +################################################################ +# Preparation and Algorithm +# -------------------------- +# 2 kinds of input data types are supported: float16 and int8. +# For float16, the accumulator is float32. +# For int8, the accumulator is int32. +# For data layouts, 'N' means None-transpose while 'T' means Transpose. + +import logging +import sys + +import numpy as np +import tvm + +from tvm import autotvm +from tvm.contrib import nvcc + +def matmul_nn(A, B, L, dtype='float16', layout='NN'): + k = tvm.reduce_axis((0, L), name='k') + if dtype == 'float16': + out_type = 'float' + elif dtype == 'int8': + out_type = 'int' + if (layout == 'NN'): + return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k)) + if (layout == 'NT'): + return tvm.compute((N, M), lambda i, j: tvm.sum(A[k, i].astype(out_type) * B[k, j].astype(out_type), axis=k)) + if (layout == 'TN'): + return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[j, k].astype(out_type), axis=k)) + if (layout == 'TT'): + return tvm.compute((N, M), lambda i, j: tvm.sum(A[k, i].astype(out_type) * B[j, k].astype(out_type), axis=k)) + +############################################################################### +# Scheduling the Computation +# -------------------------- +# This schedule is no different than a non-tensorcore matmul schedule on GPU. +# Please refer to :ref:`opt-gemm` tutorial for basics of optimizing matmul schedule. +# When the "tensor_core" pragma is set, the "rewrite for tensorcore" ir pass +# will automatically transform the schedule for tensorcore codegen, +# otherwise normal CUDA code, with lower performance but equal functionality, will be generated. +# +# .. note:: +# +# *Requirements of TesnsorCore* +# +# Note that in the following 2 cases, even though the "tensor_core" pragma is set, TVM will still fall back to normal CUDA codegen: +# (1) The m, n or k of input matrices is not multiple of 16; +# (2) The warp tile size is not 16x16x16 on CUDA9, or not one of {16x16x16, 32x8x16, 8x32x16} on CUDA version >= 10.0. +# +# In this schedule, storage_align is used to reduce bank conflicts of shared memory. Please refer to this +# `doc `_ +# for the usage of storage_align primitive. In short, we need to add an offset to some shared memory buffer +# to reduce bank conflicts. +# According to the `wmma doc `_, +# the stride of load_matrix_sync must be a multiple of 16 bytes, +# so we choose 8 as offset for float16 and 16 as offset for int8. +# +# We use AutoTVM to search for best configurations in this schedule. + +@autotvm.template +def test_gemm(N, L, M, dtype, layout): + if (layout == "NN"): + shape_a = (N, L) + shape_b = (L, M) + elif (layout == "NT"): + shape_a = (L, N) + shape_b = (L, M) + elif (layout == "TN"): + shape_a = (N, L) + shape_b = (M, L) + elif (layout == "TT"): + shape_a = (L, N) + shape_b = (M, L) + else: + print ("Unsupported layout:", layout) + sys.exit(1); + A = tvm.placeholder(shape_a, name='A', dtype=dtype) + B = tvm.placeholder(shape_b, name='B', dtype=dtype) + C = matmul_nn(A, B, L, dtype, layout) + + s = tvm.create_schedule(C.op) + y, x = s[C].op.axis + k = s[C].op.reduce_axis[0] + + # storage_align params + factor = 16 + offset = 8 + if dtype == 'int8': + factor = 32 + offset = 16 + + # create cache stages + AA = s.cache_read(A, "shared", [C]) + if (layout == "NN" or layout == "TN"): + s[AA].storage_align(AA.op.axis[0], factor, offset) + AL = s.cache_read(AA, "local", [C]) + BB = s.cache_read(B, "shared", [C]) + if (layout == "TT" or layout == "NT"): + s[BB].storage_align(BB.op.axis[0], factor, offset) + BL = s.cache_read(BB, "local", [C]) + CL = s.cache_write(C, "local") + + #autotvm search space definition + cfg = autotvm.get_config() + + cfg.define_knob("bx", [2, 4, 8]) + cfg.define_knob("by", [16, 32, 64]) + cfg.define_knob("step_k", [8, 16, 32]) + cfg.define_knob("v", [4, 8]) + by = cfg['by'].val + bx = cfg['bx'].val + step_k = cfg['step_k'].val + v = cfg['v'].val + + # thread tile + TX = 8 + TY = 1 + # warp tile + warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0 + warp_tile_k = 16 # it must be 16 + # block tile + tile_x = bx * TX + tile_y = by * TY + + yo, ty = s[C].split(y, tile_y) + ty, yi = s[C].split(ty, TY) + + # schedule for C stage + xo, xi = s[C].split(x, tile_x) + WX = min(warp_tile_m, tile_x) + tz, xi = s[C].split(xi, WX) + tx, xi = s[C].split(xi, TX) + s[C].reorder(yo, xo, tz, ty, tx, yi, xi) + s[C].bind(yo, tvm.thread_axis("blockIdx.y")) + s[C].bind(xo, tvm.thread_axis("blockIdx.x")) + s[C].bind(ty, tvm.thread_axis("threadIdx.y")) + s[C].bind(tz, tvm.thread_axis("threadIdx.z")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + + # schedule for CL stage + ko, ki = s[CL].split(k, step_k * warp_tile_k) + kl, ki = s[CL].split(ki, warp_tile_k) + s[CL].compute_at(s[C], tx) + yo, xo = CL.op.axis + s[CL].reorder(ko, kl, ki, yo, xo) + + # schedule for AA stage + s[AA].compute_at(s[CL], ko) + xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v) + tz, tx = s[AA].split(xi, factor=(WX//TX)*v) + tx, vec = s[AA].split(tx, factor=v) + fused = s[AA].fuse(s[AA].op.axis[0], xo) + _, ty = s[AA].split(fused, factor=by) + s[AA].bind(ty, tvm.thread_axis("threadIdx.y")) + s[AA].bind(tz, tvm.thread_axis("threadIdx.z")) + s[AA].bind(tx, tvm.thread_axis("threadIdx.x")) + # vectorization is very important for float16/int8 inputs + s[AA].vectorize(vec) + + # schedule for BB stage + s[BB].compute_at(s[CL], ko) + xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v) + tz, tx = s[BB].split(xi, factor=(WX//TX)*v) + tx, vec = s[BB].split(tx, factor=v) + fused = s[BB].fuse(s[BB].op.axis[0], xo) + _, ty = s[BB].split(fused, factor=by) + s[BB].bind(ty, tvm.thread_axis("threadIdx.y")) + s[BB].bind(tz, tvm.thread_axis("threadIdx.z")) + s[BB].bind(tx, tvm.thread_axis("threadIdx.x")) + s[BB].vectorize(vec) + + s[AL].compute_at(s[CL], kl) + s[BL].compute_at(s[CL], kl) + + # set the 'tensor_core' pragma for tensorcore codegen + s[CL].pragma(ko, 'tensor_core') + + return s, [A, B, C] + +############################################################################### +# AutoTune and Test +# -------------------- +# Finally we use a tuner to tune the schedule, generate code with best config +# and run the kernel to compare with numpy to check whether the results are correct. + +# check whether the gpu has tensorcore +ctx = tvm.gpu() +if not nvcc.have_tensorcore(ctx.compute_version): + print('the gpu has no tensorcore, skipping...') + sys.exit(0) + +M, N, L = 512, 32, 512 +dtype = 'float16' +layout = 'NN' +if len(sys.argv) >= 4: + M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]) +if len(sys.argv) >= 5: + dtype = sys.argv[4] +if len(sys.argv) >= 6: + layout = sys.argv[5] + +def tune_and_evaluate(M, N, L, dtype, layout): + task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda') + print(task.config_space) + + logging.getLogger('autotvm').setLevel(logging.DEBUG) + logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) + + measure_option = autotvm.measure_option( + builder='local', + runner=autotvm.LocalRunner(number=5)) + + tuner = autotvm.tuner.XGBTuner(task) + tuner.tune(n_trial=1000, + measure_option=measure_option, + callbacks=[autotvm.callback.log_to_file('matmul.log')]) + + dispatch_context = autotvm.apply_history_best("matmul.log") + best_config = dispatch_context.query(task.target, task.workload) + print("\nBest config:") + print(best_config) + with autotvm.apply_history_best('matmul.log'): + with tvm.target.create("cuda"): + with tvm.build_config(): + s, arg_bufs = test_gemm(N, L, M, dtype, layout) + print(tvm.lower(s, arg_bufs, simple_mode=True)) + func = tvm.build(s, arg_bufs) + dev_module = func.imported_modules[0] + print(dev_module.get_source()) + + # check correctness + if (layout == "NN"): + shape_a = (N, L) + shape_b = (L, M) + elif (layout == "NT"): + shape_a = (L, N) + shape_b = (L, M) + elif (layout == "TN"): + shape_a = (N, L) + shape_b = (M, L) + elif (layout == "TT"): + shape_a = (L, N) + shape_b = (M, L) + + a_np = None + b_np = None + c_np = None + c_np_type = None + if dtype == 'float16': + c_np_type = np.float32 + a_np = np.random.uniform(size=shape_a).astype(np.float16) + b_np = np.random.uniform(size=shape_b).astype(np.float16) + if (layout == "NN"): + c_np = np.dot(a_np, b_np) + elif (layout == "NT"): + c_np = np.dot(a_np.T, b_np) + elif (layout == "TN"): + c_np = np.dot(a_np, b_np.T) + elif (layout == "TT"): + c_np = np.dot(a_np.T, b_np.T) + elif dtype == 'int8': + c_np_type = np.int32 + a_np = np.random.randint(low=-128, high=127, size=shape_a).astype(np.int8) + b_np = np.random.randint(low=-128, high=127, size=shape_b).astype(np.int8) + if (layout == "NN"): + c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32)) + elif (layout == "NT"): + c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32)) + elif (layout == "TN"): + c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T) + elif (layout == "TT"): + c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T) + + c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx) + a_tvm = tvm.nd.array(a_np, ctx=ctx) + b_tvm = tvm.nd.array(b_np, ctx=ctx) + func(a_tvm, b_tvm, c_tvm) + + tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3) + + evaluator = func.time_evaluator(func.entry_name, ctx, number=100) + print('Time cost of this operator: %f' % evaluator(a_tvm, b_tvm, c_tvm).mean) + +# We do not run the tuning in our webpage server since it takes some time. +# Uncomment the following line to run it by yourself. + +# tune_and_evaluate(M, N, L, dtype, layout) + +###################################################################### +# Sample Output +# ------------- +# .. code-block:: bash +# +# Best config: +# [('bx', 4), ('by', 32), ('step_k', 16), ('v', 8)],,None,40 +# Finish loading 162 records +# produce compute { +# // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1 +# // attr [compute.local] storage_scope = "wmma.accumulator" +# allocate compute.local[float32 * 256] +# // attr [A.shared] storage_scope = "shared" +# allocate A.shared[float16 * 8448] +# // attr [B.shared] storage_scope = "shared" +# allocate B.shared[float16 * 8192] +# // attr [A.shared.local] storage_scope = "wmma.matrix_b" +# allocate A.shared.local[float16 * 256] +# // attr [B.shared.local] storage_scope = "wmma.matrix_a" +# allocate B.shared.local[float16 * 256] +# // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16 +# // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2 +# // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32 +# // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2 +# produce compute.local { +# for (j.c.init, 0, 1) { +# tvm_fill_fragment(compute.local, 16, 16, 16, 0, 0f) +# } +# // attr [iter_var(k.outer, )] pragma_tensor_core = 1 +# for (k.outer, 0, 2) { +# produce A.shared { +# for (ax0.ax1.outer.fused.outer, 0, 8) { +# // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32 +# // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2 +# // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2 +# A.shared[ramp((((((ax0.ax1.outer.fused.outer*1056) + (floordiv(threadIdx.y, 8)*264)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = A[ramp(((((((ax0.ax1.outer.fused.outer*2048) + (floordiv(threadIdx.y, 8)*512)) + (k.outer*256)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] +# } +# } +# produce B.shared { +# for (ax0.ax1.outer.fused.outer, 0, 8) { +# // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32 +# // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2 +# // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2 +# B.shared[ramp(((((ax0.ax1.outer.fused.outer*1024) + (threadIdx.y*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = B[ramp(((((((k.outer*131072) + (ax0.ax1.outer.fused.outer*16384)) + (threadIdx.y*512)) + (blockIdx.x*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] +# } +# } +# for (k.inner.outer, 0, 16) { +# produce A.shared.local { +# for (ax1, 0, 1) { +# tvm_load_matrix_sync(A.shared.local, 16, 16, 16, 0, &(A.shared[(((threadIdx.y/16)*4224) + (k.inner.outer*16))]), 264, "col_major") +# } +# } +# produce B.shared.local { +# for (ax0, 0, 1) { +# for (ax1, 0, 1) { +# tvm_load_matrix_sync(B.shared.local, 16, 16, 16, 0, &(B.shared[((k.inner.outer*512) + (threadIdx.z*16))]), 32, "col_major") +# } +# } +# } +# for (k.inner.inner, 0, 1) { +# for (j.c, 0, 1) { +# tvm_mma_sync(compute.local, 0, B.shared.local, 0, A.shared.local, 0, compute.local, 0) +# } +# } +# } +# } +# } +# for (j.inner.inner.inner, 0, 1) { +# tvm_store_matrix_sync(compute.local, 16, 16, 16, 0, &(compute[((((threadIdx.y/16)*8192) + (blockIdx.x*32)) + (threadIdx.z*16))]), 512, "col_major") +# } +# } +# +# #include +# __device__ half max(const half a, const half b) +# { +# return __hgt(__half(a), __half(b)) ? a : b; +# } +# __device__ half min(const half a, const half b) +# { +# return __hlt(__half(a), __half(b)) ? a : b; +# } +# __device__ half operator+(const volatile __half &a, const volatile __half &b) +# { +# return __hadd(a, b); +# } +# __device__ half operator<=(const volatile __half &a, const volatile __half &b) +# { +# return __hlt(a, b); +# } +# __device__ half operator*(const volatile __half &a, const volatile __half &b) +# { +# return __hmul(a, b); +# } +# #include +# extern "C" __global__ void default_function_kernel0( half* __restrict__ A, half* __restrict__ B, float* __restrict__ compute) { +# nvcuda::wmma::fragment compute_local[1]; +# __shared__ half A_shared[8448]; +# __shared__ half B_shared[8192]; +# nvcuda::wmma::fragment A_shared_local[1]; +# nvcuda::wmma::fragment B_shared_local[1]; +# for (int j_c_init = 0; j_c_init < 1; ++j_c_init) { +# (void)nvcuda::wmma::fill_fragment(compute_local[0], 0.000000e+00f); +# } +# for (int k_outer = 0; k_outer < 2; ++k_outer) { +# __syncthreads(); +# for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 8; ++ax0_ax1_outer_fused_outer) { +# ((__shared__ float4*)(A_shared + (((((ax0_ax1_outer_fused_outer * 1056) + ((((int)threadIdx.y) >> 3) * 264)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(A + ((((((ax0_ax1_outer_fused_outer * 2048) + ((((int)threadIdx.y) >> 3) * 512)) + (k_outer * 256)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0]; +# } +# for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 8; ++ax0_ax1_outer_fused_outer1) { +# ((__shared__ float4*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 1024) + (((int)threadIdx.y) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(B + ((((((k_outer * 131072) + (ax0_ax1_outer_fused_outer1 * 16384)) + (((int)threadIdx.y) * 512)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0]; +# } +# __syncthreads(); +# for (int k_inner_outer = 0; k_inner_outer < 16; ++k_inner_outer) { +# for (int ax1 = 0; ax1 < 1; ++ax1) { +# (void)nvcuda::wmma::load_matrix_sync(A_shared_local[0], &(A_shared[(((((int)threadIdx.y) / 16) * 4224) + (k_inner_outer * 16))]), 264); +# } +# for (int ax0 = 0; ax0 < 1; ++ax0) { +# for (int ax11 = 0; ax11 < 1; ++ax11) { +# (void)nvcuda::wmma::load_matrix_sync(B_shared_local[0], &(B_shared[((k_inner_outer * 512) + (((int)threadIdx.z) * 16))]), 32); +# } +# } +# for (int k_inner_inner = 0; k_inner_inner < 1; ++k_inner_inner) { +# for (int j_c = 0; j_c < 1; ++j_c) { +# (void)nvcuda::wmma::mma_sync(compute_local[0], B_shared_local[0], A_shared_local[0], compute_local[0]); +# } +# } +# } +# } +# for (int j_inner_inner_inner = 0; j_inner_inner_inner < 1; ++j_inner_inner_inner) { +# (void)nvcuda::wmma::store_matrix_sync(&(compute[((((((int)threadIdx.y) / 16) * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16))]), compute_local[0], 512, nvcuda::wmma::mem_col_major); +# } +# } +# +# +# Time cost of this operator: 0.000008 + +############################################################################### +# Summary +# -------------------------- +# This tutorial demonstrates how to use the AutoTensorCoreCodeGen of TVM +# to generate tensorcore kernels.