TensorCore Support using Intrinsic (#4136)
authorSiyuan Feng <Hzfengsy@vip.qq.com>
Thu, 24 Oct 2019 19:04:37 +0000 (12:04 -0700)
committerLeyuan Wang <laurawly@gmail.com>
Thu, 24 Oct 2019 19:04:37 +0000 (12:04 -0700)
* add tensor core support

* avoid memory bank conflict

* fix thread sync & better performance

* better performance

* add schedule test for conv2d

* extend into BatchMatMul

* support config fragment shape and layout using intrinsic

* add TensorCore tutorial

* add int support and fix lint

* address comment

* add 32*16*8 TensorCore test

* fix wmma include logic

15 files changed:
include/tvm/ir.h
include/tvm/ir_pass.h
python/tvm/build_module.py
src/api/api_pass.cc
src/codegen/build_module.cc
src/codegen/codegen_cuda.cc
src/codegen/codegen_cuda.h
src/pass/infer_fragment.cc [new file with mode: 0644]
src/pass/storage_access.cc
src/pass/storage_sync.cc
src/runtime/thread_storage_scope.h
tests/python/unittest/test_schedule_tensor_core.py [new file with mode: 0644]
topi/python/topi/testing/conv2d_nhwc_python.py
tutorials/optimize/opt_conv_tensorcore.py [new file with mode: 0644]
vta/python/vta/build_module.py

index b908049..37718fe 100644 (file)
@@ -1311,6 +1311,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope";
 constexpr const char* device_scope = "device_scope";
 
 /*!
+ * \brief Mark that the shape of TensorCore fragment
+ */
+constexpr const char* fragment_shape = "fragment_shape";
+
+/*!
+ * \brief Mark that the layout of TensorCore fragment
+ */
+constexpr const char* fragment_layout = "fragment_layout";
+
+/*!
  * \brief Check if attr_key is a pragma key extension
  * \param attr_key The attr key to be compared
  * \return true if it is a pragma key
@@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
  *  }
  */
 constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
+/*!
+ * \brief tvm intrinsic for tensor core load operators.
+ *
+ *  void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
+ *                            Expr index, Expr buffer_ptr, Expr stride,
+ *                            StringImm layout) {
+ *    // m, n, k are the shape of wmma fragment.
+ *    // Determine fragment layout(column-major or row major) by layout.
+ *    // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
+ *    nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
+ *  }
+ */
+constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
+/*!
+ * \brief tvm intrinsic for tensor core mma_sync operators.
+ *
+ *  void tvm_mma_sync(Var fragment_d, Expr index_d,
+ *                    Var fragment_a, Expr index_a,
+ *                    Var fragment_b, Expr index_b,
+ *                    Var fragment_c, Expr index_c) {
+ *    nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
+ *                           fragment_b[index_b], fragment_c[index_c]);
+ *  }
+ */
+constexpr const char* tvm_mma_sync = "tvm_mma_sync";
+/*!
+ * \brief tvm intrinsic for tensor core fill_fragment operators.
+ *
+ *  void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
+ *                         Expr index, Expr value) {
+ *    // m, n, k are the shape of wmma fragment
+ *    // fragments must be in 'wmma.accumulator' scope.
+ *    nvcuda::wmma::fill_fragment(fragment[index], value);
+ *  }
+ */
+constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
+/*!
+ * \brief tvm intrinsic for tensor core store operators.
+ *
+ *  void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
+ *                             Expr index, Expr buffer_ptr, Expr stride,
+ *                             StringImm layout) {
+ *    // m, n, k are the shape of wmma fragment
+ *    // fragments must be in 'wmma.accumulator' scope.
+ *    nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
+ *  }
+ */
+constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
 
 }   // namespace intrinsic
 
index 03078b8..842c6af 100644 (file)
@@ -514,6 +514,15 @@ LoweredFunc CombineContextCall(LoweredFunc f);
 LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
 
 /*!
+ * \brief Lower attached storage access information on device.
+ * Do this pass after all storage access analysis finish.
+ *
+ * \param func The device function to be lowered.
+ * \return Transformed function.
+ */
+LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
+
+/*!
  * \brief Lower intrinsic function calls.
  * \param f The device function to be lowered.
  * \param target The target device.
@@ -533,6 +542,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
 LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
 
 /*!
+ * \brief Infer the TensorCore fragment infomation using tensor intrinsics
+ *
+ * \param f The device function to be lowered.
+ * \return Transformed function.
+ */
+LoweredFunc InferFragment(LoweredFunc f);
+
+/*!
  * \brief Verify if memory accesses are legal for a specific target device type.
  *
  *  In the case that tgt is cuda, if not all workload is bound with
index 479bdce..fe2f641 100644 (file)
@@ -413,7 +413,6 @@ def lower(sch,
 
     # Phase 3
     stmt = ir_pass.Simplify(stmt)
-    stmt = ir_pass.LowerStorageAccessInfo(stmt)
     stmt = ir_pass.RemoveNoOp(stmt)
     if not cfg.disable_select_rewriting:
         stmt = ir_pass.RewriteUnsafeSelect(stmt)
@@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host):
                 func = ir_pass.ThreadSync(func, "global")
             func = ir_pass.ThreadSync(func, "shared")
             func = ir_pass.ThreadSync(func, "warp")
+            func = ir_pass.InferFragment(func)
             warp_size = target.thread_warp_size
             func = ir_pass.LowerThreadAllreduce(func, warp_size)
             fsplits = [s for s in ir_pass.SplitHostDevice(func)]
@@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host):
         assert not fdevice
 
     target_host = _target.create(target_host)
+    fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
+    fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
     fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
     fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
     fhost = [ir_pass.CombineContextCall(x) for x in fhost]
index dd0415a..d7f621f 100644 (file)
@@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
       });
   });
 
+TVM_REGISTER_API("ir_pass.LowerStorageAccess")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  LoweredFunc f = args[0];
+  auto n = make_node<LoweredFuncNode>(*f.operator->());
+  n->body = LowerStorageAccessInfo(f->body);
+  *ret = LoweredFunc(n);
+});
+
 // make from two arguments
 #define REGISTER_PASS(PassName)                                   \
   TVM_REGISTER_API("ir_pass."#PassName)                           \
@@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice);
 REGISTER_PASS(StorageRewrite);
 REGISTER_PASS(CoProcSync);
 REGISTER_PASS(LowerStorageAccessInfo);
+REGISTER_PASS(LowerDeviceStorageAccessInfo)
 REGISTER_PASS(InjectVirtualThread);
 REGISTER_PASS(InjectPrefetch);
 REGISTER_PASS(InjectDoubleBuffer);
@@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope);
 REGISTER_PASS(InstrumentBoundCheckers);
 REGISTER_PASS(VerifyCompactBuffer);
 REGISTER_PASS(HoistIfThenElse);
+REGISTER_PASS(InferFragment)
 }  // namespace ir
 }  // namespace tvm
index 66340e9..cfcb060 100644 (file)
@@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch,
 
   // Phase 2
   stmt = ir::Simplify(stmt);
-  stmt = ir::LowerStorageAccessInfo(stmt);
   stmt = ir::RemoveNoOp(stmt);
 
   if (!(config->disable_select_rewriting))
@@ -517,6 +516,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
   for (size_t i = 0; i < fhost.size(); ++i) {
     auto func = fhost[i];
     func = ir::BindDeviceType(func, target->device_type);
+    func = ir::LowerDeviceStorageAccessInfo(func);
     func = ir::LowerTVMBuiltin(func);
     fhost.Set(i, func);
   }
@@ -524,6 +524,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
   for (size_t i = 0; i < fhost.size(); ++i) {
     auto func = fhost[i];
     func = ir::LowerIntrin(func, target_host->target_name);
+    func = ir::LowerDeviceStorageAccessInfo(func);
     func = ir::CombineContextCall(func);
     fhost.Set(i, func);
   }
index 241310f..55b4810 100644 (file)
@@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "#include <math_constants.h>\n";
   }
 
+  if (need_mma_h_) {
+    decl_stream << "#include <mma.h>\n";
+  }
+
   return CodeGenC::Finish();
 }
 
@@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) {  // NOLINT(*)
   bool fail = false;
   if (t.is_float()) {
     switch (t.bits()) {
-      case 16: os << "half";
+      case 16:
         enable_fp16_ = true;
+        if (lanes == 1) {
+          os << "half";
+        } else if (lanes <= 8) {
+          CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
+          os << "float" << lanes / 2;
+        } else {
+          fail = true;
+        }
         break;
       case 32: os << "float"; break;
       case 64: os << "double"; break;
       default: fail = true; break;
     }
-    if (!fail && lanes == 1) return;
+    if (!fail && (lanes == 1 || t.bits() == 16)) return;
     if (!fail && (lanes >= 2 && lanes <= 4)) {
       os << lanes; return;
     }
@@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope(
   }
 }
 
+void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
+  if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
+    need_mma_h_ = true;
+    CHECK_EQ(op->args.size(), 6U);
+    os << "nvcuda::wmma::fill_fragment(";
+    this->PrintExpr(op->args[0], os);
+    os << "[";
+    this->PrintExpr(op->args[4], os);
+    os << "], ";
+    this->PrintExpr(op->args[5], os);
+    os << ")";
+  } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) {
+    need_mma_h_ = true;
+    CHECK_EQ(op->args.size(), 8U);
+    os << "nvcuda::wmma::load_matrix_sync(";
+    this->PrintExpr(op->args[0], os);
+    os << "[";
+    this->PrintExpr(op->args[4], os);
+    os << "], ";
+    this->PrintExpr(op->args[5], os);
+    os << ", ";
+    this->PrintExpr(op->args[6], os);
+    os << ")";
+  } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
+    need_mma_h_ = true;
+    CHECK_EQ(op->args.size(), 8U);
+    os << "nvcuda::wmma::store_matrix_sync(";
+    this->PrintExpr(op->args[5], os);
+    os << ", ";
+    this->PrintExpr(op->args[0], os);
+    os << "[";
+    this->PrintExpr(op->args[4], os);
+    os << "], ";
+    this->PrintExpr(op->args[6], os);
+    if (const StringImm *str = op->args[7].as<StringImm>()) {
+      os << ", nvcuda::wmma::mem_" << str->value;
+    } else {
+      LOG(FATAL) << "Invalid parameters";
+    }
+    os << ")";
+  } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
+    need_mma_h_ = true;
+    CHECK_EQ(op->args.size(), 8U);
+    os << "nvcuda::wmma::mma_sync(";
+    for (int i = 0; i < 4; ++i) {
+      this->PrintExpr(op->args[i * 2], os);
+      os << "[";
+      this->PrintExpr(op->args[i * 2 + 1], os);
+      os << "]" << ((i < 3) ? ", ": ")");
+    }
+  } else {
+    CodeGenC::VisitExpr_(op, os);
+  }
+}
+
+void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
+  if (op->attr_key == attr::fragment_shape) {
+    const Variable* buffer = op->node.as<Variable>();
+    const StringImm* shape_str = op->value.as<StringImm>();
+    fragment_shapes[buffer] = shape_str->value;
+  } else if (op->attr_key == attr::fragment_layout) {
+    const Variable* buffer = op->node.as<Variable>();
+    const StringImm* layout_str = op->value.as<StringImm>();
+    fragment_layouts[buffer] = layout_str->value;
+  }
+  CodeGenC::VisitStmt_(op);
+}
+
+void CodeGenCUDA::VisitStmt_(const Allocate* op) {
+  CHECK(!is_zero(op->condition));
+  std::string vid = AllocVarID(op->buffer_var.get());
+  if (op->new_expr.defined()) {
+    // Prefer global static allocation for the program
+    CHECK_EQ(op->free_function, "nop");
+    std::string new_data = PrintExpr(op->new_expr);
+    this->PrintIndent();
+    PrintType(op->type, stream);
+    stream << "* "<< vid << '=' << new_data << ";\n";
+  } else {
+    this->PrintIndent();
+    int32_t constant_size = op->constant_allocation_size();
+    CHECK_GT(constant_size, 0)
+      << "Can only handle constant size stack allocation for now";
+    const Variable* buffer = op->buffer_var.as<Variable>();
+    std::string scope = alloc_storage_scope_.at(buffer);
+    if (scope.find("wmma.") == 0) {
+      if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
+        CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8))
+          << "Matrix_a and matrix_b only support half or char or unsigned char type for now";
+      } else {
+        CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32))
+          << "Accumulator only support half, float and int type for now";
+      }
+      constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
+      PrintWmmaScope(scope, op->type, buffer, stream);
+    } else {
+      PrintStorageScope(scope, stream);
+      stream << ' ';
+      PrintType(op->type, stream);
+    }
+    stream << ' '<< vid << '['
+           << constant_size << "];\n";
+  }
+  RegisterHandleType(op->buffer_var.get(), op->type);
+  this->PrintStmt(op->body);
+}
+
 void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
   if (is_const(op->value)) return;
   const Call* call = op->value.as<Call>();
@@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
   PrintConst(op, os, this);
 }
 
+void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t,
+    const Variable* variable, std::ostream &os) {
+  std::stringstream type;
+  PrintType(t, type);
+  std::string shape_str = fragment_shapes[variable];
+  if (scope == "wmma.matrix_a") {
+    need_mma_h_ = true;
+    std::string layout_str = fragment_layouts[variable];
+    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
+      << shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
+  } else if (scope == "wmma.matrix_b") {
+    need_mma_h_ = true;
+    std::string layout_str = fragment_layouts[variable];
+    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
+       << shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
+  } else if (scope == "wmma.accumulator") {
+    need_mma_h_ = true;
+    os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
+       << shape_str << ", "<< type.str() << ">";
+  }
+}
+
+int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
+                                         const Variable* variable, int32_t size) {
+  std::string shape_str = fragment_shapes[variable];
+  size_t m, n, k;
+  size_t last_pos = 0, pos = 0;
+  pos = shape_str.find(", ", last_pos);
+  m = std::stoi(shape_str.substr(last_pos, pos - last_pos));
+  last_pos = pos + 2;
+  pos = shape_str.find(", ", last_pos);
+  n = std::stoi(shape_str.substr(last_pos, pos - last_pos));
+  last_pos = pos + 2;
+  k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
+  if (scope == "wmma.matrix_a") {
+    return size / m / k;
+  } else if (scope == "wmma.matrix_b") {
+    return size / n / k;
+  } else if (scope == "wmma.accumulator") {
+    return size / m / n;
+  }
+  return 0;
+}
+
 }  // namespace codegen
 }  // namespace tvm
index 61c6fa3..53e7db4 100644 (file)
@@ -28,6 +28,7 @@
 #include <tvm/codegen.h>
 #include <tvm/packed_func_ext.h>
 #include <string>
+#include <unordered_map>
 #include "codegen_c.h"
 
 namespace tvm {
@@ -40,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC {
   void AddFunction(LoweredFunc f);
   std::string Finish();
   bool need_include_path() {
-    return (enable_fp16_ || enable_int8_ || need_math_constants_h_);
+    return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
   }
   // override behavior
   void VisitStmt_(const ir::For* op) final;
@@ -60,7 +61,10 @@ class CodeGenCUDA final : public CodeGenC {
   void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const FloatImm *op, std::ostream& os) final;
+  void VisitExpr_(const Call *op, std::ostream& os) final;
   void VisitStmt_(const Evaluate *op) final;
+  void VisitStmt_(const Allocate *op) final;
+  void VisitStmt_(const AttrStmt *op) final;
 
  private:
   // Whether global barrier is needed.
@@ -75,7 +79,14 @@ class CodeGenCUDA final : public CodeGenC {
   bool enable_int8_{false};
   // whether need math_constants.h
   bool need_math_constants_h_{false};
+  // whether need mma.h
+  bool need_mma_h_{false};
+
+  std::unordered_map<const Variable*, std::string> fragment_shapes;
+  std::unordered_map<const Variable*, std::string> fragment_layouts;
   friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
+  void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os);
+  int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size);
 };
 
 }  // namespace codegen
diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc
new file mode 100644 (file)
index 0000000..d9c0ef0
--- /dev/null
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \brief Infer TensorCore metadata from tensor intrinsic.
+ * \file tensorcore_fragment.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_visitor.h>
+#include <unordered_map>
+#include <unordered_set>
+#include "ir_util.h"
+#include "storage_access.h"
+#include "../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace ir {
+
+// Get fragment information from tensor intrinsics
+class FragmentGetter : public IRVisitor {
+ public:
+  // fragment metadata
+  struct FragmentInfo {
+    // fragment shape
+    int m, n, k;
+    // fragment layout (row-major or column-major)
+    std::string layout;
+    FragmentInfo() = default;
+    FragmentInfo(int _m, int _n, int _k, const std::string& _layout)
+      : m(_m), n(_n), k(_k), layout(_layout) {}
+  };
+
+  void Visit_(const Call* op) final {
+    IRVisitor::Visit_(op);
+
+    if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
+        op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
+      // Get shape and layout information from load and store intrinsic
+      CHECK_EQ(op->args.size(), 8U);
+      const Variable* buffer_var = op->args[0].as<Variable>();
+      CHECK(buffer_var);
+      // Get shape
+      const IntImm* m = op->args[1].as<IntImm>();
+      const IntImm* n = op->args[2].as<IntImm>();
+      const IntImm* k = op->args[3].as<IntImm>();
+      const StringImm* layout = op->args[7].as<StringImm>();
+      CHECK(m);
+      CHECK(n);
+      CHECK(k);
+      CHECK(layout);
+
+      std::string scope = scopes[buffer_var];
+      if (fragments.count(buffer_var)) {
+        // check if the fragment has met before
+        FragmentInfo info = fragments[buffer_var];
+        CHECK_EQ(m->value, info.m);
+        CHECK_EQ(n->value, info.n);
+        CHECK_EQ(k->value, info.k);
+        if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
+          CHECK_EQ(layout->value, info.layout);
+        }
+      } else {
+        // store metadata
+        FragmentInfo info;
+        if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
+          info = FragmentInfo(m->value, n->value, k->value, layout->value);
+        } else if (scope == "wmma.accumulator") {
+          info = FragmentInfo(m->value, n->value, k->value, "");
+        }
+        fragments[buffer_var] = info;
+      }
+    } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
+      // Get shape information from fill intrinsic
+      CHECK_EQ(op->args.size(), 6U);
+      const Variable* buffer_var = op->args[0].as<Variable>();
+      CHECK(buffer_var);
+      // Get shape
+      const IntImm* m = op->args[1].as<IntImm>();
+      const IntImm* n = op->args[2].as<IntImm>();
+      const IntImm* k = op->args[3].as<IntImm>();
+      CHECK(m);
+      CHECK(n);
+      CHECK(k);
+
+      std::string scope = scopes[buffer_var];
+      // Only wmma.accumulator can use tvm_fill_fragment
+      CHECK_EQ(scope, "wmma.accumulator");
+      if (fragments.count(buffer_var)) {
+        FragmentInfo info = fragments[buffer_var];
+        CHECK_EQ(m->value, info.m);
+        CHECK_EQ(n->value, info.n);
+        CHECK_EQ(k->value, info.k);
+      } else {
+        FragmentInfo info(m->value, n->value, k->value, "");
+        fragments[buffer_var] = info;
+      }
+    }
+  }
+
+  // Get memory scope
+  void Visit_(const AttrStmt* op) final {
+    if (op->attr_key == attr::storage_scope) {
+      const Variable* buffer = op->node.as<Variable>();
+      CHECK(buffer);
+      scopes[buffer] = op->value.as<StringImm>()->value;
+    }
+    IRVisitor::Visit_(op);
+  }
+
+  // Memory scope for allocations
+  std::unordered_map<const Variable*, std::string> scopes;
+  // Fragment metadata for all fragments
+  std::unordered_map<const Variable*, FragmentInfo> fragments;
+};
+
+// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
+class FragmentChecker : public IRVisitor {
+ public:
+  explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
+
+  void Visit_(const Call* op) final {
+    // Check shape when calling tvm_mma_sync
+    if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
+      CHECK_EQ(op->args.size(), 8U);
+      const Variable* buffer_var_d = op->args[0].as<Variable>();
+      const Variable* buffer_var_a = op->args[2].as<Variable>();
+      const Variable* buffer_var_b = op->args[4].as<Variable>();
+      const Variable* buffer_var_c = op->args[6].as<Variable>();
+      CHECK(buffer_var_d);
+      CHECK(buffer_var_a);
+      CHECK(buffer_var_b);
+      CHECK(buffer_var_c);
+
+      // Check all fragment A, B, C and D have the same shape
+      CHECK(CheckShape(buffer_var_d, buffer_var_a));
+      CHECK(CheckShape(buffer_var_d, buffer_var_b));
+      CHECK(CheckShape(buffer_var_d, buffer_var_c));
+    }
+  }
+
+ private:
+  // A tool for checking shapes of two fragments
+  bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
+    CHECK(fragment_getter.fragments.count(buffer1));
+    CHECK(fragment_getter.fragments.count(buffer2));
+    FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
+    FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
+    return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
+  }
+  // Fragment infomation
+  const FragmentGetter &fragment_getter;
+};
+
+// Store the metadata into attributes
+class InferFragmenter : public IRMutator {
+ public:
+  explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
+
+  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    const Variable* buffer = op->buffer_var.get();
+    if (fragment_getter.fragments.count(buffer)) {
+      // Add attribute to fragments allocation
+      FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
+
+      // Add shape attribute to all fragments
+      std::string shape = std::to_string(info.m) + ", " +
+                          std::to_string(info.n) + ", " +
+                          std::to_string(info.k);
+      Expr shape_expr = StringImm::make(shape);
+      Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
+      if (info.layout != "") {
+        // Add shape attribute to matrix_a and matrix_b
+        Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout,
+                                          StringImm::make(info.layout), shape_attr);
+        return layout_attr;
+      } else {
+        return shape_attr;
+      }
+    }
+    return stmt;
+  }
+
+ private:
+  // Fragment infomation
+  const FragmentGetter &fragment_getter;
+};
+
+Stmt InferFragment(Stmt stmt) {
+  FragmentGetter getter;
+  getter.Visit(stmt);
+  FragmentChecker(getter).Visit(stmt);
+  stmt = InferFragmenter(getter).Mutate(stmt);
+  return stmt;
+}
+
+LoweredFunc InferFragment(LoweredFunc f) {
+  CHECK_NE(f->func_type, kHostFunc);
+  auto n = make_node<LoweredFuncNode>(*f.operator->());
+  n->body = InferFragment(f->body);
+  return LoweredFunc(n);
+}
+
+}  // namespace ir
+}  // namespace tvm
index f7deb25..8cad36d 100644 (file)
@@ -341,5 +341,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
   return StorageAccessInfoLower().Mutate(stmt);
 }
 
+LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
+  auto n = make_node<LoweredFuncNode>(*f.operator->());
+  n->body = LowerStorageAccessInfo(f->body);
+  return LoweredFunc(n);
+}
+
 }  // namespace ir
 }  // namespace tvm
index 7c2f321..34dac52 100644 (file)
@@ -263,6 +263,28 @@ class ThreadSyncInserter : public IRMutator {
     }
   }
 
+  Expr Mutate_(const Call* op, const Expr& e) final {
+    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+      Expr expr = IRMutator::Mutate_(op, e);
+      op = expr.as<Call>();
+      CHECK_EQ(op->args.size(), 5U);
+      const Variable* buffer_var = op->args[1].as<Variable>();
+      Var var(GetRef<Var>(buffer_var));
+      const IntImm* flag = op->args[4].as<IntImm>();
+      if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
+          GetScope(buffer_var).rank == StorageRank::kGlobal) {
+        ++rw_stats_[var].read_count;
+      }
+      if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
+          GetScope(buffer_var).rank == StorageRank::kGlobal) {
+        ++rw_stats_[var].write_count;
+      }
+      return expr;
+    } else {
+      return IRMutator::Mutate_(op, e);
+    }
+  }
+
  private:
   // RW statistics about data
   struct Entry {
index 0934e46..8e75fab 100644 (file)
@@ -50,7 +50,13 @@ enum class StorageRank {
    */
   kWarp = 2,
   /*! \brief thread local memory */
-  kLocal = 3
+  kLocal = 3,
+  /*! \brief wmma scope memory of matrix_a */
+  kWMMAMatrixA = 4,
+  /*! \brief wmma scope memory of matrix_b */
+  kWMMAMatrixB = 5,
+  /*! \brief wmma scope memory of accumulator */
+  kWMMAAccumulator = 6,
 };
 
 /*!
@@ -89,6 +95,9 @@ struct StorageScope {
       case StorageRank::kShared: return "shared" + tag;
       case StorageRank::kWarp: return "warp" + tag;
       case StorageRank::kLocal: return "local" + tag;
+      case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag;
+      case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag;
+      case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag;
       default: LOG(FATAL) << "unknown storage scope"; return "";
     }
   }
@@ -111,6 +120,15 @@ struct StorageScope {
     } else if (s.compare(0, 5, "local") == 0) {
       r.rank = StorageRank::kLocal;
       r.tag = s.substr(5, std::string::npos);
+    } else if (s.compare(0, 13, "wmma.matrix_a") == 0) {
+      r.rank = StorageRank::kWMMAMatrixA;
+      r.tag = s.substr(13, std::string::npos);
+    } else if (s.compare(0, 13, "wmma.matrix_b") == 0) {
+      r.rank = StorageRank::kWMMAMatrixB;
+      r.tag = s.substr(13, std::string::npos);
+    } else if (s.compare(0, 16, "wmma.accumulator") == 0) {
+      r.rank = StorageRank::kWMMAAccumulator;
+      r.tag = s.substr(16, std::string::npos);
     } else {
       LOG(FATAL) << "unknown storage scope " << s;
     }
diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py
new file mode 100644 (file)
index 0000000..9fe72cd
--- /dev/null
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# 'License'); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import numpy as np
+from topi.testing import conv2d_nhwc_python
+from tvm.contrib import nvcc
+
+VERIFY = True
+
+
+def intrin_wmma_load_matrix(shape, scope):
+    n, m, l = shape
+    if scope == "wmma.matrix_a":
+        row, col = n, l
+    elif scope == "wmma.matrix_b":
+        row, col = l, m
+    A = tvm.placeholder((row, col), name='A', dtype='float16')
+    BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=row * col)
+    C = tvm.compute((row, col), lambda i, j: A[i, j], name='C')
+    BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col)
+
+    def intrin_func(ins, outs):
+        ib = tvm.ir_builder.create()
+
+        BA = ins[0]
+        BC = outs[0]
+        ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
+                                BC.data, n, m, l, BC.elem_offset // (row * col),
+                                BA.access_ptr('r'), col, 'row_major'))
+        return ib.get()
+
+    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
+
+
+def intrin_wmma_gemm(shape):
+    n, m, l = shape
+    A = tvm.placeholder((n, l), name='A', dtype='float16')
+    B = tvm.placeholder((l, m), name='B', dtype='float16')
+    k = tvm.reduce_axis((0, l), name="k")
+    C = tvm.compute((n, m),
+                    lambda ii, jj:
+                    tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
+                    name='C')
+    BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=n * l)
+    BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=l * m)
+    BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=n * m)
+
+    def intrin_func(ins, outs):
+        BA, BB = ins
+        BC, = outs
+
+        def init():
+            ib = tvm.ir_builder.create()
+            ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0))
+            return ib.get()
+
+        def update():
+            ib = tvm.ir_builder.create()
+            ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
+                                    BC.data, BC.elem_offset // (n * m),
+                                    BA.data, BA.elem_offset // (n * l),
+                                    BB.data, BB.elem_offset // (l * m),
+                                    BC.data, BC.elem_offset // (n * m)))
+            return ib.get()
+
+        return update(), init(), update()
+
+    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
+
+
+def intrin_wmma_store_matrix(shape):
+    n, m, l = shape
+    A = tvm.placeholder((n, m), name='A', dtype='float32')
+    BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=n * m)
+    C = tvm.compute((n, m), lambda i, j: A[i, j], name='C')
+    BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=n * m)
+
+    def intrin_func(ins, outs):
+        ib = tvm.ir_builder.create()
+
+        BA = ins[0]
+        BC = outs[0]
+        ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync',
+                                BA.data, n, m, l, BA.elem_offset // (n * m),
+                                BC.access_ptr('w'), m, 'row_major'))
+        return ib.get()
+
+    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
+
+
+def test_tensor_core_batch_matmal():
+    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+    if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        print("skip because gpu does not support tensor core")
+        return
+
+    batch_size = 4
+    n = 512
+    m, l = n, n
+    assert (n % 32 == 0)
+    assert (m % 8 == 0)
+    assert (l % 16 == 0)
+    nn, mm, ll = n // 32, m // 8, l // 16
+    A = tvm.placeholder((batch_size, nn, ll, 32, 16), name='A', dtype='float16')
+    B = tvm.placeholder((batch_size, ll, mm, 16, 8), name='B', dtype='float16')
+    k1 = tvm.reduce_axis((0, ll), name='k1')
+    k2 = tvm.reduce_axis((0, 16), name='k2')
+    C = tvm.compute((batch_size, nn, mm, 32, 8),
+                    lambda b, i, j, ii, jj:
+                    tvm.sum(A[b, i, k1, ii, k2].astype('float') * B[b, k1, j, k2, jj].astype('float'), axis=[k1, k2]),
+                    name='Fragment_C')
+    s = tvm.create_schedule(C.op)
+
+    warp_size = 32
+    kernel_size = 16
+    block_row_warps = 2
+    block_col_warps = 4
+    warp_row_tiles = 4
+    warp_col_tiles = 2
+    chunk = 4
+
+    block_x = tvm.thread_axis('blockIdx.x')
+    block_y = tvm.thread_axis('blockIdx.y')
+    block_z = tvm.thread_axis('blockIdx.z')
+    thread_x = tvm.thread_axis('threadIdx.x')
+    thread_y = tvm.thread_axis('threadIdx.y')
+    thread_z = tvm.thread_axis('threadIdx.z')
+
+    AS = s.cache_read(A, 'shared', [C])
+    BS = s.cache_read(B, 'shared', [C])
+    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+    CF = s.cache_write(C, 'wmma.accumulator')
+
+    b, i, j, kernel_i, kernel_j = s[C].op.axis
+    i, ii = s[C].split(i, factor=warp_row_tiles)
+    block_i, i = s[C].split(i, factor=block_row_warps)
+    j, jj = s[C].split(j, factor=warp_col_tiles)
+    block_j, j = s[C].split(j, factor=block_col_warps)
+    s[C].reorder(block_i, block_j, i, j, ii, jj, kernel_i, kernel_j)
+    s[C].bind(b, block_z)
+    s[C].bind(block_i, block_x)
+    s[C].bind(block_j, block_y)
+    s[C].bind(i, thread_y)
+    s[C].bind(j, thread_z)
+
+    s[CF].compute_at(s[C], j)
+    b, warp_i, warp_j, _i, _j = s[CF].op.axis
+    k, _k = CF.op.reduce_axis
+    ko, ki = s[CF].split(k, factor=chunk)
+    s[CF].reorder(ko, ki, warp_i, warp_j, _i, _j, _k)
+
+    s[AF].compute_at(s[CF], ki)
+    s[BF].compute_at(s[CF], ki)
+
+    s[AS].compute_at(s[CF], ko)
+    b, xo, yo, xi, yi = AS.op.axis
+    tx, xo = s[AS].split(xo, nparts=block_row_warps)
+    ty, yo = s[AS].split(yo, nparts=block_col_warps)
+    t = s[AS].fuse(xi, yi)
+    to, ti = s[AS].split(t, nparts=warp_size)
+    s[AS].bind(tx, thread_y)
+    s[AS].bind(ty, thread_z)
+    s[AS].bind(to, thread_x)
+
+    s[BS].compute_at(s[CF], ko)
+    b, xo, yo, xi, yi = BS.op.axis
+    tx, xo = s[BS].split(xo, nparts=block_row_warps)
+    ty, yo = s[BS].split(yo, nparts=block_col_warps)
+    t = s[BS].fuse(xi, yi)
+    to, ti = s[BS].split(t, nparts=warp_size)
+    s[BS].bind(tx, thread_y)
+    s[BS].bind(ty, thread_z)
+    s[BS].bind(to, thread_x)
+
+    s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_a'))
+    s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_b'))
+    s[C].tensorize(kernel_i, intrin_wmma_store_matrix((32, 8, 16)))
+    s[CF].tensorize(_i, intrin_wmma_gemm((32, 8, 16)))
+
+    func = tvm.build(s, [A, B, C], 'cuda')
+
+    ctx = tvm.gpu(0)
+    a_np = np.random.uniform(size=(batch_size, nn, ll, 32, 16)).astype(A.dtype)
+    b_np = np.random.uniform(size=(batch_size, ll, mm, 16, 8)).astype(B.dtype)
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    c = tvm.nd.array(np.zeros((batch_size, nn, mm, 32, 8), dtype=C.dtype), ctx)
+    func(a, b, c)
+    evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
+    print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3))
+
+    if VERIFY:
+        func(a, b, c)
+        a_np = a_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
+        b_np = b_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
+        c_np = c.asnumpy().transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
+        np.testing.assert_allclose(c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4)
+
+
+
+def test_tensor_core_batch_conv():
+    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+    if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        print("skip because gpu does not support tensor core")
+        return
+
+    # The sizes of inputs and filters
+    batch_size = 32
+    height = 14
+    width = 14
+    in_channels = 32
+    out_channels = 64
+    kernel_h = 3
+    kernel_w = 3
+    pad_h = 1
+    pad_w = 1
+    stride_h = 1
+    stride_w = 1
+    block_size = 16
+
+    block_row_warps = 2
+    block_col_warps = 4
+    warp_row_tiles = 4
+    warp_col_tiles = 2
+    warp_size = 32
+    chunk = 2
+
+    # Input feature map: (N, H, W, IC, n, ic)
+    data_shape = (batch_size // block_size,
+                  height,
+                  width,
+                  in_channels // block_size,
+                  block_size,
+                  block_size)
+    # Kernel: (H, W, IC, OC, ic, oc)
+    kernel_shape = (kernel_h,
+                    kernel_w,
+                    in_channels // block_size,
+                    out_channels // block_size,
+                    block_size,
+                    block_size)
+
+    # Output feature map: (N, H, W, OC, n, oc)
+    output_shape = (batch_size // block_size,
+                    height,
+                    width,
+                    out_channels // block_size,
+                    block_size,
+                    block_size)
+
+    assert (batch_size % block_size == 0)
+    assert (in_channels % block_size == 0)
+    assert (out_channels % block_size == 0)
+
+    kh = tvm.reduce_axis((0, kernel_h), name='kh')
+    kw = tvm.reduce_axis((0, kernel_w), name='kw')
+    ic = tvm.reduce_axis((0, in_channels // block_size), name='ic')
+    ii = tvm.reduce_axis((0, block_size), name='ii')
+
+    # Algorithm
+    A = tvm.placeholder(data_shape, name='A', dtype="float16")
+    W = tvm.placeholder(kernel_shape, name='W', dtype="float16")
+    Apad = tvm.compute(
+        (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size,
+         block_size),
+        lambda n, h, w, i, nn, ii: tvm.if_then_else(
+            tvm.all(h >= pad_h, h - pad_h < height,
+                    w >= pad_w, w - pad_w < width),
+            A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")),
+        name='Apad')
+    Conv = tvm.compute(output_shape,
+                       lambda n, h, w, o, nn, oo: tvm.sum(
+                           Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") *
+                           W[kh, kw, ic, o, ii, oo].astype("float32"),
+                           axis=[ic, kh, kw, ii]),
+                       name="Conv")
+
+    s = tvm.create_schedule(Conv.op)
+    s[Apad].compute_inline()
+
+    AS = s.cache_read(Apad, 'shared', [Conv])
+    WS = s.cache_read(W, 'shared', [Conv])
+    AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
+    WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
+    ConvF = s.cache_write(Conv, 'wmma.accumulator')
+
+    block_x = tvm.thread_axis('blockIdx.x')
+    block_y = tvm.thread_axis('blockIdx.y')
+    block_z = tvm.thread_axis('blockIdx.z')
+    thread_x = tvm.thread_axis('threadIdx.x')
+    thread_y = tvm.thread_axis('threadIdx.y')
+    thread_z = tvm.thread_axis('threadIdx.z')
+
+    nc, hc, wc, oc, nnc, ooc = Conv.op.axis
+    block_k = s[Conv].fuse(hc, wc)
+    s[Conv].bind(block_k, block_z)
+    nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
+    block_i, nc = s[Conv].split(nc, factor=block_row_warps)
+    oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
+    block_j, oc = s[Conv].split(oc, factor=block_col_warps)
+    s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
+    s[Conv].bind(block_i, block_x)
+    s[Conv].bind(block_j, block_y)
+    s[Conv].bind(nc, thread_y)
+    s[Conv].bind(oc, thread_z)
+
+    s[ConvF].compute_at(s[Conv], oc)
+    n, h, w, o, nnf, oof = ConvF.op.axis
+    ko, ki = s[ConvF].split(ic, factor=chunk)
+    s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
+
+    s[AF].compute_at(s[ConvF], kw)
+    s[WF].compute_at(s[ConvF], kw)
+
+    s[WS].compute_at(s[ConvF], kh)
+    s[AS].compute_at(s[ConvF], kh)
+
+    n, h, w, i, nn, ii = AS.op.axis
+    tx, xo = s[AS].split(n, nparts=block_row_warps)
+    ty, yo = s[AS].split(xo, nparts=block_col_warps)
+    t = s[AS].fuse(nn, ii)
+    to, ti = s[AS].split(t, factor=warp_size)
+    s[AS].bind(tx, thread_y)
+    s[AS].bind(ty, thread_z)
+    s[AS].bind(ti, thread_x)
+
+    kh, kw, ic, o, ii, oo = WS.op.axis
+    tx, xo = s[WS].split(o, nparts=block_row_warps)
+    ty, yo = s[WS].split(xo, nparts=block_col_warps)
+    t = s[WS].fuse(ii, oo)
+    to, ti = s[WS].split(t, nparts=warp_size)
+    s[WS].bind(tx, thread_y)
+    s[WS].bind(ty, thread_z)
+    s[WS].bind(to, thread_x)
+    s[WS].vectorize(ti)
+
+    s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_a'))
+    s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_b'))
+    s[Conv].tensorize(nnc, intrin_wmma_store_matrix((16, 16, 16)))
+    s[ConvF].tensorize(nnf, intrin_wmma_gemm((16, 16, 16)))
+
+    func = tvm.build(s, [A, W, Conv], 'cuda')
+
+    ctx = tvm.gpu(0)
+    a_np = np.random.uniform(size=data_shape).astype(A.dtype)
+    w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
+    a = tvm.nd.array(a_np, ctx)
+    w = tvm.nd.array(w_np, ctx)
+    c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
+    evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
+    print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))
+
+    if VERIFY:
+        func(a, w, c)
+        a_np = a_np.transpose(0, 4, 1, 2, 3, 5).reshape(batch_size, height, width, in_channels)
+        w_np = w_np.transpose(0, 1, 2, 4, 3, 5).reshape(kernel_h, kernel_w, in_channels, out_channels)
+        c_np = c.asnumpy().transpose((0, 4, 1, 2, 3, 5)).reshape(batch_size, height, width, out_channels)
+        c_std = conv2d_nhwc_python(a_np.astype(Conv.dtype),
+                                   w_np.astype(Conv.dtype),
+                                   (stride_h, stride_w),
+                                   (pad_h, pad_w)).astype(Conv.dtype)
+        np.testing.assert_allclose(c_np, c_std, rtol=1e-4, atol=1e-4)
+
+
+if __name__ == '__main__':
+    test_tensor_core_batch_matmal()
+    test_tensor_core_batch_conv()
index d2ef40c..8a6a467 100644 (file)
@@ -40,7 +40,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding):
     Returns
     -------
     b_np : np.ndarray
-        4-D with shape [out_height, out_width, out_channel, batch]
+        4-D with shape [batch, out_height, out_width, out_channel]
     """
     batch, in_height, in_width, in_channel = a_np.shape
     kernel_h, kernel_w, _, num_filter = w_np.shape
diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py
new file mode 100644 (file)
index 0000000..774b4c7
--- /dev/null
@@ -0,0 +1,348 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+.. _opt-conv-tensorcore:
+
+How to optimize convolution using TensorCores
+==================================
+**Author**: `Siyuan Feng <https://github.com/Hzfengsy>`_
+
+In this tutorial, we will demonstrate how to write a high performance convolution
+schedule using TensorCores in TVM. In this example, we assume the input to
+convolution has a large batch. We strongly recommend covering the :ref:`opt-conv-gpu` tutorial first.
+
+"""
+
+################################################################
+# TensorCore Introduction
+# -------------------------
+# Each Tensor Core provides a 4x4x4 matrix processing array that operates
+# :code:`D = A * B + C`, where A, B, C and D are 4x4 matrices as Figure shows.
+# The matrix multiplication inputs A and B are FP16 matrices, while the accumulation
+# matrices C and D may be FP16 or FP32 matrices.
+#
+# However, CUDA programmers can only use warp-level primitive
+# :code:`wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` to perform
+# 16x16x16 half-precision matrix multiplication on tensor cores. Before invoking
+# the matrix multiplication, programmers must load data from memory into registers
+# with primitive :code:`wmma::load_matrix_sync`, explicitly. The NVCC compiler translates
+# that primitive into multiple memory load instructions. At run time, every thread loads
+# 16 elements from matrix A and 16 elements from B.
+
+################################################################
+# Preparation and Algorithm
+# --------------------------
+# We use the fixed size for input tensors with 256 channels and 14 x 14 dimensions.
+# The batch size is 256. Convolution filters contain 512 filters of size 3 x 3.
+# We use stride size 1 and padding size 1 for the convolution. In the example, we use
+# NHWCnc memory layout.The following code defines the convolution algorithm in TVM.
+
+import tvm
+import numpy as np
+from tvm.contrib import nvcc
+
+# The sizes of inputs and filters
+batch_size = 256
+height = 14
+width = 14
+in_channels = 256
+out_channels = 512
+kernel_h = 3
+kernel_w = 3
+pad_h = 1
+pad_w = 1
+stride_h = 1
+stride_w = 1
+
+# TensorCore shape
+block_size = 16
+
+assert (batch_size % block_size == 0)
+assert (in_channels % block_size == 0)
+assert (out_channels % block_size == 0)
+
+# Input feature map: (N, H, W, IC, n, ic)
+data_shape = (batch_size // block_size,
+              height,
+              width,
+              in_channels // block_size,
+              block_size,
+              block_size)
+# Kernel: (H, W, IC, OC, ic, oc)
+kernel_shape = (kernel_h,
+                kernel_w,
+                in_channels // block_size,
+                out_channels // block_size,
+                block_size,
+                block_size)
+# Output feature map: (N, H, W, OC, n, oc)
+output_shape = (batch_size // block_size,
+                height,
+                width,
+                out_channels // block_size,
+                block_size,
+                block_size)
+
+# Reduction axes
+kh = tvm.reduce_axis((0, kernel_h), name='kh')
+kw = tvm.reduce_axis((0, kernel_w), name='kw')
+ic = tvm.reduce_axis((0, in_channels // block_size), name='ic')
+ii = tvm.reduce_axis((0, block_size), name='ii')
+
+# Algorithm
+A = tvm.placeholder(data_shape, name='A', dtype="float16")
+W = tvm.placeholder(kernel_shape, name='W', dtype="float16")
+Apad = tvm.compute(
+    (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size,
+     block_size),
+    lambda n, h, w, i, nn, ii: tvm.if_then_else(
+        tvm.all(h >= pad_h, h - pad_h < height,
+                w >= pad_w, w - pad_w < width),
+        A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")),
+    name='Apad')
+Conv = tvm.compute(output_shape,
+                   lambda n, h, w, o, nn, oo: tvm.sum(
+                       Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") *
+                       W[kh, kw, ic, o, ii, oo].astype("float32"),
+                       axis=[ic, kh, kw, ii]),
+                   name="Conv")
+
+s = tvm.create_schedule(Conv.op)
+s[Apad].compute_inline()
+
+###############################################################################
+# Memory Scope
+# ----------------
+#
+# In traditional GPU schedule, we have global, shared and local memory scope.
+# To support TensorCores, we add another three special memory scope: :code:`wmma.matrix_a`,
+# :code:`wmma.matrix_b` and :code:`wmma.accumulator`. On hardware, all fragments scope
+# stores at the on-chip registers level, the same place with local memory.
+
+# Designate the memory hierarchy
+AS = s.cache_read(Apad, 'shared', [Conv])
+WS = s.cache_read(W, 'shared', [Conv])
+AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
+WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
+ConvF = s.cache_write(Conv, 'wmma.accumulator')
+
+###############################################################################
+# Define Tensor Intrinsic
+# In fact, TensorCore is a special hardware operation. So, we can just use tensorize
+# to replace a unit of computation with the TensorCore instruction. The first thing is
+# that we need to define tensor intrinsic.
+#
+# There are four basic operation in TensorCore: :code:`fill_fragment`, :code:`load_matrix`,
+# :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync`
+# are both used in matrix multiplication, so we can just write following three intrinsics.
+
+def intrin_wmma_load_matrix(scope):
+    n = 16
+    A = tvm.placeholder((n, n), name='A', dtype='float16')
+    BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256)
+    C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
+    BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)
+
+    def intrin_func(ins, outs):
+        ib = tvm.ir_builder.create()
+
+        BA = ins[0]
+        BC = outs[0]
+        ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
+                                BC.data, n, n, n, BC.elem_offset // 256,
+                                BA.access_ptr('r'), n, 'row_major'))
+        return ib.get()
+
+    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
+
+
+def intrin_wmma_gemm():
+    n = 16
+    A = tvm.placeholder((n, n), name='A', dtype='float16')
+    B = tvm.placeholder((n, n), name='B', dtype='float16')
+    k = tvm.reduce_axis((0, n), name="k")
+    C = tvm.compute((n, n),
+                    lambda ii, jj:
+                    tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
+                    name='C')
+    BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
+    BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
+    BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)
+
+    def intrin_func(ins, outs):
+        BA, BB = ins
+        BC, = outs
+
+        def init():
+            ib = tvm.ir_builder.create()
+            ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
+            return ib.get()
+
+        def update():
+            ib = tvm.ir_builder.create()
+            ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
+                                    BC.data, BC.elem_offset // 256,
+                                    BA.data, BA.elem_offset // 256,
+                                    BB.data, BB.elem_offset // 256,
+                                    BC.data, BC.elem_offset // 256))
+            return ib.get()
+
+        return update(), init(), update()
+
+    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
+
+
+def intrin_wmma_store_matrix():
+    n = 16
+    A = tvm.placeholder((n, n), name='A', dtype='float32')
+    BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256)
+    C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
+    BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256)
+
+    def intrin_func(ins, outs):
+        ib = tvm.ir_builder.create()
+        BA = ins[0]
+        BC = outs[0]
+        ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync',
+                                BA.data, n, n, n, BA.elem_offset // 256,
+                                BC.access_ptr('w'), n, 'row_major'))
+        return ib.get()
+
+    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
+
+###############################################################################
+# Scheduling the Computation
+# --------------------------
+# To use TensorCores in TVM, we must schedule the computation into specific structure
+# to match the tensor intrinsic. The same as traditional GPU programs, we can also use
+# shared memory to boost the speed. If you have any questions about blocking and shared
+# memory, please refer :ref:`opt-conv-gpu`.
+#
+# In this example, each block contains 2x4 warps, and each warp calls 4x2 TensorCore
+# instructions. Thus, the output shape of each warp is 64x32 and each block outputs
+# 128x128 titles. Due to the limit of shared memory space, we only load 2 blocks (2x128x128 tiles)
+# one time.
+#
+# .. note::
+#
+#   *Warp-level Operation*
+#
+#   Note that all TensorCore instructions are warp-level instructions, which means all 32 threads
+#   in a warp should do this instruction simultaneously. Making theadIdx.x extent=32 is one of the
+#   easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain
+#   TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution.
+#   The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time.
+#
+
+# Define tiling sizes
+block_row_warps = 4
+block_col_warps = 2
+warp_row_tiles = 2
+warp_col_tiles = 4
+warp_size = 32
+chunk = 2
+
+block_x = tvm.thread_axis('blockIdx.x')
+block_y = tvm.thread_axis('blockIdx.y')
+block_z = tvm.thread_axis('blockIdx.z')
+thread_x = tvm.thread_axis('threadIdx.x')
+thread_y = tvm.thread_axis('threadIdx.y')
+thread_z = tvm.thread_axis('threadIdx.z')
+
+nc, hc, wc, oc, nnc, ooc = Conv.op.axis
+block_k = s[Conv].fuse(hc, wc)
+s[Conv].bind(block_k, block_z)
+nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
+block_i, nc = s[Conv].split(nc, factor=block_row_warps)
+oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
+block_j, oc = s[Conv].split(oc, factor=block_col_warps)
+s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
+s[Conv].bind(block_i, block_x)
+s[Conv].bind(block_j, block_y)
+s[Conv].bind(nc, thread_y)
+s[Conv].bind(oc, thread_z)
+
+# Schedule local computation
+s[ConvF].compute_at(s[Conv], oc)
+n, h, w, o, nnf, oof = ConvF.op.axis
+ko, ki = s[ConvF].split(ic, factor=chunk)
+s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
+
+# Move intermediate computation into each output compute tile
+s[AF].compute_at(s[ConvF], kw)
+s[WF].compute_at(s[ConvF], kw)
+
+# Schedule for A's share memory
+s[AS].compute_at(s[ConvF], kh)
+n, h, w, i, nn, ii = AS.op.axis
+tx, xo = s[AS].split(n, nparts=block_row_warps)
+ty, yo = s[AS].split(xo, nparts=block_col_warps)
+t = s[AS].fuse(nn, ii)
+to, ti = s[AS].split(t, factor=warp_size)
+s[AS].bind(tx, thread_y)
+s[AS].bind(ty, thread_z)
+s[AS].bind(ti, thread_x)
+
+# Schedule for W's share memory
+s[WS].compute_at(s[ConvF], kh)
+kh, kw, ic, o, ii, oo = WS.op.axis
+tx, xo = s[WS].split(o, nparts=block_row_warps)
+ty, yo = s[WS].split(xo, nparts=block_col_warps)
+t = s[WS].fuse(ii, oo)
+to, ti = s[WS].split(t, nparts=warp_size)
+s[WS].bind(tx, thread_y)
+s[WS].bind(ty, thread_z)
+s[WS].bind(to, thread_x)
+s[WS].vectorize(ti)
+print(tvm.lower(s, [A, W, Conv], simple_mode=True))
+
+###############################################################################
+# Lowering Computation to Intrinsics
+# --------------------------
+# The last phase is to lower the computation loops down to TensorCore hardware intrinsics
+# by mapping the 2D convolution to tensor intrinsics
+#
+
+s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a'))
+s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b'))
+s[Conv].tensorize(nnc, intrin_wmma_store_matrix())
+s[ConvF].tensorize(nnf, intrin_wmma_gemm())
+print(tvm.lower(s, [A, W, Conv], simple_mode=True))
+
+###############################################################################
+# Generate CUDA Kernel
+# --------------------
+# Finally we use TVM to generate and compile the CUDA kernel, and evaluate the latency of convolution.
+# Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not
+# be able to run on our build server
+
+ctx = tvm.gpu(0)
+if nvcc.have_tensorcore(ctx.compute_version):
+    with tvm.build_config(auto_unroll_max_step=16):
+        func = tvm.build(s, [A, W, Conv], 'cuda')
+    a_np = np.random.uniform(size=data_shape).astype(A.dtype)
+    w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
+    a = tvm.nd.array(a_np, ctx)
+    w = tvm.nd.array(w_np, ctx)
+    c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
+    evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
+    print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))
+
+###############################################################################
+# Summary
+# This tutorial demonstrates how TVM scheduling primitives can be used to
+# call TensorCores on specific GPUs.
index 5c24375..cec217c 100644 (file)
@@ -80,6 +80,7 @@ def build_config(debug_flag=0, **kwargs):
     if debug_flag:
         pass_list.append((1, add_debug))
     pass_list.append((2, ir_pass.inject_alu_intrin))
+    pass_list.append((3, tvm.ir_pass.LowerStorageAccessInfo))
     pass_list.append((3, ir_pass.fold_uop_loop))
     pass_list.append((3, ir_pass.cpu_access_rewrite))
     return tvm.build_config(add_lower_pass=pass_list, **kwargs)