From: Siyuan Feng Date: Thu, 24 Oct 2019 19:04:37 +0000 (-0700) Subject: TensorCore Support using Intrinsic (#4136) X-Git-Tag: upstream/0.7.0~1738 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=324a9607eb563f81e55fdd0c9d078c2f74651817;p=platform%2Fupstream%2Ftvm.git TensorCore Support using Intrinsic (#4136) * add tensor core support * avoid memory bank conflict * fix thread sync & better performance * better performance * add schedule test for conv2d * extend into BatchMatMul * support config fragment shape and layout using intrinsic * add TensorCore tutorial * add int support and fix lint * address comment * add 32*16*8 TensorCore test * fix wmma include logic --- diff --git a/include/tvm/ir.h b/include/tvm/ir.h index b908049..37718fe 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1311,6 +1311,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope"; constexpr const char* device_scope = "device_scope"; /*! + * \brief Mark that the shape of TensorCore fragment + */ +constexpr const char* fragment_shape = "fragment_shape"; + +/*! + * \brief Mark that the layout of TensorCore fragment + */ +constexpr const char* fragment_layout = "fragment_layout"; + +/*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared * \return true if it is a pragma key @@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; * } */ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; +/*! + * \brief tvm intrinsic for tensor core load operators. + * + * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment. + * // Determine fragment layout(column-major or row major) by layout. + * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. + * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); + * } + */ +constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; +/*! + * \brief tvm intrinsic for tensor core mma_sync operators. + * + * void tvm_mma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +constexpr const char* tvm_mma_sync = "tvm_mma_sync"; +/*! + * \brief tvm intrinsic for tensor core fill_fragment operators. + * + * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr value) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::fill_fragment(fragment[index], value); + * } + */ +constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; +/*! + * \brief tvm intrinsic for tensor core store operators. + * + * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); + * } + */ +constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; } // namespace intrinsic diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 03078b8..842c6af 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -514,6 +514,15 @@ LoweredFunc CombineContextCall(LoweredFunc f); LoweredFunc PointerValueTypeRewrite(LoweredFunc f); /*! + * \brief Lower attached storage access information on device. + * Do this pass after all storage access analysis finish. + * + * \param func The device function to be lowered. + * \return Transformed function. + */ +LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func); + +/*! * \brief Lower intrinsic function calls. * \param f The device function to be lowered. * \param target The target device. @@ -533,6 +542,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); /*! + * \brief Infer the TensorCore fragment infomation using tensor intrinsics + * + * \param f The device function to be lowered. + * \return Transformed function. + */ +LoweredFunc InferFragment(LoweredFunc f); + +/*! * \brief Verify if memory accesses are legal for a specific target device type. * * In the case that tgt is cuda, if not all workload is bound with diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 479bdce..fe2f641 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -413,7 +413,6 @@ def lower(sch, # Phase 3 stmt = ir_pass.Simplify(stmt) - stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) @@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host): func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "warp") + func = ir_pass.InferFragment(func) warp_size = target.thread_warp_size func = ir_pass.LowerThreadAllreduce(func, warp_size) fsplits = [s for s in ir_pass.SplitHostDevice(func)] @@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host): assert not fdevice target_host = _target.create(target_host) + fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] + fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index dd0415a..d7f621f 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") }); }); +TVM_REGISTER_API("ir_pass.LowerStorageAccess") +.set_body([](TVMArgs args, TVMRetValue *ret) { + LoweredFunc f = args[0]; + auto n = make_node(*f.operator->()); + n->body = LowerStorageAccessInfo(f->body); + *ret = LoweredFunc(n); +}); + // make from two arguments #define REGISTER_PASS(PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \ @@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice); REGISTER_PASS(StorageRewrite); REGISTER_PASS(CoProcSync); REGISTER_PASS(LowerStorageAccessInfo); +REGISTER_PASS(LowerDeviceStorageAccessInfo) REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectDoubleBuffer); @@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); +REGISTER_PASS(InferFragment) } // namespace ir } // namespace tvm diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 66340e9..cfcb060 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch, // Phase 2 stmt = ir::Simplify(stmt); - stmt = ir::LowerStorageAccessInfo(stmt); stmt = ir::RemoveNoOp(stmt); if (!(config->disable_select_rewriting)) @@ -517,6 +516,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::BindDeviceType(func, target->device_type); + func = ir::LowerDeviceStorageAccessInfo(func); func = ir::LowerTVMBuiltin(func); fhost.Set(i, func); } @@ -524,6 +524,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::LowerIntrin(func, target_host->target_name); + func = ir::LowerDeviceStorageAccessInfo(func); func = ir::CombineContextCall(func); fhost.Set(i, func); } diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 241310f..55b4810 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_mma_h_) { + decl_stream << "#include \n"; + } + return CodeGenC::Finish(); } @@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) bool fail = false; if (t.is_float()) { switch (t.bits()) { - case 16: os << "half"; + case 16: enable_fp16_ = true; + if (lanes == 1) { + os << "half"; + } else if (lanes <= 8) { + CHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "float" << lanes / 2; + } else { + fail = true; + } break; case 32: os << "float"; break; case 64: os << "double"; break; default: fail = true; break; } - if (!fail && lanes == 1) return; + if (!fail && (lanes == 1 || t.bits() == 16)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } @@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope( } } +void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { + if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 6U); + os << "nvcuda::wmma::fill_fragment("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::load_matrix_sync("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[6], os); + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::store_matrix_sync("; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[6], os); + if (const StringImm *str = op->args[7].as()) { + os << ", nvcuda::wmma::mem_" << str->value; + } else { + LOG(FATAL) << "Invalid parameters"; + } + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::mma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", ": ")"); + } + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenCUDA::VisitStmt_(const AttrStmt* op) { + if (op->attr_key == attr::fragment_shape) { + const Variable* buffer = op->node.as(); + const StringImm* shape_str = op->value.as(); + fragment_shapes[buffer] = shape_str->value; + } else if (op->attr_key == attr::fragment_layout) { + const Variable* buffer = op->node.as(); + const StringImm* layout_str = op->value.as(); + fragment_layouts[buffer] = layout_str->value; + } + CodeGenC::VisitStmt_(op); +} + +void CodeGenCUDA::VisitStmt_(const Allocate* op) { + CHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + if (op->new_expr.defined()) { + // Prefer global static allocation for the program + CHECK_EQ(op->free_function, "nop"); + std::string new_data = PrintExpr(op->new_expr); + this->PrintIndent(); + PrintType(op->type, stream); + stream << "* "<< vid << '=' << new_data << ";\n"; + } else { + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + const Variable* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + if (scope.find("wmma.") == 0) { + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8)) + << "Matrix_a and matrix_b only support half or char or unsigned char type for now"; + } else { + CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32)) + << "Accumulator only support half, float and int type for now"; + } + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + PrintWmmaScope(scope, op->type, buffer, stream); + } else { + PrintStorageScope(scope, stream); + stream << ' '; + PrintType(op->type, stream); + } + stream << ' '<< vid << '[' + << constant_size << "];\n"; + } + RegisterHandleType(op->buffer_var.get(), op->type); + this->PrintStmt(op->body); +} + void CodeGenCUDA::VisitStmt_(const Evaluate *op) { if (is_const(op->value)) return; const Call* call = op->value.as(); @@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } +void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, + const Variable* variable, std::ostream &os) { + std::stringstream type; + PrintType(t, type); + std::string shape_str = fragment_shapes[variable]; + if (scope == "wmma.matrix_a") { + need_mma_h_ = true; + std::string layout_str = fragment_layouts[variable]; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.matrix_b") { + need_mma_h_ = true; + std::string layout_str = fragment_layouts[variable]; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.accumulator") { + need_mma_h_ = true; + os << "nvcuda::wmma::fragment"; + } +} + +int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, + const Variable* variable, int32_t size) { + std::string shape_str = fragment_shapes[variable]; + size_t m, n, k; + size_t last_pos = 0, pos = 0; + pos = shape_str.find(", ", last_pos); + m = std::stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + pos = shape_str.find(", ", last_pos); + n = std::stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); + if (scope == "wmma.matrix_a") { + return size / m / k; + } else if (scope == "wmma.matrix_b") { + return size / n / k; + } else if (scope == "wmma.accumulator") { + return size / m / n; + } + return 0; +} + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 61c6fa3..53e7db4 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "codegen_c.h" namespace tvm { @@ -40,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC { void AddFunction(LoweredFunc f); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_int8_ || need_math_constants_h_); + return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior void VisitStmt_(const ir::For* op) final; @@ -60,7 +61,10 @@ class CodeGenCUDA final : public CodeGenC { void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImm *op, std::ostream& os) final; + void VisitExpr_(const Call *op, std::ostream& os) final; void VisitStmt_(const Evaluate *op) final; + void VisitStmt_(const Allocate *op) final; + void VisitStmt_(const AttrStmt *op) final; private: // Whether global barrier is needed. @@ -75,7 +79,14 @@ class CodeGenCUDA final : public CodeGenC { bool enable_int8_{false}; // whether need math_constants.h bool need_math_constants_h_{false}; + // whether need mma.h + bool need_mma_h_{false}; + + std::unordered_map fragment_shapes; + std::unordered_map fragment_layouts; friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); + void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os); + int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size); }; } // namespace codegen diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc new file mode 100644 index 0000000..d9c0ef0 --- /dev/null +++ b/src/pass/infer_fragment.cc @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \brief Infer TensorCore metadata from tensor intrinsic. + * \file tensorcore_fragment.cc + */ +#include +#include +#include +#include +#include +#include +#include "ir_util.h" +#include "storage_access.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +// Get fragment information from tensor intrinsics +class FragmentGetter : public IRVisitor { + public: + // fragment metadata + struct FragmentInfo { + // fragment shape + int m, n, k; + // fragment layout (row-major or column-major) + std::string layout; + FragmentInfo() = default; + FragmentInfo(int _m, int _n, int _k, const std::string& _layout) + : m(_m), n(_n), k(_k), layout(_layout) {} + }; + + void Visit_(const Call* op) final { + IRVisitor::Visit_(op); + + if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || + op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + // Get shape and layout information from load and store intrinsic + CHECK_EQ(op->args.size(), 8U); + const Variable* buffer_var = op->args[0].as(); + CHECK(buffer_var); + // Get shape + const IntImm* m = op->args[1].as(); + const IntImm* n = op->args[2].as(); + const IntImm* k = op->args[3].as(); + const StringImm* layout = op->args[7].as(); + CHECK(m); + CHECK(n); + CHECK(k); + CHECK(layout); + + std::string scope = scopes[buffer_var]; + if (fragments.count(buffer_var)) { + // check if the fragment has met before + FragmentInfo info = fragments[buffer_var]; + CHECK_EQ(m->value, info.m); + CHECK_EQ(n->value, info.n); + CHECK_EQ(k->value, info.k); + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + CHECK_EQ(layout->value, info.layout); + } + } else { + // store metadata + FragmentInfo info; + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + info = FragmentInfo(m->value, n->value, k->value, layout->value); + } else if (scope == "wmma.accumulator") { + info = FragmentInfo(m->value, n->value, k->value, ""); + } + fragments[buffer_var] = info; + } + } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + // Get shape information from fill intrinsic + CHECK_EQ(op->args.size(), 6U); + const Variable* buffer_var = op->args[0].as(); + CHECK(buffer_var); + // Get shape + const IntImm* m = op->args[1].as(); + const IntImm* n = op->args[2].as(); + const IntImm* k = op->args[3].as(); + CHECK(m); + CHECK(n); + CHECK(k); + + std::string scope = scopes[buffer_var]; + // Only wmma.accumulator can use tvm_fill_fragment + CHECK_EQ(scope, "wmma.accumulator"); + if (fragments.count(buffer_var)) { + FragmentInfo info = fragments[buffer_var]; + CHECK_EQ(m->value, info.m); + CHECK_EQ(n->value, info.n); + CHECK_EQ(k->value, info.k); + } else { + FragmentInfo info(m->value, n->value, k->value, ""); + fragments[buffer_var] = info; + } + } + } + + // Get memory scope + void Visit_(const AttrStmt* op) final { + if (op->attr_key == attr::storage_scope) { + const Variable* buffer = op->node.as(); + CHECK(buffer); + scopes[buffer] = op->value.as()->value; + } + IRVisitor::Visit_(op); + } + + // Memory scope for allocations + std::unordered_map scopes; + // Fragment metadata for all fragments + std::unordered_map fragments; +}; + +// Check shape of fragment making sure it is a valid shape for tvm_mma_sync +class FragmentChecker : public IRVisitor { + public: + explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} + + void Visit_(const Call* op) final { + // Check shape when calling tvm_mma_sync + if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { + CHECK_EQ(op->args.size(), 8U); + const Variable* buffer_var_d = op->args[0].as(); + const Variable* buffer_var_a = op->args[2].as(); + const Variable* buffer_var_b = op->args[4].as(); + const Variable* buffer_var_c = op->args[6].as(); + CHECK(buffer_var_d); + CHECK(buffer_var_a); + CHECK(buffer_var_b); + CHECK(buffer_var_c); + + // Check all fragment A, B, C and D have the same shape + CHECK(CheckShape(buffer_var_d, buffer_var_a)); + CHECK(CheckShape(buffer_var_d, buffer_var_b)); + CHECK(CheckShape(buffer_var_d, buffer_var_c)); + } + } + + private: + // A tool for checking shapes of two fragments + bool CheckShape(const Variable* buffer1, const Variable* buffer2) { + CHECK(fragment_getter.fragments.count(buffer1)); + CHECK(fragment_getter.fragments.count(buffer2)); + FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1); + FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2); + return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; + } + // Fragment infomation + const FragmentGetter &fragment_getter; +}; + +// Store the metadata into attributes +class InferFragmenter : public IRMutator { + public: + explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} + + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + const Variable* buffer = op->buffer_var.get(); + if (fragment_getter.fragments.count(buffer)) { + // Add attribute to fragments allocation + FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); + + // Add shape attribute to all fragments + std::string shape = std::to_string(info.m) + ", " + + std::to_string(info.n) + ", " + + std::to_string(info.k); + Expr shape_expr = StringImm::make(shape); + Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); + if (info.layout != "") { + // Add shape attribute to matrix_a and matrix_b + Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout, + StringImm::make(info.layout), shape_attr); + return layout_attr; + } else { + return shape_attr; + } + } + return stmt; + } + + private: + // Fragment infomation + const FragmentGetter &fragment_getter; +}; + +Stmt InferFragment(Stmt stmt) { + FragmentGetter getter; + getter.Visit(stmt); + FragmentChecker(getter).Visit(stmt); + stmt = InferFragmenter(getter).Mutate(stmt); + return stmt; +} + +LoweredFunc InferFragment(LoweredFunc f) { + CHECK_NE(f->func_type, kHostFunc); + auto n = make_node(*f.operator->()); + n->body = InferFragment(f->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index f7deb25..8cad36d 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -341,5 +341,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower().Mutate(stmt); } +LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { + auto n = make_node(*f.operator->()); + n->body = LowerStorageAccessInfo(f->body); + return LoweredFunc(n); +} + } // namespace ir } // namespace tvm diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 7c2f321..34dac52 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -263,6 +263,28 @@ class ThreadSyncInserter : public IRMutator { } } + Expr Mutate_(const Call* op, const Expr& e) final { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + CHECK_EQ(op->args.size(), 5U); + const Variable* buffer_var = op->args[1].as(); + Var var(GetRef(buffer_var)); + const IntImm* flag = op->args[4].as(); + if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[var].read_count; + } + if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[var].write_count; + } + return expr; + } else { + return IRMutator::Mutate_(op, e); + } + } + private: // RW statistics about data struct Entry { diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 0934e46..8e75fab 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -50,7 +50,13 @@ enum class StorageRank { */ kWarp = 2, /*! \brief thread local memory */ - kLocal = 3 + kLocal = 3, + /*! \brief wmma scope memory of matrix_a */ + kWMMAMatrixA = 4, + /*! \brief wmma scope memory of matrix_b */ + kWMMAMatrixB = 5, + /*! \brief wmma scope memory of accumulator */ + kWMMAAccumulator = 6, }; /*! @@ -89,6 +95,9 @@ struct StorageScope { case StorageRank::kShared: return "shared" + tag; case StorageRank::kWarp: return "warp" + tag; case StorageRank::kLocal: return "local" + tag; + case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag; + case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag; + case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag; default: LOG(FATAL) << "unknown storage scope"; return ""; } } @@ -111,6 +120,15 @@ struct StorageScope { } else if (s.compare(0, 5, "local") == 0) { r.rank = StorageRank::kLocal; r.tag = s.substr(5, std::string::npos); + } else if (s.compare(0, 13, "wmma.matrix_a") == 0) { + r.rank = StorageRank::kWMMAMatrixA; + r.tag = s.substr(13, std::string::npos); + } else if (s.compare(0, 13, "wmma.matrix_b") == 0) { + r.rank = StorageRank::kWMMAMatrixB; + r.tag = s.substr(13, std::string::npos); + } else if (s.compare(0, 16, "wmma.accumulator") == 0) { + r.rank = StorageRank::kWMMAAccumulator; + r.tag = s.substr(16, std::string::npos); } else { LOG(FATAL) << "unknown storage scope " << s; } diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py new file mode 100644 index 0000000..9fe72cd --- /dev/null +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import numpy as np +from topi.testing import conv2d_nhwc_python +from tvm.contrib import nvcc + +VERIFY = True + + +def intrin_wmma_load_matrix(shape, scope): + n, m, l = shape + if scope == "wmma.matrix_a": + row, col = n, l + elif scope == "wmma.matrix_b": + row, col = l, m + A = tvm.placeholder((row, col), name='A', dtype='float16') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=row * col) + C = tvm.compute((row, col), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, m, l, BC.elem_offset // (row * col), + BA.access_ptr('r'), col, 'row_major')) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def intrin_wmma_gemm(shape): + n, m, l = shape + A = tvm.placeholder((n, l), name='A', dtype='float16') + B = tvm.placeholder((l, m), name='B', dtype='float16') + k = tvm.reduce_axis((0, l), name="k") + C = tvm.compute((n, m), + lambda ii, jj: + tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') + BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=n * l) + BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=l * m) + BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=n * m) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def init(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0)) + return ib.get() + + def update(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset // (n * m), + BA.data, BA.elem_offset // (n * l), + BB.data, BB.elem_offset // (l * m), + BC.data, BC.elem_offset // (n * m))) + return ib.get() + + return update(), init(), update() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + + +def intrin_wmma_store_matrix(shape): + n, m, l = shape + A = tvm.placeholder((n, m), name='A', dtype='float32') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=n * m) + C = tvm.compute((n, m), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=n * m) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, n, m, l, BA.elem_offset // (n * m), + BC.access_ptr('w'), m, 'row_major')) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def test_tensor_core_batch_matmal(): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + print("skip because gpu does not support tensor core") + return + + batch_size = 4 + n = 512 + m, l = n, n + assert (n % 32 == 0) + assert (m % 8 == 0) + assert (l % 16 == 0) + nn, mm, ll = n // 32, m // 8, l // 16 + A = tvm.placeholder((batch_size, nn, ll, 32, 16), name='A', dtype='float16') + B = tvm.placeholder((batch_size, ll, mm, 16, 8), name='B', dtype='float16') + k1 = tvm.reduce_axis((0, ll), name='k1') + k2 = tvm.reduce_axis((0, 16), name='k2') + C = tvm.compute((batch_size, nn, mm, 32, 8), + lambda b, i, j, ii, jj: + tvm.sum(A[b, i, k1, ii, k2].astype('float') * B[b, k1, j, k2, jj].astype('float'), axis=[k1, k2]), + name='Fragment_C') + s = tvm.create_schedule(C.op) + + warp_size = 32 + kernel_size = 16 + block_row_warps = 2 + block_col_warps = 4 + warp_row_tiles = 4 + warp_col_tiles = 2 + chunk = 4 + + block_x = tvm.thread_axis('blockIdx.x') + block_y = tvm.thread_axis('blockIdx.y') + block_z = tvm.thread_axis('blockIdx.z') + thread_x = tvm.thread_axis('threadIdx.x') + thread_y = tvm.thread_axis('threadIdx.y') + thread_z = tvm.thread_axis('threadIdx.z') + + AS = s.cache_read(A, 'shared', [C]) + BS = s.cache_read(B, 'shared', [C]) + AF = s.cache_read(AS, 'wmma.matrix_a', [C]) + BF = s.cache_read(BS, 'wmma.matrix_b', [C]) + CF = s.cache_write(C, 'wmma.accumulator') + + b, i, j, kernel_i, kernel_j = s[C].op.axis + i, ii = s[C].split(i, factor=warp_row_tiles) + block_i, i = s[C].split(i, factor=block_row_warps) + j, jj = s[C].split(j, factor=warp_col_tiles) + block_j, j = s[C].split(j, factor=block_col_warps) + s[C].reorder(block_i, block_j, i, j, ii, jj, kernel_i, kernel_j) + s[C].bind(b, block_z) + s[C].bind(block_i, block_x) + s[C].bind(block_j, block_y) + s[C].bind(i, thread_y) + s[C].bind(j, thread_z) + + s[CF].compute_at(s[C], j) + b, warp_i, warp_j, _i, _j = s[CF].op.axis + k, _k = CF.op.reduce_axis + ko, ki = s[CF].split(k, factor=chunk) + s[CF].reorder(ko, ki, warp_i, warp_j, _i, _j, _k) + + s[AF].compute_at(s[CF], ki) + s[BF].compute_at(s[CF], ki) + + s[AS].compute_at(s[CF], ko) + b, xo, yo, xi, yi = AS.op.axis + tx, xo = s[AS].split(xo, nparts=block_row_warps) + ty, yo = s[AS].split(yo, nparts=block_col_warps) + t = s[AS].fuse(xi, yi) + to, ti = s[AS].split(t, nparts=warp_size) + s[AS].bind(tx, thread_y) + s[AS].bind(ty, thread_z) + s[AS].bind(to, thread_x) + + s[BS].compute_at(s[CF], ko) + b, xo, yo, xi, yi = BS.op.axis + tx, xo = s[BS].split(xo, nparts=block_row_warps) + ty, yo = s[BS].split(yo, nparts=block_col_warps) + t = s[BS].fuse(xi, yi) + to, ti = s[BS].split(t, nparts=warp_size) + s[BS].bind(tx, thread_y) + s[BS].bind(ty, thread_z) + s[BS].bind(to, thread_x) + + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_a')) + s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_b')) + s[C].tensorize(kernel_i, intrin_wmma_store_matrix((32, 8, 16))) + s[CF].tensorize(_i, intrin_wmma_gemm((32, 8, 16))) + + func = tvm.build(s, [A, B, C], 'cuda') + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=(batch_size, nn, ll, 32, 16)).astype(A.dtype) + b_np = np.random.uniform(size=(batch_size, ll, mm, 16, 8)).astype(B.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros((batch_size, nn, mm, 32, 8), dtype=C.dtype), ctx) + func(a, b, c) + evaluator = func.time_evaluator(func.entry_name, ctx, number=3) + print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3)) + + if VERIFY: + func(a, b, c) + a_np = a_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n) + b_np = b_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n) + c_np = c.asnumpy().transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n) + np.testing.assert_allclose(c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) + + + +def test_tensor_core_batch_conv(): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + print("skip because gpu does not support tensor core") + return + + # The sizes of inputs and filters + batch_size = 32 + height = 14 + width = 14 + in_channels = 32 + out_channels = 64 + kernel_h = 3 + kernel_w = 3 + pad_h = 1 + pad_w = 1 + stride_h = 1 + stride_w = 1 + block_size = 16 + + block_row_warps = 2 + block_col_warps = 4 + warp_row_tiles = 4 + warp_col_tiles = 2 + warp_size = 32 + chunk = 2 + + # Input feature map: (N, H, W, IC, n, ic) + data_shape = (batch_size // block_size, + height, + width, + in_channels // block_size, + block_size, + block_size) + # Kernel: (H, W, IC, OC, ic, oc) + kernel_shape = (kernel_h, + kernel_w, + in_channels // block_size, + out_channels // block_size, + block_size, + block_size) + + # Output feature map: (N, H, W, OC, n, oc) + output_shape = (batch_size // block_size, + height, + width, + out_channels // block_size, + block_size, + block_size) + + assert (batch_size % block_size == 0) + assert (in_channels % block_size == 0) + assert (out_channels % block_size == 0) + + kh = tvm.reduce_axis((0, kernel_h), name='kh') + kw = tvm.reduce_axis((0, kernel_w), name='kw') + ic = tvm.reduce_axis((0, in_channels // block_size), name='ic') + ii = tvm.reduce_axis((0, block_size), name='ii') + + # Algorithm + A = tvm.placeholder(data_shape, name='A', dtype="float16") + W = tvm.placeholder(kernel_shape, name='W', dtype="float16") + Apad = tvm.compute( + (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size, + block_size), + lambda n, h, w, i, nn, ii: tvm.if_then_else( + tvm.all(h >= pad_h, h - pad_h < height, + w >= pad_w, w - pad_w < width), + A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")), + name='Apad') + Conv = tvm.compute(output_shape, + lambda n, h, w, o, nn, oo: tvm.sum( + Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") * + W[kh, kw, ic, o, ii, oo].astype("float32"), + axis=[ic, kh, kw, ii]), + name="Conv") + + s = tvm.create_schedule(Conv.op) + s[Apad].compute_inline() + + AS = s.cache_read(Apad, 'shared', [Conv]) + WS = s.cache_read(W, 'shared', [Conv]) + AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) + WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) + ConvF = s.cache_write(Conv, 'wmma.accumulator') + + block_x = tvm.thread_axis('blockIdx.x') + block_y = tvm.thread_axis('blockIdx.y') + block_z = tvm.thread_axis('blockIdx.z') + thread_x = tvm.thread_axis('threadIdx.x') + thread_y = tvm.thread_axis('threadIdx.y') + thread_z = tvm.thread_axis('threadIdx.z') + + nc, hc, wc, oc, nnc, ooc = Conv.op.axis + block_k = s[Conv].fuse(hc, wc) + s[Conv].bind(block_k, block_z) + nc, nci = s[Conv].split(nc, factor=warp_row_tiles) + block_i, nc = s[Conv].split(nc, factor=block_row_warps) + oc, oci = s[Conv].split(oc, factor=warp_col_tiles) + block_j, oc = s[Conv].split(oc, factor=block_col_warps) + s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) + s[Conv].bind(block_i, block_x) + s[Conv].bind(block_j, block_y) + s[Conv].bind(nc, thread_y) + s[Conv].bind(oc, thread_z) + + s[ConvF].compute_at(s[Conv], oc) + n, h, w, o, nnf, oof = ConvF.op.axis + ko, ki = s[ConvF].split(ic, factor=chunk) + s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) + + s[AF].compute_at(s[ConvF], kw) + s[WF].compute_at(s[ConvF], kw) + + s[WS].compute_at(s[ConvF], kh) + s[AS].compute_at(s[ConvF], kh) + + n, h, w, i, nn, ii = AS.op.axis + tx, xo = s[AS].split(n, nparts=block_row_warps) + ty, yo = s[AS].split(xo, nparts=block_col_warps) + t = s[AS].fuse(nn, ii) + to, ti = s[AS].split(t, factor=warp_size) + s[AS].bind(tx, thread_y) + s[AS].bind(ty, thread_z) + s[AS].bind(ti, thread_x) + + kh, kw, ic, o, ii, oo = WS.op.axis + tx, xo = s[WS].split(o, nparts=block_row_warps) + ty, yo = s[WS].split(xo, nparts=block_col_warps) + t = s[WS].fuse(ii, oo) + to, ti = s[WS].split(t, nparts=warp_size) + s[WS].bind(tx, thread_y) + s[WS].bind(ty, thread_z) + s[WS].bind(to, thread_x) + s[WS].vectorize(ti) + + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_a')) + s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_b')) + s[Conv].tensorize(nnc, intrin_wmma_store_matrix((16, 16, 16))) + s[ConvF].tensorize(nnf, intrin_wmma_gemm((16, 16, 16))) + + func = tvm.build(s, [A, W, Conv], 'cuda') + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=data_shape).astype(A.dtype) + w_np = np.random.uniform(size=kernel_shape).astype(W.dtype) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx) + evaluator = func.time_evaluator(func.entry_name, ctx, number=3) + print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3)) + + if VERIFY: + func(a, w, c) + a_np = a_np.transpose(0, 4, 1, 2, 3, 5).reshape(batch_size, height, width, in_channels) + w_np = w_np.transpose(0, 1, 2, 4, 3, 5).reshape(kernel_h, kernel_w, in_channels, out_channels) + c_np = c.asnumpy().transpose((0, 4, 1, 2, 3, 5)).reshape(batch_size, height, width, out_channels) + c_std = conv2d_nhwc_python(a_np.astype(Conv.dtype), + w_np.astype(Conv.dtype), + (stride_h, stride_w), + (pad_h, pad_w)).astype(Conv.dtype) + np.testing.assert_allclose(c_np, c_std, rtol=1e-4, atol=1e-4) + + +if __name__ == '__main__': + test_tensor_core_batch_matmal() + test_tensor_core_batch_conv() diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py index d2ef40c..8a6a467 100644 --- a/topi/python/topi/testing/conv2d_nhwc_python.py +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -40,7 +40,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): Returns ------- b_np : np.ndarray - 4-D with shape [out_height, out_width, out_channel, batch] + 4-D with shape [batch, out_height, out_width, out_channel] """ batch, in_height, in_width, in_channel = a_np.shape kernel_h, kernel_w, _, num_filter = w_np.shape diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py new file mode 100644 index 0000000..774b4c7 --- /dev/null +++ b/tutorials/optimize/opt_conv_tensorcore.py @@ -0,0 +1,348 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _opt-conv-tensorcore: + +How to optimize convolution using TensorCores +================================== +**Author**: `Siyuan Feng `_ + +In this tutorial, we will demonstrate how to write a high performance convolution +schedule using TensorCores in TVM. In this example, we assume the input to +convolution has a large batch. We strongly recommend covering the :ref:`opt-conv-gpu` tutorial first. + +""" + +################################################################ +# TensorCore Introduction +# ------------------------- +# Each Tensor Core provides a 4x4x4 matrix processing array that operates +# :code:`D = A * B + C`, where A, B, C and D are 4x4 matrices as Figure shows. +# The matrix multiplication inputs A and B are FP16 matrices, while the accumulation +# matrices C and D may be FP16 or FP32 matrices. +# +# However, CUDA programmers can only use warp-level primitive +# :code:`wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` to perform +# 16x16x16 half-precision matrix multiplication on tensor cores. Before invoking +# the matrix multiplication, programmers must load data from memory into registers +# with primitive :code:`wmma::load_matrix_sync`, explicitly. The NVCC compiler translates +# that primitive into multiple memory load instructions. At run time, every thread loads +# 16 elements from matrix A and 16 elements from B. + +################################################################ +# Preparation and Algorithm +# -------------------------- +# We use the fixed size for input tensors with 256 channels and 14 x 14 dimensions. +# The batch size is 256. Convolution filters contain 512 filters of size 3 x 3. +# We use stride size 1 and padding size 1 for the convolution. In the example, we use +# NHWCnc memory layout.The following code defines the convolution algorithm in TVM. + +import tvm +import numpy as np +from tvm.contrib import nvcc + +# The sizes of inputs and filters +batch_size = 256 +height = 14 +width = 14 +in_channels = 256 +out_channels = 512 +kernel_h = 3 +kernel_w = 3 +pad_h = 1 +pad_w = 1 +stride_h = 1 +stride_w = 1 + +# TensorCore shape +block_size = 16 + +assert (batch_size % block_size == 0) +assert (in_channels % block_size == 0) +assert (out_channels % block_size == 0) + +# Input feature map: (N, H, W, IC, n, ic) +data_shape = (batch_size // block_size, + height, + width, + in_channels // block_size, + block_size, + block_size) +# Kernel: (H, W, IC, OC, ic, oc) +kernel_shape = (kernel_h, + kernel_w, + in_channels // block_size, + out_channels // block_size, + block_size, + block_size) +# Output feature map: (N, H, W, OC, n, oc) +output_shape = (batch_size // block_size, + height, + width, + out_channels // block_size, + block_size, + block_size) + +# Reduction axes +kh = tvm.reduce_axis((0, kernel_h), name='kh') +kw = tvm.reduce_axis((0, kernel_w), name='kw') +ic = tvm.reduce_axis((0, in_channels // block_size), name='ic') +ii = tvm.reduce_axis((0, block_size), name='ii') + +# Algorithm +A = tvm.placeholder(data_shape, name='A', dtype="float16") +W = tvm.placeholder(kernel_shape, name='W', dtype="float16") +Apad = tvm.compute( + (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size, + block_size), + lambda n, h, w, i, nn, ii: tvm.if_then_else( + tvm.all(h >= pad_h, h - pad_h < height, + w >= pad_w, w - pad_w < width), + A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")), + name='Apad') +Conv = tvm.compute(output_shape, + lambda n, h, w, o, nn, oo: tvm.sum( + Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") * + W[kh, kw, ic, o, ii, oo].astype("float32"), + axis=[ic, kh, kw, ii]), + name="Conv") + +s = tvm.create_schedule(Conv.op) +s[Apad].compute_inline() + +############################################################################### +# Memory Scope +# ---------------- +# +# In traditional GPU schedule, we have global, shared and local memory scope. +# To support TensorCores, we add another three special memory scope: :code:`wmma.matrix_a`, +# :code:`wmma.matrix_b` and :code:`wmma.accumulator`. On hardware, all fragments scope +# stores at the on-chip registers level, the same place with local memory. + +# Designate the memory hierarchy +AS = s.cache_read(Apad, 'shared', [Conv]) +WS = s.cache_read(W, 'shared', [Conv]) +AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) +WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) +ConvF = s.cache_write(Conv, 'wmma.accumulator') + +############################################################################### +# Define Tensor Intrinsic +# In fact, TensorCore is a special hardware operation. So, we can just use tensorize +# to replace a unit of computation with the TensorCore instruction. The first thing is +# that we need to define tensor intrinsic. +# +# There are four basic operation in TensorCore: :code:`fill_fragment`, :code:`load_matrix`, +# :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync` +# are both used in matrix multiplication, so we can just write following three intrinsics. + +def intrin_wmma_load_matrix(scope): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float16') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, n, n, BC.elem_offset // 256, + BA.access_ptr('r'), n, 'row_major')) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def intrin_wmma_gemm(): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float16') + B = tvm.placeholder((n, n), name='B', dtype='float16') + k = tvm.reduce_axis((0, n), name="k") + C = tvm.compute((n, n), + lambda ii, jj: + tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') + BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) + BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) + BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def init(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + return ib.get() + + def update(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset // 256, + BA.data, BA.elem_offset // 256, + BB.data, BB.elem_offset // 256, + BC.data, BC.elem_offset // 256)) + return ib.get() + + return update(), init(), update() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + + +def intrin_wmma_store_matrix(): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float32') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, n, n, n, BA.elem_offset // 256, + BC.access_ptr('w'), n, 'row_major')) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +############################################################################### +# Scheduling the Computation +# -------------------------- +# To use TensorCores in TVM, we must schedule the computation into specific structure +# to match the tensor intrinsic. The same as traditional GPU programs, we can also use +# shared memory to boost the speed. If you have any questions about blocking and shared +# memory, please refer :ref:`opt-conv-gpu`. +# +# In this example, each block contains 2x4 warps, and each warp calls 4x2 TensorCore +# instructions. Thus, the output shape of each warp is 64x32 and each block outputs +# 128x128 titles. Due to the limit of shared memory space, we only load 2 blocks (2x128x128 tiles) +# one time. +# +# .. note:: +# +# *Warp-level Operation* +# +# Note that all TensorCore instructions are warp-level instructions, which means all 32 threads +# in a warp should do this instruction simultaneously. Making theadIdx.x extent=32 is one of the +# easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain +# TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution. +# The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time. +# + +# Define tiling sizes +block_row_warps = 4 +block_col_warps = 2 +warp_row_tiles = 2 +warp_col_tiles = 4 +warp_size = 32 +chunk = 2 + +block_x = tvm.thread_axis('blockIdx.x') +block_y = tvm.thread_axis('blockIdx.y') +block_z = tvm.thread_axis('blockIdx.z') +thread_x = tvm.thread_axis('threadIdx.x') +thread_y = tvm.thread_axis('threadIdx.y') +thread_z = tvm.thread_axis('threadIdx.z') + +nc, hc, wc, oc, nnc, ooc = Conv.op.axis +block_k = s[Conv].fuse(hc, wc) +s[Conv].bind(block_k, block_z) +nc, nci = s[Conv].split(nc, factor=warp_row_tiles) +block_i, nc = s[Conv].split(nc, factor=block_row_warps) +oc, oci = s[Conv].split(oc, factor=warp_col_tiles) +block_j, oc = s[Conv].split(oc, factor=block_col_warps) +s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) +s[Conv].bind(block_i, block_x) +s[Conv].bind(block_j, block_y) +s[Conv].bind(nc, thread_y) +s[Conv].bind(oc, thread_z) + +# Schedule local computation +s[ConvF].compute_at(s[Conv], oc) +n, h, w, o, nnf, oof = ConvF.op.axis +ko, ki = s[ConvF].split(ic, factor=chunk) +s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) + +# Move intermediate computation into each output compute tile +s[AF].compute_at(s[ConvF], kw) +s[WF].compute_at(s[ConvF], kw) + +# Schedule for A's share memory +s[AS].compute_at(s[ConvF], kh) +n, h, w, i, nn, ii = AS.op.axis +tx, xo = s[AS].split(n, nparts=block_row_warps) +ty, yo = s[AS].split(xo, nparts=block_col_warps) +t = s[AS].fuse(nn, ii) +to, ti = s[AS].split(t, factor=warp_size) +s[AS].bind(tx, thread_y) +s[AS].bind(ty, thread_z) +s[AS].bind(ti, thread_x) + +# Schedule for W's share memory +s[WS].compute_at(s[ConvF], kh) +kh, kw, ic, o, ii, oo = WS.op.axis +tx, xo = s[WS].split(o, nparts=block_row_warps) +ty, yo = s[WS].split(xo, nparts=block_col_warps) +t = s[WS].fuse(ii, oo) +to, ti = s[WS].split(t, nparts=warp_size) +s[WS].bind(tx, thread_y) +s[WS].bind(ty, thread_z) +s[WS].bind(to, thread_x) +s[WS].vectorize(ti) +print(tvm.lower(s, [A, W, Conv], simple_mode=True)) + +############################################################################### +# Lowering Computation to Intrinsics +# -------------------------- +# The last phase is to lower the computation loops down to TensorCore hardware intrinsics +# by mapping the 2D convolution to tensor intrinsics +# + +s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) +s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) +s[Conv].tensorize(nnc, intrin_wmma_store_matrix()) +s[ConvF].tensorize(nnf, intrin_wmma_gemm()) +print(tvm.lower(s, [A, W, Conv], simple_mode=True)) + +############################################################################### +# Generate CUDA Kernel +# -------------------- +# Finally we use TVM to generate and compile the CUDA kernel, and evaluate the latency of convolution. +# Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not +# be able to run on our build server + +ctx = tvm.gpu(0) +if nvcc.have_tensorcore(ctx.compute_version): + with tvm.build_config(auto_unroll_max_step=16): + func = tvm.build(s, [A, W, Conv], 'cuda') + a_np = np.random.uniform(size=data_shape).astype(A.dtype) + w_np = np.random.uniform(size=kernel_shape).astype(W.dtype) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx) + evaluator = func.time_evaluator(func.entry_name, ctx, number=10) + print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3)) + +############################################################################### +# Summary +# This tutorial demonstrates how TVM scheduling primitives can be used to +# call TensorCores on specific GPUs. diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index 5c24375..cec217c 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -80,6 +80,7 @@ def build_config(debug_flag=0, **kwargs): if debug_flag: pass_list.append((1, add_debug)) pass_list.append((2, ir_pass.inject_alu_intrin)) + pass_list.append((3, tvm.ir_pass.LowerStorageAccessInfo)) pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, ir_pass.cpu_access_rewrite)) return tvm.build_config(add_lower_pass=pass_list, **kwargs)