Auto TensorCore CodeGen (#4234)
authorMinmin Sun (孙敏敏) <mmsun@foxmail.com>
Sat, 9 Nov 2019 21:01:36 +0000 (05:01 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sat, 9 Nov 2019 21:01:36 +0000 (13:01 -0800)
* Add Auto TensorCore TensorCore Unit Test

* Rebase to tvm master branch & Add auto tensor core

* Code Refine

* Add tensor core switch by pragma

* Add pragma in tensor core example code

* Get real tile size to replace hard coded 16

* support more than 2 dimensions (e.g. batchmatmul) for buffer bind scope

* support batch matmul

* Move cuda env check to tensor_core.cc

* Coderefine for tensor_core.cc

* Refine comments

* Some refinements of code and comment

* Update TensorCore UT to pass the CPU test

* remove redundant code

* matmul's storage align for different layout

* Add support for differenct position of type cast

* Add formal tutorial for auto tensorcore codegen

* move tensorcore check up to tutorial code

* code and doc refine

* comment out tune_and_evaluate in tutorial

* fix cpplint error

include/tvm/ir.h
include/tvm/ir_pass.h
python/tvm/build_module.py
src/api/api_pass.cc
src/pass/tensor_core.cc [new file with mode: 0644]
tests/python/unittest/test_pass_rewrite_for_tensor_core.py [new file with mode: 0644]
tutorials/optimize/opt_matmul_auto_tensorcore.py [new file with mode: 0644]

index b6c3028..aca45f4 100644 (file)
@@ -1248,6 +1248,8 @@ constexpr const char* reduce_scope = "reduce_scope";
 constexpr const char* pragma_scope_prefix = "pragma_";
 /*! \brief Import llvm source or file into the final code gen module */
 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
+/*! \brief Try to modify the AST to support Tensor Core */
+constexpr const char* pragma_tensor_core = "pragma_tensor_core";
 /*!
  * \brief Mark of prefetch scope, value=offset,
  *  run prefetch of Tensor on the current loop scope
index 842c6af..76d7d61 100644 (file)
@@ -206,6 +206,20 @@ Stmt StorageFlatten(Stmt stmt,
                     Map<Tensor, Buffer> extern_buffer,
                     int cache_line_size,
                     bool create_bound_attribute = false);
+
+/*!
+ * \brief Try to modify the AST to support TensorCore
+ *
+ * \param stmt The stmt to be trasnformed.
+ * \param schedule The original schedule.
+ * \param extern_buffer Map specifies external
+ *    buffer assignment of input and outputs.
+ * \return Transformed stmt.
+ */
+Stmt RewriteForTensorCore(Stmt stmt,
+                          Schedule schedule,
+                          Map<Tensor, Buffer> extern_buffer);
+
 /*!
  * \brief Verify if there is any argument bound to compact buffer.
  *
index fe2f641..217318e 100644 (file)
@@ -387,6 +387,7 @@ def lower(sch,
     binds, arg_list = get_binds(args, compact, binds)
 
     # Phase 1
+    stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
     stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
     stmt = ir_pass.CanonicalSimplify(stmt)
     for f in lower_phase1:
index d7f621f..6994e2d 100644 (file)
@@ -94,6 +94,12 @@ TVM_REGISTER_API("ir_pass.StorageFlatten")
     }
   });
 
+TVM_REGISTER_API("ir_pass.RewriteForTensorCore")
+.set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)>
+  ([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
+      return RewriteForTensorCore(stmt, schedule, extern_buffer);
+  });
+
 TVM_REGISTER_API("ir_pass.AttrsEqual")
 .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) {
     return AttrsEqual()(lhs, rhs);
diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc
new file mode 100644 (file)
index 0000000..183ed3b
--- /dev/null
@@ -0,0 +1,1202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tensor_core.cc
+ */
+// IR Passes for TensorCore CodeGen
+#include <tvm/ir.h>
+#include <tvm/expr.h>
+#include <tvm/operation.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_visitor.h>
+#include <tvm/expr_operator.h>
+#include <tvm/ir_pass.h>
+#include <tvm/buffer.h>
+#include <tvm/target_info.h>
+#include <tvm/build_module.h>
+#include <tvm/runtime/device_api.h>
+#include <unordered_map>
+#include "ir_util.h"
+#include "../arithmetic/compute_expr.h"
+#include "../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace ir {
+
+using runtime::StorageRank;
+using runtime::StorageScope;
+using runtime::ThreadScope;
+using intrinsic::tvm_address_of;
+
+struct Tile {
+  int m{-1};
+  int n{-1};
+  int k{-1};
+};
+
+std::string simplify_name(std::string input) {
+  auto pos = input.find(".");
+  if (pos != std::string::npos) {
+    return input.substr(0, pos);
+  } else {
+    return input;
+  }
+}
+
+Expr unpack_type_cast(const Expr &input, const Type &target_type) {
+  auto cast = input.as<Cast>();
+  if (cast == nullptr) {
+    return input;
+  } else if (cast->type == target_type) {
+    return cast->value;
+  }
+  return Expr();
+}
+
+// MMAMatcher matches C = Cast(A)*Cast(B)+C,
+// where A & B are fp16/int8 local buffers,
+// and C is fp32/int32 local buffer.
+class MMAMatcher: public IRVisitor {
+ public:
+  explicit MMAMatcher(Map<Tensor, Buffer> extern_buffer) {
+    for (auto kv : extern_buffer) {
+      BufferInfo bi;
+      bi.name = kv.second->name;
+      bi.dtype = kv.second->dtype;
+      bi.external = true;
+      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi;
+    }
+  }
+  using IRVisitor::Visit_;
+
+  void Visit_(const AttrStmt* op) final {
+    if (op->attr_key == attr::pragma_tensor_core) {
+      tensor_core_on_ = true;
+      IRVisitor::Visit_(op);
+    } else if (op->attr_key == attr::realize_scope) {
+      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+      Visit(op->body);
+    } else {
+      IRVisitor::Visit_(op);
+    }
+  }
+
+  void Visit_(const Provide* op) final {
+    IRVisitor::Visit_(op);
+    auto it = buf_map_.find(TensorKey{op->func, op->value_index});
+    if (it == buf_map_.end()) {
+      return;
+    }
+    const BufferInfo& bi = it->second;
+    if (bi.released) {
+      return;
+    }
+    if (tensor_core_on_ && mma_sync_match_(op, bi)) {
+      matched_ = true;
+    }
+  }
+
+  void Visit_(const Realize* op) final {
+    TensorKey key{op->func, op->value_index};
+    if (buf_map_.count(key)) {
+      if (!buf_map_.at(key).external) {
+        return;
+      }
+      Visit(op->body);
+    } else {
+      BufferInfo bi;
+      bi.name = key.GetName();
+      bi.dtype = op->type;
+      buf_map_[key] = bi;
+      Visit(op->body);
+      buf_map_[key].released = true;
+    }
+  }
+
+  inline bool Matched() const {return matched_;}
+
+  friend class ScheduleAnalyser;
+  friend class BufferAnalyser;
+
+ private:
+  struct BufferInfo {
+    std::string name;
+    Type dtype;
+    bool external{false};
+    bool released{false};
+    bool same_as(const BufferInfo &bi) {
+      if (this->dtype != bi.dtype) return false;
+      if (this->name != bi.name) return false;
+      if (this->external != bi.external) return false;
+      if (this->released != bi.released) return false;
+      return true;
+    }
+  };
+
+  // Check whether the storage scope is local
+  bool check_local_buffer_(const Call* op, BufferInfo* bi) {
+    if (op->call_type == Call::Halide) {
+      auto it = storage_scope_.find(op->func.get());
+      if (it == storage_scope_.end()) {
+        return false;
+      }
+      const std::string& strkey = it->second;
+      if (strkey != "local") {
+        return false;
+      }
+      auto it1 = buf_map_.find(TensorKey{op->func, op->value_index});
+      if (it1 == buf_map_.end()) {
+        return false;
+      }
+      *bi = it1->second;
+      if (bi->released) {
+        return false;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Do the pattern matching
+  bool mma_sync_match_(const Provide* op, BufferInfo store_buffer) {
+    auto* add = op->value.as<Add>();
+    if (add == nullptr) {
+      return false;
+    }
+
+    auto* load_c = add->a.as<Call>();
+    BufferInfo buffer_c;
+    if (!check_local_buffer_(load_c, &buffer_c)
+        || !buffer_c.same_as(store_buffer)
+        || !(buffer_c.dtype == Float(32) ||
+             buffer_c.dtype == Int(32))) {
+      return false;
+    }
+
+    auto mul = unpack_type_cast(add->b, buffer_c.dtype).as<Mul>();
+    if (mul == nullptr) {
+      return false;
+    }
+
+    auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype);
+    auto load_a = load_a_expr.as<Call>();
+    BufferInfo buffer_a;
+    if (!check_local_buffer_(load_a, &buffer_a)
+        || !(buffer_a.dtype == Float(16) ||
+             buffer_a.dtype == Int(8))) {
+      return false;
+    }
+
+    auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype);
+    auto load_b = load_b_expr.as<Call>();
+    BufferInfo buffer_b;
+    if (!check_local_buffer_(load_b, &buffer_b)
+        || !(buffer_b.dtype == Float(16) ||
+             buffer_b.dtype == Int(8))) {
+      return false;
+    }
+
+    frag_reg_.insert(buffer_c.name);
+    frag_reg_.insert(buffer_a.name);
+    frag_reg_.insert(buffer_b.name);
+    buf_name_.insert(std::make_pair(load_a, buffer_a.name));
+    buf_name_.insert(std::make_pair(load_b, buffer_b.name));
+    mma_sync_.insert(std::make_pair(op,
+      Array<Expr>{load_a_expr, load_b_expr, add->a}));
+
+    return true;
+  }
+
+  std::unordered_map<TensorKey, BufferInfo> buf_map_;
+  std::unordered_map<const Node*, std::string> storage_scope_;
+  std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+  std::unordered_map<const Node*, std::string> buf_name_;
+  std::unordered_set<std::string> frag_reg_;
+  bool matched_{false};
+  bool tensor_core_on_{false};
+};
+
+// BodyVisitor visits the body stmt of original ComputeOp
+// to get the access indices of input matrices,
+// if it is recognized as matrix multiply.
+class BodyVisitor : public IRVisitor {
+ public:
+  BodyVisitor() {}
+  using IRVisitor::Visit_;
+
+  void Visit_(const Reduce* op) final {
+    auto* comm_add = op->combiner->result[0].as<Add>();
+    if (comm_add == nullptr || op->combiner->result.size() > 1) {
+      return;
+    }
+    for (Expr source : op->source) {
+      auto mul_0 = unpack_type_cast(source, Float(32)).as<Mul>();
+      auto mul_1 = unpack_type_cast(source, Int(32)).as<Mul>();
+      if (mul_0 == nullptr && mul_1 == nullptr) {
+        continue;
+      }
+
+      tensorcore_candidate_ = true;
+      IRVisitor::Visit(source);
+    }
+  }
+
+  void Visit_(const Call* op) final {
+    IRVisitor::Visit_(op);
+    args_.insert(std::make_pair(op->name, op->args));
+  }
+
+  friend class ScheduleAnalyser;
+
+ private:
+  std::unordered_map<std::string, Array<Expr>> args_;
+  bool tensorcore_candidate_{false};
+};
+
+// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major
+class ScheduleAnalyser {
+ public:
+  explicit ScheduleAnalyser(const MMAMatcher &mma_matcher)
+    : mma_sync_(mma_matcher.mma_sync_),
+      buf_name_(mma_matcher.buf_name_) {}
+
+  bool MatrixIdentify(Schedule schedule) {
+    // TODO(minmin): handle the case where MatMul is not the output stage
+    for (Operation output : schedule->outputs) {
+      const ComputeOpNode* compute = output.as<ComputeOpNode>();
+      if (compute == nullptr) {
+        // Not a ComputeOp
+        continue;
+      }
+      auto axis = compute->axis;
+      auto reduce_axis = compute->reduce_axis;
+      if (axis.size() < 2 || reduce_axis.size() != 1) {
+        continue;
+      }
+      const Variable* axis_var[2];
+      const Variable* reduce_axis_var;
+      axis_var[0] = axis[axis.size()-2]->var.as<Variable>();
+      axis_var[1] = axis[axis.size()-1]->var.as<Variable>();
+      reduce_axis_var = reduce_axis[0]->var.as<Variable>();
+
+      BodyVisitor body_visitor;
+      for (Expr expr : compute->body) {
+        body_visitor.Visit(expr);
+      }
+      if (!body_visitor.tensorcore_candidate_) {
+        continue;
+      }
+      for (auto iter : body_visitor.args_) {
+        auto name = iter.first;
+        auto args = iter.second;
+        if (args.size() < 2) {
+          continue;
+        }
+        const Variable* var0 = args[args.size() - 2].as<Variable>();
+        const Variable* var1 = args[args.size() - 1].as<Variable>();
+        if (var0 == nullptr || var1 == nullptr) {
+          continue;
+        }
+        std::string matrix_abc, major;
+        if (var0 == reduce_axis_var && var1 == axis_var[1]) {
+          matrix_abc = "matrix_a";
+          major = "col_major";
+        } else if (var0 == reduce_axis_var && var1 == axis_var[0]) {
+          matrix_abc = "matrix_b";
+          major = "row_major";
+        } else if (var0 == axis_var[1] && var1 == reduce_axis_var) {
+          matrix_abc = "matrix_a";
+          major = "row_major";
+        } else if (var0 == axis_var[0] && var1 == reduce_axis_var) {
+          matrix_abc = "matrix_b";
+          major = "col_major";
+        }
+        matrix_abc_.insert(std::make_pair(name, matrix_abc));
+        matrix_major_.insert(std::make_pair(name, major));
+      }
+      matrix_abc_.insert(std::make_pair(compute->name, "accumulator"));
+      matrix_major_.insert(std::make_pair(compute->name, "col_major"));
+    }
+
+    for (auto &mma_sync : mma_sync_) {
+      auto &operands = mma_sync.second;
+      auto* load_a = operands[0].as<Call>();
+      auto* load_b = operands[1].as<Call>();
+      auto input0 = simplify_name(buf_name_.find(load_a)->second);
+      auto input1 = simplify_name(buf_name_.find(load_b)->second);
+      auto it0 = matrix_abc_.find(input0);
+      auto it1 = matrix_abc_.find(input1);
+
+      if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) {
+        return false;
+      }
+      if (it0->second == "matrix_a" && it1->second == "matrix_b") {
+        return true;
+      } else if (it0->second == "matrix_b" && it1->second == "matrix_a") {
+        mma_sync.second = Array<Expr>{operands[1], operands[0], operands[2]};
+      } else {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  friend class BufferAnalyser;
+  friend class TensorCoreIRMutator;
+
+ private:
+  std::unordered_map<std::string, std::string> matrix_abc_;
+  std::unordered_map<std::string, std::string> matrix_major_;
+  std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+  std::unordered_map<const Node*, std::string> buf_name_;
+};
+
+// IndexVisitor visits access index of fragment
+// to record variable for loop scaling
+class IndexVisitor : public IRVisitor {
+ public:
+  IndexVisitor() {}
+  using IRVisitor::Visit_;
+
+  void Visit_(const Variable* op) final {
+    loop_scaling_.insert(std::make_pair(op, scaling_factor_));
+  }
+
+  friend class BufferAnalyser;
+  friend class TensorCoreIRMutator;
+
+ private:
+  std::unordered_map<const Variable*, unsigned> loop_scaling_;
+  unsigned scaling_factor_{0};
+};
+
+// BufferAnalyser gets buffer info,
+// e.g. thread tile and warp tile, for TensorCore CodeGen
+class BufferAnalyser : public IRVisitor {
+ public:
+  explicit BufferAnalyser(Map<Tensor, Buffer> extern_buffer,
+                          const ScheduleAnalyser &schedule_analyser,
+                          const MMAMatcher &mma_matcher)
+      : matrix_abc_(schedule_analyser.matrix_abc_),
+        matrix_major_(schedule_analyser.matrix_major_),
+        frag_reg_(mma_matcher.frag_reg_) {
+    for (auto kv : extern_buffer) {
+      BufferInfo bi;
+      bi.name = kv.second->name;
+      bi.dtype = kv.second->dtype;
+      bi.strides = kv.second->strides;
+      bi.shape = kv.second->shape;
+      bi.external = true;
+      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi;
+    }
+  }
+  using IRVisitor::Visit_;
+
+  void Visit_(const AttrStmt* op) final {
+    if (op->attr_key == attr::thread_extent) {
+      if (const IntImm* value = op->value.as<IntImm>()) {
+        thread_extent_.insert(
+            std::make_pair(
+                op->node.as<IterVarNode>()->var->name_hint,
+                value->value));
+      }
+      IRVisitor::Visit_(op);
+    } else if (op->attr_key == attr::realize_scope) {
+      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+      Visit(op->body);
+    } else if (op->attr_key == attr::buffer_dim_align) {
+      Tensor tensor = Downcast<Tensor>(op->node);
+      const Call* tuple = op->value.as<Call>();
+      CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
+      auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}];
+      size_t dim = tuple->args[0].as<IntImm>()->value;
+      if (dim >= vinfo.size()) {
+        vinfo.resize(dim + 1);
+      }
+      vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
+      Visit(op->body);
+    } else {
+      IRVisitor::Visit_(op);
+    }
+  }
+
+  void Visit_(const Provide* op) final {
+    IRVisitor::Visit_(op);
+    TensorKey key{op->func, op->value_index};
+    auto it = buf_map_.find(key);
+    CHECK(it != buf_map_.end())
+        << "Cannot find allocated buffer for " << key.f;
+    const BufferInfo& bi = it->second;
+    CHECK(!bi.released)
+        << "Read a buffer that is already out of scope";
+
+    if (matrix_abc_.count(key.GetName())) {
+      if (bi.shape.size() < 2) {
+        invalid_ = true;
+        return;
+      }
+      for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
+        const IntImm* shape = bi.shape[i].as<IntImm>();
+        if (shape == nullptr || shape->value % 16 != 0) {
+          invalid_ = true;
+          return;
+        }
+      }
+    }
+
+    Array<Expr> strides;
+    if (bi.strides.size() > 0) {
+      strides = bi.strides;
+    } else {
+      for (size_t i = 1; i < bi.shape.size(); ++i) {
+        Expr stride = IntImm::make(Int(32), 1);
+        for (size_t j = bi.shape.size() - 1; j >= i; --j) {
+          stride = Mul::make(stride, bi.shape[j]);
+        }
+        strides.push_back(stride);
+      }
+      strides.push_back(make_const(Int(32), 1));
+    }
+    strides_.insert(std::make_pair(key.GetName(), strides));
+
+    if (frag_reg_.count(bi.name)) {
+      Expr dst = Call::make(bi.dtype,
+                            bi.name,
+                            op->args,
+                            Call::Halide,
+                            op->func,
+                            0);
+      frag_load_.insert(std::make_pair(op, dst));
+
+      auto rel_index = bi.RelIndex(op->args);
+      if (op->args.size() < 2) {
+        invalid_ = true;
+        return;
+      }
+      std::vector<int> tile_size;
+      for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
+        index_visitor.scaling_factor_ = 16;
+        if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
+          tile_size.push_back(shape->value);
+          index_visitor.scaling_factor_ = shape->value;
+        } else {
+          invalid_ = true;
+          return;
+        }
+        auto index = rel_index[i];
+        auto simplified_index = ir::Simplify(index);
+        index_visitor.Visit(simplified_index);
+      }
+
+      std::string input_name = simplify_name(bi.name);
+      auto it = matrix_abc_.find(input_name);
+      auto it2 = matrix_major_.find(input_name);
+      bool ret = true;
+      if (it != matrix_abc_.end() && it2 != matrix_major_.end()) {
+        if (it->second == "matrix_a" && it2->second == "col_major") {
+          ret &= assign_or_check_(&thread_tile_.m, tile_size[0]);
+          ret &= assign_or_check_(&thread_tile_.k, tile_size[1]);
+        }
+        if (it->second == "matrix_a" && it2->second == "row_major") {
+          ret &= assign_or_check_(&thread_tile_.k, tile_size[0]);
+          ret &= assign_or_check_(&thread_tile_.m, tile_size[1]);
+        }
+        if (it->second == "matrix_b" && it2->second == "col_major") {
+          ret &= assign_or_check_(&thread_tile_.k, tile_size[0]);
+          ret &= assign_or_check_(&thread_tile_.n, tile_size[1]);
+        }
+        if (it->second == "matrix_b" && it2->second == "row_major") {
+          ret &= assign_or_check_(&thread_tile_.n, tile_size[0]);
+          ret &= assign_or_check_(&thread_tile_.k, tile_size[1]);
+        }
+        if (it->second == "accumulator") {
+          ret &= assign_or_check_(&thread_tile_.m, tile_size[0]);
+          ret &= assign_or_check_(&thread_tile_.n, tile_size[1]);
+        }
+        if (!ret) {
+          invalid_ = true;
+          return;
+        }
+      }
+    }
+
+    const Call* value = op->value.as<Call>();
+    if (value != nullptr && frag_reg_.count(value->name)) {
+      Expr dst = Call::make(bi.dtype,
+                            bi.name,
+                            op->args,
+                            Call::Halide,
+                            op->func,
+                            0);
+      frag_store_.insert(std::make_pair(op, dst));
+    }
+  }
+
+  void Visit_(const Call* op) final {
+    IRVisitor::Visit_(op);
+    if (op->call_type == Call::Halide) {
+      TensorKey key{op->func, op->value_index};
+      auto it = buf_map_.find(key);
+      CHECK(it != buf_map_.end())
+          << "Cannot find allocated buffer for " << key.f;
+      const BufferInfo& bi = it->second;
+      CHECK(!bi.released)
+          << "Read a buffer that is already out of scope";
+
+      if (matrix_abc_.count(op->name)) {
+        if (bi.shape.size() < 2) {
+          invalid_ = true;
+          return;
+        }
+        for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
+          const IntImm* shape = bi.shape[i].as<IntImm>();
+          if (shape == nullptr || shape->value % 16 != 0) {
+            invalid_ = true;
+            return;
+          }
+        }
+      }
+
+      Array<Expr> strides;
+      if (bi.strides.size() > 0) {
+        strides = bi.strides;
+      } else {
+        for (size_t i = 1; i < bi.shape.size(); ++i) {
+          Expr stride = IntImm::make(Int(32), 1);
+          for (size_t j = bi.shape.size() - 1; j >= i; --j) {
+            stride = Mul::make(stride, bi.shape[j]);
+          }
+          strides.push_back(stride);
+        }
+        strides.push_back(make_const(Int(32), 1));
+      }
+      strides_.insert(std::make_pair(key.GetName(), strides));
+
+      if (!frag_reg_.count(bi.name)) {
+        return;
+      }
+
+      auto rel_index = bi.RelIndex(op->args);
+      if (op->args.size() < 2) {
+        invalid_ = true;
+        return;
+      }
+      for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
+        index_visitor.scaling_factor_ = 16;
+        if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
+          index_visitor.scaling_factor_ = shape->value;
+        }
+        auto index = rel_index[i];
+        auto simplified_index = ir::Simplify(index);
+        index_visitor.Visit(simplified_index);
+      }
+    }
+  }
+
+  void Visit_(const Realize* op) final {
+    TensorKey key{op->func, op->value_index};
+    if (buf_map_.count(key)) {
+      CHECK(buf_map_.at(key).external);
+      Visit(op->body);
+    } else {
+      // create a buffer entry
+      BufferInfo bi;
+
+      bi.bounds = op->bounds;
+      Array<Expr> shape;
+      for (auto r : bi.bounds) {
+        shape.push_back(r->extent);
+      }
+
+      Array<Expr> strides;
+      if (dim_align_.count(key) != 0 && shape.size() != 0) {
+        std::vector<Expr> rstrides;
+        const std::vector<DimAlignInfo>& avec = dim_align_[key];
+        int first_dim = 0;
+        Expr stride = make_const(shape[first_dim].type(), 1);
+        for (size_t i = shape.size(); i != 0; --i) {
+          size_t dim = i - 1;
+          if (dim < avec.size() && avec[dim].align_factor != 0) {
+            Expr factor = make_const(stride.type(), avec[dim].align_factor);
+            Expr offset = make_const(stride.type(), avec[dim].align_offset);
+            stride = stride + \
+              indexmod(factor + offset - indexmod(stride, factor), factor);
+            stride = ir::Simplify(stride);
+          }
+          rstrides.push_back(stride);
+          stride = stride * shape[dim];
+        }
+        strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
+      }
+
+      bi.name = key.GetName();
+      bi.dtype = op->type;
+      bi.strides = strides;
+      bi.shape = shape;
+
+      buf_map_[key] = bi;
+      Visit(op->body);
+      buf_map_[key].released = true;
+    }
+  }
+
+  // Derive warp tile from thread tile,
+  // and check whether it is qualified for TensorCore.
+  bool QualifiedForTensorCore() {
+    if (invalid_) {
+      return false;
+    }
+    auto itx = thread_extent_.find("threadIdx.x");
+    if (itx == thread_extent_.end()) {
+      return false;
+    }
+    int warp_threads_x = itx->second;
+    warp_tile_.m = warp_threads_x * thread_tile_.m;
+    warp_threads_y_ = 32 / warp_threads_x;
+    auto ity = thread_extent_.find("threadIdx.y");
+    if (ity == thread_extent_.end()) {
+      return false;
+    }
+    if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) {
+      return false;
+    }
+    warp_tile_.n = warp_threads_y_ * thread_tile_.n;
+    warp_tile_.k = thread_tile_.k;
+    return supported_warp_tile_();
+  }
+
+  friend class TensorCoreIRMutator;
+
+ private:
+  struct DimAlignInfo {
+    int align_factor{0};
+    int align_offset{0};
+  };
+
+  struct BufferInfo {
+    std::string name;
+    Type dtype;
+    Array<Expr> strides;
+    Array<Expr> shape;
+    Region bounds;
+    bool external{false};
+    bool released{false};
+    inline Array<Expr> RelIndex(Array<Expr> args) const {
+      if (bounds.size() != 0) {
+        Array<Expr> index;
+        CHECK_EQ(bounds.size(), args.size());
+        for (size_t i = 0; i < bounds.size(); ++i) {
+          index.push_back(args[i] - bounds[i]->min);
+        }
+        return index;
+      } else {
+        return args;
+      }
+    }
+  };
+
+  bool assign_or_check_(int* dst, int src) {
+    if (*dst <= 0) {
+      *dst = src;
+      return true;
+    }
+    if (*dst == src) {
+      return true;
+    }
+    return false;
+  }
+
+  bool supported_warp_tile_() {
+    if (warp_tile_.m == 16 &&
+        warp_tile_.n == 16 &&
+        warp_tile_.k == 16) {
+      return true;
+    }
+    if (warp_tile_.m == 8 &&
+        warp_tile_.n == 32 &&
+        warp_tile_.k == 16) {
+      return true;
+    }
+    if (warp_tile_.m == 32 &&
+        warp_tile_.n == 8 &&
+        warp_tile_.k == 16) {
+      return true;
+    }
+    return false;
+  }
+
+  std::unordered_map<TensorKey, BufferInfo> buf_map_;
+  std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
+  std::unordered_map<const Node*, std::string> storage_scope_;
+  std::unordered_map<std::string, std::string> matrix_abc_;
+  std::unordered_map<std::string, std::string> matrix_major_;
+  std::unordered_set<std::string> frag_reg_;
+  std::unordered_map<std::string, Array<Expr>> strides_;
+  std::unordered_map<const Provide*, Expr> frag_load_;
+  std::unordered_map<const Provide*, Expr> frag_store_;
+  std::unordered_map<std::string, int> thread_extent_;
+  IndexVisitor index_visitor;
+  Tile warp_tile_;
+  Tile thread_tile_;
+  int warp_threads_y_{-1};
+  bool invalid_{false};
+};
+
+// ThreadIdxMutator does the thread index unification inside a warp
+class ThreadIdxMutator : public IRMutator {
+ public:
+  explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {}
+
+  Expr Mutate_(const Variable* op, const Expr& olde) final {
+    Expr expr = IRMutator::Mutate_(op, olde);
+    op = expr.as<Variable>();
+    if (op != nullptr) {
+      if (op->name_hint == "threadIdx.x") {
+        Expr zero = IntImm::make(Int(32), 0);
+        return zero;
+      }
+      if (op->name_hint == "threadIdx.y") {
+        Expr div = Div::make(expr, warp_y_);
+        Expr mul = Mul::make(div, warp_y_);
+        return mul;
+      }
+    }
+    return expr;
+  }
+
+ private:
+  Expr warp_y_;
+};
+
+// TensorCoreIRMutator mutates the AST for TensorCore CodeGen
+// based on tensor core intrinsics
+class TensorCoreIRMutator : public IRMutator {
+ public:
+  explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser,
+    const BufferAnalyser &buffer_analyser)
+      : matrix_abc_(schedule_analyser.matrix_abc_),
+      matrix_major_(schedule_analyser.matrix_major_),
+      mma_sync_(schedule_analyser.mma_sync_),
+      strides_(buffer_analyser.strides_),
+      frag_reg_(buffer_analyser.frag_reg_),
+      loop_scaling_(buffer_analyser.index_visitor.loop_scaling_),
+      frag_load_(buffer_analyser.frag_load_),
+      frag_store_(buffer_analyser.frag_store_),
+      warp_tile_(buffer_analyser.warp_tile_),
+      warp_threads_y_(buffer_analyser.warp_threads_y_) {}
+
+  Stmt Mutate_(const Realize* op, const Stmt& s) final {
+    TensorKey key{op->func, op->value_index};
+    bounds_[key] = op->bounds;
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<Realize>();
+    if (op != nullptr) {
+      if (!frag_reg_.count(key.GetName())) {
+        return stmt;
+      }
+
+      auto new_extents = get_tile_size_(simplify_name(key.GetName()));
+
+      Region new_bounds;
+      for (size_t i = 0; i < op->bounds.size() - 2; ++i) {
+        new_bounds.push_back(op->bounds[i]);
+      }
+      CHECK_GE(op->bounds.size(), 2)
+          << "Less than 2 dimensions for matrix " << key.GetName();
+      new_bounds.push_back(Range::make_by_min_extent(
+          op->bounds[op->bounds.size() - 2]->min, new_extents[0]));
+      new_bounds.push_back(Range::make_by_min_extent(
+          op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
+
+      return Realize::make(op->func, op->value_index,
+                           op->type, new_bounds,
+                           op->condition, op->body);
+    }
+    return stmt;
+  }
+
+  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    if (op->attr_key == attr::realize_scope) {
+      auto node = op->node.as<OperationNode>();
+      if (node != nullptr) {
+        if (!frag_reg_.count(node->name)) {
+          return stmt;
+        }
+
+        auto it = matrix_abc_.find(simplify_name(node->name));
+        CHECK(it != matrix_abc_.end())
+              << "Cannot find matrix info for " << node->name;
+        auto matrix_abc = "wmma." + it->second;
+        Stmt body = Mutate(op->body);
+        return AttrStmt::make(op->node,
+                              op->attr_key,
+                              matrix_abc,
+                              body);
+      }
+    }
+    return stmt;
+  }
+
+  Stmt Mutate_(const Provide* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    auto it = mma_sync_.find(op);
+    if (it != mma_sync_.end()) {
+      const auto &operands = it->second;
+      Expr a = operands[0];
+      auto ca = a.as<Call>();
+      Expr b = operands[1];
+      auto cb = b.as<Call>();
+      Expr c = operands[2];
+      auto cc = c.as<Call>();
+
+      NodePtr<BufferNode> buffer_node_a = make_node<BufferNode>();
+      NodePtr<BufferNode> buffer_node_b = make_node<BufferNode>();
+      NodePtr<BufferNode> buffer_node_c = make_node<BufferNode>();
+
+      auto mma_sync_call =
+        [&buffer_node_a, &buffer_node_b]
+        (const Buffer &buffer) {
+          Buffer buffer_a(buffer_node_a);
+          Buffer buffer_b(buffer_node_b);
+          return Evaluate::make(
+                  Call::make(Handle(),
+                        intrinsic::tvm_mma_sync,
+                        {buffer->data, buffer->elem_offset,
+                        buffer_a->data, buffer_a->elem_offset,
+                        buffer_b->data, buffer_b->elem_offset,
+                        buffer->data, buffer->elem_offset},
+                        Call::Intrinsic));
+        };
+
+      auto call_add_c =
+        [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer &buffer) {
+          return add_buffer_bind_scope_(cc, buffer_node_c,
+            TensorKey{cc->func, cc->value_index}, mma_sync_call, cc->type);
+        };
+
+      auto call_add_b =
+        [this, &cb, &buffer_node_b, &call_add_c](const Buffer &buffer) {
+          return add_buffer_bind_scope_(cb, buffer_node_b,
+            TensorKey{cb->func, cb->value_index}, call_add_c, cb->type);
+        };
+
+      return add_buffer_bind_scope_(ca, buffer_node_a,
+        TensorKey{ca->func, ca->value_index}, call_add_b, ca->type);
+    }
+
+    auto it2 = frag_load_.find(op);
+    if (it2 != frag_load_.end()) {
+      Expr dst = it2->second;
+      if (op->value.as<FloatImm>() != nullptr ||
+          op->value.as<IntImm>() != nullptr) {
+        auto call = dst.as<Call>();
+
+        auto fill_fragment_call =
+          [this, &op](const Buffer &buffer) {
+            return Evaluate::make(
+                    Call::make(Handle(),
+                              intrinsic::tvm_fill_fragment,
+                              {buffer->data,
+                              warp_tile_.m, warp_tile_.n, warp_tile_.k,
+                              buffer->elem_offset, op->value},
+                              Call::Intrinsic));
+          };
+
+        NodePtr<BufferNode> buffer_node = make_node<BufferNode>();
+        return add_buffer_bind_scope_(call, buffer_node,
+                                      TensorKey{call->func, call->value_index},
+                                      fill_fragment_call, call->type);
+      }
+
+      const Call* value = op->value.as<Call>();
+      CHECK(value != nullptr)
+          << "Can only load fragment from a buffer";
+
+      auto it = strides_.find(value->name);
+      CHECK(it != strides_.end())
+          << "Cannot find stride for " << value->name;
+      auto strides = it->second;
+      CHECK_GE(strides.size(), 2);
+      Expr stride = strides[strides.size()-2];
+
+      // thread index unification inside a warp
+      Expr warp_y = IntImm::make(Int(32), warp_threads_y_);
+      ThreadIdxMutator thread_idx_mutator(warp_y);
+      Expr mutated_value = thread_idx_mutator.Mutate(op->value);
+      Expr src = Call::make(value->type,
+                            "&",
+                            {mutated_value},
+                            Call::Extern);
+
+      auto call = dst.as<Call>();
+      Expr matrix_major;
+      auto iter2 = matrix_major_.find(simplify_name(call->name));
+      CHECK(iter2 != matrix_major_.end())
+          << "Can not determine matrix major for " << call->name;
+      if (iter2->second == "col_major") {
+        matrix_major = StringImm::make("col_major");
+      } else if (iter2->second == "row_major") {
+        matrix_major = StringImm::make("row_major");
+      } else {
+        LOG(FATAL) << "invalid matrix major for " << call->name;
+      }
+
+      auto load_matrix_call =
+        [this, &src, &stride, &matrix_major](const Buffer &buffer) {
+        return Evaluate::make(
+                Call::make(Handle(),
+                          intrinsic::tvm_load_matrix_sync,
+                          {buffer->data,
+                          warp_tile_.m, warp_tile_.n, warp_tile_.k,
+                          buffer->elem_offset, src, stride, matrix_major},
+                          Call::Intrinsic));
+      };
+
+      NodePtr<BufferNode> buffer_node = make_node<BufferNode>();
+      return add_buffer_bind_scope_(call, buffer_node,
+                                    TensorKey{op->func, op->value_index},
+                                    load_matrix_call, call->type);
+    }
+
+    auto it3 = frag_store_.find(op);
+    if (it3 != frag_store_.end()) {
+      TensorKey key{op->func, op->value_index};
+      auto it = strides_.find(key.GetName());
+      CHECK(it != strides_.end())
+          << "Cannot find stride for " << key.GetName();
+      auto strides = it->second;
+      CHECK_GE(strides.size(), 2);
+      Expr stride = strides[strides.size()-2];
+
+      Expr dst = it3->second;
+      // thread index unification inside a warp
+      Expr warp_y = IntImm::make(Int(32), warp_threads_y_);
+      ThreadIdxMutator thread_idx_mutator(warp_y);
+      dst = thread_idx_mutator.Mutate(dst);
+      dst = Call::make(Handle(),
+                       "&",
+                       {dst},
+                       Call::Extern);
+
+      auto call = op->value.as<Call>();
+
+      auto store_matrix_call =
+        [this, &dst, &stride](const Buffer &buffer) {
+          return Evaluate::make(
+                  Call::make(Handle(),
+                            intrinsic::tvm_store_matrix_sync,
+                            {buffer->data,
+                            warp_tile_.m, warp_tile_.n, warp_tile_.k,
+                            buffer->elem_offset, dst, stride,
+                            StringImm::make("col_major")},
+                            Call::Intrinsic));
+        };
+
+      NodePtr<BufferNode> buffer_node = make_node<BufferNode>();
+      return add_buffer_bind_scope_(call, buffer_node,
+                                    TensorKey{call->func, call->value_index},
+                                    store_matrix_call, call->type);
+    }
+
+    return stmt;
+  }
+
+  Stmt Mutate_(const For* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<For>();
+    if (op != nullptr) {
+      auto it = loop_scaling_.find(op->loop_var.get());
+      if (it != loop_scaling_.end()) {
+        int scale_factor = it->second;
+        int scaled_extent_value = 1;
+        if (const IntImm *ori_extent = op->extent.as<IntImm>()) {
+          int ori_extent_value = ori_extent->value;
+          scaled_extent_value = ori_extent_value / scale_factor;
+        }
+        Expr scaled_extent = make_const(op->extent.type(), scaled_extent_value);
+        stmt = For::make(op->loop_var, op->min, scaled_extent, op->for_type,
+          op->device_api, op->body);
+      }
+    }
+    return stmt;
+  }
+
+ private:
+  Array<Expr> get_tile_size_(const std::string &name) {
+      auto it = matrix_abc_.find(name);
+      auto it2 = matrix_major_.find(name);
+      CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end())
+          << "Cannot find matrix info for " << name;
+      Expr size0 = make_const(Int(32), 16);
+      Expr size1 = make_const(Int(32), 16);
+      if (it->second == "matrix_a" && it2->second == "col_major") {
+        size0 = make_const(Int(32), warp_tile_.k);
+        size1 = make_const(Int(32), warp_tile_.m);
+      }
+      if (it->second == "matrix_a" && it2->second == "row_major") {
+        size0 = make_const(Int(32), warp_tile_.m);
+        size1 = make_const(Int(32), warp_tile_.k);
+      }
+      if (it->second == "matrix_b" && it2->second == "row_major") {
+        size0 = make_const(Int(32), warp_tile_.k);
+        size1 = make_const(Int(32), warp_tile_.n);
+      }
+      if (it->second == "matrix_b" && it2->second == "col_major") {
+        size0 = make_const(Int(32), warp_tile_.n);
+        size1 = make_const(Int(32), warp_tile_.k);
+      }
+      if (it->second == "matrix_c") {
+        size0 = make_const(Int(32), warp_tile_.n);
+        size1 = make_const(Int(32), warp_tile_.m);
+      }
+      Array<Expr> tile_size = {size0, size1};
+      return tile_size;
+  }
+
+  Stmt add_buffer_bind_scope_(const Call* call,
+      const NodePtr<BufferNode> &buffer_node, const TensorKey &key,
+      const std::function<Stmt(const Buffer &buffer)> &call_back,
+      DataType datatype) {
+    auto it = bounds_.find(key);
+    CHECK(it != bounds_.end());
+    Array<Expr> min_bound;
+    for (auto i : it->second) {
+      min_bound.push_back(i->min);
+    }
+
+    CHECK_GE(it->second.size(), 2);
+    Array<Expr> shape;
+    for (size_t i = 0; i < it->second.size() - 2; ++i) {
+      shape.push_back(it->second[i]->extent);
+    }
+    auto tile_size = get_tile_size_(simplify_name(call->name));
+    shape.push_back(tile_size[0]);
+    shape.push_back(tile_size[1]);
+
+    Array<Expr> strides;
+    for (size_t i = 1; i < shape.size(); ++i) {
+      Expr stride = IntImm::make(Int(32), 1);
+      for (size_t j = shape.size() - 1; j >= i; --j) {
+        stride = Mul::make(stride, shape[j]);
+      }
+      strides.push_back(stride);
+    }
+    strides.push_back(make_const(Int(32), 1));
+
+    Expr elem_offset = IntImm::make(Int(32), 0);
+    CHECK_EQ(call->args.size(), min_bound.size());
+    for (size_t i = 0; i < min_bound.size(); i++) {
+      elem_offset = Add::make(
+        elem_offset, Mul::make(
+          strides[i], Sub::make(call->args[i], min_bound[i])));
+    }
+
+    auto it2 = matrix_abc_.find(simplify_name(call->name));
+    CHECK(it2 != matrix_abc_.end())
+          << "Cannot find matrix info for " << call->name;
+    buffer_node->data = Variable::make(Handle(), call->name);
+    buffer_node->name = call->name;
+    buffer_node->scope = "wmma." + it2->second;
+    buffer_node->dtype = datatype;
+    buffer_node->strides = strides;
+    buffer_node->shape = shape;
+    buffer_node->data_alignment = 1;
+    buffer_node->elem_offset = Simplify(elem_offset);
+    buffer_node->offset_factor = 1;
+    Buffer buffer(buffer_node);
+
+    NodePtr<TensorNode> tensor_node = make_node<TensorNode>();
+    tensor_node->value_index = key.value_index;
+    tensor_node->op = Downcast<Operation>(key.f);
+    tensor_node->shape = shape;
+    tensor_node->dtype = datatype;
+    Tensor tensor(tensor_node);
+
+    Array<Expr> args;
+    for (size_t i = 0; i < call->args.size(); ++i) {
+      args.push_back(call->args[i]);
+      args.push_back(shape[i]);
+    }
+    auto tuple = Call::make(Handle(),
+                            intrinsic::tvm_tuple,
+                            args,
+                            Call::Intrinsic);
+    Array<NodeRef> node = {buffer, tensor};
+    return AttrStmt::make(node,
+                          "buffer_bind_scope",
+                          tuple,
+                          call_back(buffer));
+  }
+
+  std::unordered_map<std::string, std::string> matrix_abc_;
+  std::unordered_map<std::string, std::string> matrix_major_;
+  std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+  std::unordered_map<std::string, Array<Expr>> strides_;
+  std::unordered_set<std::string> frag_reg_;
+  std::unordered_map<const Variable*, unsigned> loop_scaling_;
+  std::unordered_map<const Provide*, Expr> frag_load_;
+  std::unordered_map<const Provide*, Expr> frag_store_;
+  std::unordered_map<TensorKey, Region> bounds_;
+  Tile warp_tile_;
+  int warp_threads_y_{-1};
+};
+
+Stmt RewriteForTensorCore(Stmt stmt,
+                          Schedule schedule,
+                          Map<Tensor, Buffer> extern_buffer) {
+  // Check if current lower target is CUDA
+  auto target = tvm::Target::Current(true);
+  if (target.defined() && target->target_name != "cuda") {
+    return stmt;
+  }
+
+  // Check if current runtime support GPU CUDA
+  TVMContext ctx{kDLGPU, 0};
+  auto api = tvm::runtime::DeviceAPI::Get(ctx, true);
+  if (api == nullptr) {
+    return stmt;
+  }
+
+  MMAMatcher mma_matcher(extern_buffer);
+  mma_matcher.Visit(stmt);
+  if (!mma_matcher.Matched()) {
+    return stmt;
+  }
+
+  ScheduleAnalyser schedule_analyser(mma_matcher);
+  if (!schedule_analyser.MatrixIdentify(schedule)) {
+    return stmt;
+  }
+
+  BufferAnalyser buffer_analyser(extern_buffer,
+                                 schedule_analyser, mma_matcher);
+  buffer_analyser.Visit(stmt);
+  if (!buffer_analyser.QualifiedForTensorCore()) {
+    return stmt;
+  }
+
+  return TensorCoreIRMutator(schedule_analyser, buffer_analyser).Mutate(stmt);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_pass_rewrite_for_tensor_core.py b/tests/python/unittest/test_pass_rewrite_for_tensor_core.py
new file mode 100644 (file)
index 0000000..272d243
--- /dev/null
@@ -0,0 +1,230 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# 'License'); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import topi
+import numpy as np
+from tvm.contrib import nvcc
+
+def tensor_core_matmul(warp_tile_m=16, m=64, n=32, l=96):
+    A = tvm.placeholder((n, l), name='A', dtype='float16')
+    B = tvm.placeholder((l, m), name='B', dtype='float16')
+    k = tvm.reduce_axis((0, l), name='k')
+    C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k].astype('float32') * B[k, j].astype('float32'), axis=k))
+    s = tvm.create_schedule(C.op)
+    y, x = s[C].op.axis
+    k = s[C].op.reduce_axis[0]
+
+    AA = s.cache_read(A, "shared", [C])
+    AL = s.cache_read(AA, "local", [C])
+    BB = s.cache_read(B, "shared", [C])
+    BL = s.cache_read(BB, "local", [C])
+    CL = s.cache_write(C, "local")
+
+    bx = 4
+    by = 32
+    step_k = 8
+    v = 4
+    TX = 8
+    TY = 1
+    tile_x = bx * TX
+    tile_y = by * TY
+    WX = min(warp_tile_m, tile_x)
+    tile_k = 16
+    vthread = 1
+
+    yo, ty = s[C].split(y, tile_y*vthread)
+    vy, ty = s[C].split(ty, tile_y)
+    ty, yi = s[C].split(ty, TY)
+
+    xo, xi = s[C].split(x, tile_x)
+    tz, xi = s[C].split(xi, WX)
+    tx, xi = s[C].split(xi, TX)
+    ko, ki = s[CL].split(k, step_k * tile_k)
+    kl, ki = s[CL].split(ki, tile_k)
+
+    s[C].reorder(yo, xo, tz, ty, tx, yi, xi)
+    s[C].bind(yo, tvm.thread_axis("blockIdx.y"))
+    s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[C].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[C].bind(vy, tvm.thread_axis((0, vthread), "vthread", name="vy"))
+    s[CL].compute_at(s[C], tx)
+    yo, xo = CL.op.axis
+    s[CL].reorder(ko, kl, ki, yo, xo)
+
+    s[AA].compute_at(s[CL], ko)
+    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v)
+    tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
+    tx, vec = s[AA].split(tx, factor=v)
+    fused = s[AA].fuse(s[AA].op.axis[0], xo)
+    _, ty = s[AA].split(fused, factor=by)
+    s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[AA].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[AA].vectorize(vec)
+
+    s[BB].compute_at(s[CL], ko)
+    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v)
+    tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
+    tx, vec = s[BB].split(tx, factor=v)
+    fused = s[BB].fuse(s[BB].op.axis[0], xo)
+    _, ty = s[BB].split(fused, factor=by)
+    s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[BB].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[BB].vectorize(vec)
+
+    s[AL].compute_at(s[CL], kl)
+    s[BL].compute_at(s[CL], kl)
+
+    s[CL].pragma(ko, 'tensor_core')
+
+    func = tvm.build(s, [A, B, C], 'cuda')
+
+    ctx = tvm.gpu(0)
+    a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
+    b_np = np.random.uniform(size=(l, m)).astype(B.dtype)
+    c_np = np.zeros((n, m), dtype=np.float32)
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
+    func(a, b, c)
+    evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
+    print('gemm m=%d n=%d k=%d: %f ms' % (m, n, l, evaluator(a, b, c).mean * 1e3))
+
+    c_np = np.dot(a_np, b_np)
+    np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3)
+
+def tensor_core_batch_matmul(warp_tile_m=16, m=64, n=32, l=96, batch=2):
+    A = tvm.placeholder((batch, n, l), name='A', dtype='float16')
+    B = tvm.placeholder((batch, l, m), name='B', dtype='float16')
+    k = tvm.reduce_axis((0, l), name='k')
+    C = tvm.compute((batch, n, m), lambda b, i, j: tvm.sum((A[b, i, k] * B[b, k, j]).astype('float32'), axis=k))
+    s = tvm.create_schedule(C.op)
+    z, y, x = s[C].op.axis
+    k = s[C].op.reduce_axis[0]
+
+    AA = s.cache_read(A, "shared", [C])
+    AL = s.cache_read(AA, "local", [C])
+    BB = s.cache_read(B, "shared", [C])
+    BL = s.cache_read(BB, "local", [C])
+    CL = s.cache_write(C, "local")
+
+    bx = 2
+    by = 32
+    step_k = 8
+    v = 4
+    TX = 8
+    TY = 1
+    tile_x = bx * TX
+    tile_y = by * TY
+    WX = min(warp_tile_m, tile_x)
+    tile_k = 16
+    vthread = 1
+
+    yo, ty = s[C].split(y, tile_y*vthread)
+    vy, ty = s[C].split(ty, tile_y)
+    ty, yi = s[C].split(ty, TY)
+
+    xo, xi = s[C].split(x, tile_x)
+    tz, xi = s[C].split(xi, WX)
+    tx, xi = s[C].split(xi, TX)
+    ko, ki = s[CL].split(k, step_k * tile_k)
+    kl, ki = s[CL].split(ki, tile_k)
+
+    s[C].reorder(z, yo, xo, tz, ty, tx, yi, xi)
+    s[C].bind(z, tvm.thread_axis("blockIdx.z"))
+    s[C].bind(yo, tvm.thread_axis("blockIdx.y"))
+    s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[C].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[C].bind(vy, tvm.thread_axis((0, vthread), "vthread", name="vy"))
+    s[CL].compute_at(s[C], tx)
+    zo, yo, xo = CL.op.axis
+    s[CL].reorder(ko, kl, ki, zo, yo, xo)
+
+    s[AA].compute_at(s[CL], ko)
+    xo, xi = s[AA].split(s[AA].op.axis[2], factor=bx*v)
+    tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
+    tx, vec = s[AA].split(tx, factor=v)
+    fused = s[AA].fuse(s[AA].op.axis[1], xo)
+    _, ty = s[AA].split(fused, factor=by)
+    s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[AA].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[AA].vectorize(vec)
+
+    s[BB].compute_at(s[CL], ko)
+    xo, xi = s[BB].split(s[BB].op.axis[2], factor=bx*v)
+    tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
+    tx, vec = s[BB].split(tx, factor=v)
+    fused = s[BB].fuse(s[BB].op.axis[1], xo)
+    _, ty = s[BB].split(fused, factor=by)
+    s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[BB].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[BB].vectorize(vec)
+
+    s[AL].compute_at(s[CL], kl)
+    s[BL].compute_at(s[CL], kl)
+
+    s[CL].pragma(ko, 'tensor_core')
+
+    func = tvm.build(s, [A, B, C], 'cuda')
+
+    ctx = tvm.gpu(0)
+    a_np = np.random.uniform(size=(batch, n, l)).astype(A.dtype)
+    b_np = np.random.uniform(size=(batch, l, m)).astype(B.dtype)
+    c_np = np.zeros((batch, n, m), dtype=np.float32)
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    c = tvm.nd.array(np.zeros((batch, n, m), dtype=C.dtype), ctx)
+    func(a, b, c)
+    evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
+    print('batch gemm m=%d n=%d k=%d batch=%d: %f ms' % (m, n, l, batch, evaluator(a, b, c).mean * 1e3))
+
+    for bs in range(batch):
+      c_np[bs, :, :] = np.dot(a_np[bs, :, :], b_np[bs, :, :])
+    np.testing.assert_allclose(c_np, c.asnumpy(), rtol=1e-3)
+
+def test_tensor_core_matmul():
+    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+    if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        print("skip because gpu does not support tensor core")
+        return
+
+    tensor_core_matmul(16) #test with warp_tile 16x16x16
+    tensor_core_matmul(8) #test with warp_tile 8x32x16
+    tensor_core_matmul(32) #test with warp_tile 32x8x16
+
+def test_tensor_core_batch_matmul():
+    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+    if not nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        print("skip because gpu does not support tensor core")
+        return
+
+    tensor_core_batch_matmul()
+
+if __name__ == '__main__':
+    test_tensor_core_matmul()
+    test_tensor_core_batch_matmul()
diff --git a/tutorials/optimize/opt_matmul_auto_tensorcore.py b/tutorials/optimize/opt_matmul_auto_tensorcore.py
new file mode 100644 (file)
index 0000000..00dbe9d
--- /dev/null
@@ -0,0 +1,465 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+.. _opt-matmul-auto-tensorcore:
+
+How to optimize matmul with Auto TensorCore CodeGen
+==================================
+**Author**: `Minmin Sun <https://github.com/minminsun>`_, \
+            `Lanbo Li <https://github.com/Orion34C>`_, \
+            `Chenfan Jia <https://github.com/jcf94>`_, \
+            `Jun Yang <https://github.com/yangjunpro>`_
+
+In this tutorial, we will demonstrate how to write a high performance matmul
+schedule on Volta/Turing GPUs with TVM Auto TensorCore CodeGen.
+This is a transparent solution to generate tensorcore kernel
+with most transformations done in ir passes.
+Users can also write schedule with tensorization to generate TensorCore code.
+Both solutions use the same tensorcore intrinsics.
+Please refer to :ref:`opt-conv-tensorcore` tutorial for more details.
+
+"""
+
+################################################################
+# Preparation and Algorithm
+# --------------------------
+# 2 kinds of input data types are supported: float16 and int8.
+# For float16, the accumulator is float32.
+# For int8, the accumulator is int32.
+# For data layouts, 'N' means None-transpose while 'T' means Transpose.
+
+import logging
+import sys
+
+import numpy as np
+import tvm
+
+from tvm import autotvm
+from tvm.contrib import nvcc
+
+def matmul_nn(A, B, L, dtype='float16', layout='NN'):
+    k = tvm.reduce_axis((0, L), name='k')
+    if dtype == 'float16':
+      out_type = 'float'
+    elif dtype == 'int8':
+      out_type = 'int'
+    if (layout == 'NN'):
+      return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k))
+    if (layout == 'NT'):
+      return tvm.compute((N, M), lambda i, j: tvm.sum(A[k, i].astype(out_type) * B[k, j].astype(out_type), axis=k))
+    if (layout == 'TN'):
+      return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[j, k].astype(out_type), axis=k))
+    if (layout == 'TT'):
+      return tvm.compute((N, M), lambda i, j: tvm.sum(A[k, i].astype(out_type) * B[j, k].astype(out_type), axis=k))
+
+###############################################################################
+# Scheduling the Computation
+# --------------------------
+# This schedule is no different than a non-tensorcore matmul schedule on GPU.
+# Please refer to :ref:`opt-gemm` tutorial for basics of optimizing matmul schedule.
+# When the "tensor_core" pragma is set, the "rewrite for tensorcore" ir pass
+# will automatically transform the schedule for tensorcore codegen,
+# otherwise normal CUDA code, with lower performance but equal functionality, will be generated.
+#
+# .. note::
+#
+#   *Requirements of TesnsorCore*
+#
+#   Note that in the following 2 cases, even though the "tensor_core" pragma is set, TVM will still fall back to normal CUDA codegen:
+#   (1) The m, n or k of input matrices is not multiple of 16;
+#   (2) The warp tile size is not 16x16x16 on CUDA9, or not one of {16x16x16, 32x8x16, 8x32x16} on CUDA version >= 10.0.
+#
+# In this schedule, storage_align is used to reduce bank conflicts of shared memory. Please refer to this
+# `doc <https://docs.tvm.ai/api/python/schedule.html#tvm.schedule.Stage.storage_align>`_
+# for the usage of storage_align primitive. In short, we need to add an offset to some shared memory buffer
+# to reduce bank conflicts.
+# According to the `wmma doc <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#wmma-description>`_,
+# the stride of load_matrix_sync must be a multiple of 16 bytes,
+# so we choose 8 as offset for float16 and 16 as offset for int8.
+#
+# We use AutoTVM to search for best configurations in this schedule.
+
+@autotvm.template
+def test_gemm(N, L, M, dtype, layout):
+    if (layout == "NN"):
+      shape_a = (N, L)
+      shape_b = (L, M)
+    elif (layout == "NT"):
+      shape_a = (L, N)
+      shape_b = (L, M)
+    elif (layout == "TN"):
+      shape_a = (N, L)
+      shape_b = (M, L)
+    elif (layout == "TT"):
+      shape_a = (L, N)
+      shape_b = (M, L)
+    else:
+      print ("Unsupported layout:", layout)
+      sys.exit(1);
+    A = tvm.placeholder(shape_a, name='A', dtype=dtype)
+    B = tvm.placeholder(shape_b, name='B', dtype=dtype)
+    C = matmul_nn(A, B, L, dtype, layout)
+
+    s = tvm.create_schedule(C.op)
+    y, x = s[C].op.axis
+    k = s[C].op.reduce_axis[0]
+
+    # storage_align params
+    factor = 16
+    offset = 8
+    if dtype == 'int8':
+      factor = 32
+      offset = 16
+
+    # create cache stages
+    AA = s.cache_read(A, "shared", [C])
+    if (layout == "NN" or layout == "TN"):
+      s[AA].storage_align(AA.op.axis[0], factor, offset)
+    AL = s.cache_read(AA, "local", [C])
+    BB = s.cache_read(B, "shared", [C])
+    if (layout == "TT" or layout == "NT"):
+      s[BB].storage_align(BB.op.axis[0], factor, offset)
+    BL = s.cache_read(BB, "local", [C])
+    CL = s.cache_write(C, "local")
+
+    #autotvm search space definition
+    cfg = autotvm.get_config()
+
+    cfg.define_knob("bx", [2, 4, 8])
+    cfg.define_knob("by", [16, 32, 64])
+    cfg.define_knob("step_k", [8, 16, 32])
+    cfg.define_knob("v", [4, 8])
+    by = cfg['by'].val
+    bx = cfg['bx'].val
+    step_k = cfg['step_k'].val
+    v = cfg['v'].val
+
+    # thread tile
+    TX = 8
+    TY = 1
+    # warp tile
+    warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0
+    warp_tile_k = 16 # it must be 16
+    # block tile
+    tile_x = bx * TX
+    tile_y = by * TY
+
+    yo, ty = s[C].split(y, tile_y)
+    ty, yi = s[C].split(ty, TY)
+
+    # schedule for C stage
+    xo, xi = s[C].split(x, tile_x)
+    WX = min(warp_tile_m, tile_x)
+    tz, xi = s[C].split(xi, WX)
+    tx, xi = s[C].split(xi, TX)
+    s[C].reorder(yo, xo, tz, ty, tx, yi, xi)
+    s[C].bind(yo, tvm.thread_axis("blockIdx.y"))
+    s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[C].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+
+    # schedule for CL stage
+    ko, ki = s[CL].split(k, step_k * warp_tile_k)
+    kl, ki = s[CL].split(ki, warp_tile_k)
+    s[CL].compute_at(s[C], tx)
+    yo, xo = CL.op.axis
+    s[CL].reorder(ko, kl, ki, yo, xo)
+
+    # schedule for AA stage
+    s[AA].compute_at(s[CL], ko)
+    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx*v)
+    tz, tx = s[AA].split(xi, factor=(WX//TX)*v)
+    tx, vec = s[AA].split(tx, factor=v)
+    fused = s[AA].fuse(s[AA].op.axis[0], xo)
+    _, ty = s[AA].split(fused, factor=by)
+    s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[AA].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
+    # vectorization is very important for float16/int8 inputs
+    s[AA].vectorize(vec)
+
+    # schedule for BB stage
+    s[BB].compute_at(s[CL], ko)
+    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx*v)
+    tz, tx = s[BB].split(xi, factor=(WX//TX)*v)
+    tx, vec = s[BB].split(tx, factor=v)
+    fused = s[BB].fuse(s[BB].op.axis[0], xo)
+    _, ty = s[BB].split(fused, factor=by)
+    s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
+    s[BB].bind(tz, tvm.thread_axis("threadIdx.z"))
+    s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[BB].vectorize(vec)
+
+    s[AL].compute_at(s[CL], kl)
+    s[BL].compute_at(s[CL], kl)
+
+    # set the 'tensor_core' pragma for tensorcore codegen
+    s[CL].pragma(ko, 'tensor_core')
+
+    return s, [A, B, C]
+
+###############################################################################
+# AutoTune and Test
+# --------------------
+# Finally we use a tuner to tune the schedule, generate code with best config
+# and run the kernel to compare with numpy to check whether the results are correct.
+
+# check whether the gpu has tensorcore
+ctx = tvm.gpu()
+if not nvcc.have_tensorcore(ctx.compute_version):
+  print('the gpu has no tensorcore, skipping...')
+  sys.exit(0)
+
+M, N, L = 512, 32, 512
+dtype = 'float16'
+layout = 'NN'
+if len(sys.argv) >= 4:
+  M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3])
+if len(sys.argv) >= 5:
+  dtype = sys.argv[4]
+if len(sys.argv) >= 6:
+  layout = sys.argv[5]
+
+def tune_and_evaluate(M, N, L, dtype, layout):
+  task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda')
+  print(task.config_space)
+
+  logging.getLogger('autotvm').setLevel(logging.DEBUG)
+  logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
+
+  measure_option = autotvm.measure_option(
+    builder='local',
+    runner=autotvm.LocalRunner(number=5))
+
+  tuner = autotvm.tuner.XGBTuner(task)
+  tuner.tune(n_trial=1000,
+             measure_option=measure_option,
+             callbacks=[autotvm.callback.log_to_file('matmul.log')])
+
+  dispatch_context = autotvm.apply_history_best("matmul.log")
+  best_config = dispatch_context.query(task.target, task.workload)
+  print("\nBest config:")
+  print(best_config)
+  with autotvm.apply_history_best('matmul.log'):
+    with tvm.target.create("cuda"):
+        with tvm.build_config():
+            s, arg_bufs = test_gemm(N, L, M, dtype, layout)
+            print(tvm.lower(s, arg_bufs, simple_mode=True))
+            func = tvm.build(s, arg_bufs)
+  dev_module = func.imported_modules[0]
+  print(dev_module.get_source())
+
+  # check correctness
+  if (layout == "NN"):
+    shape_a = (N, L)
+    shape_b = (L, M)
+  elif (layout == "NT"):
+    shape_a = (L, N)
+    shape_b = (L, M)
+  elif (layout == "TN"):
+    shape_a = (N, L)
+    shape_b = (M, L)
+  elif (layout == "TT"):
+    shape_a = (L, N)
+    shape_b = (M, L)
+
+  a_np = None
+  b_np = None
+  c_np = None
+  c_np_type = None
+  if dtype == 'float16':
+    c_np_type = np.float32
+    a_np = np.random.uniform(size=shape_a).astype(np.float16)
+    b_np = np.random.uniform(size=shape_b).astype(np.float16)
+    if (layout == "NN"):
+      c_np = np.dot(a_np, b_np)
+    elif (layout == "NT"):
+      c_np = np.dot(a_np.T, b_np)
+    elif (layout == "TN"):
+      c_np = np.dot(a_np, b_np.T)
+    elif (layout == "TT"):
+      c_np = np.dot(a_np.T, b_np.T)
+  elif dtype == 'int8':
+    c_np_type = np.int32
+    a_np = np.random.randint(low=-128, high=127, size=shape_a).astype(np.int8)
+    b_np = np.random.randint(low=-128, high=127, size=shape_b).astype(np.int8)
+    if (layout == "NN"):
+      c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32))
+    elif (layout == "NT"):
+      c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32))
+    elif (layout == "TN"):
+      c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
+    elif (layout == "TT"):
+      c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
+
+  c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
+  a_tvm = tvm.nd.array(a_np, ctx=ctx)
+  b_tvm = tvm.nd.array(b_np, ctx=ctx)
+  func(a_tvm, b_tvm, c_tvm)
+
+  tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3)
+
+  evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
+  print('Time cost of this operator: %f' % evaluator(a_tvm, b_tvm, c_tvm).mean)
+
+# We do not run the tuning in our webpage server since it takes some time.
+# Uncomment the following line to run it by yourself.
+
+# tune_and_evaluate(M, N, L, dtype, layout)
+
+######################################################################
+# Sample Output
+# -------------
+# .. code-block:: bash
+#
+#    Best config:
+#    [('bx', 4), ('by', 32), ('step_k', 16), ('v', 8)],,None,40
+#    Finish loading 162 records
+#    produce compute {
+#      // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
+#      // attr [compute.local] storage_scope = "wmma.accumulator"
+#      allocate compute.local[float32 * 256]
+#      // attr [A.shared] storage_scope = "shared"
+#      allocate A.shared[float16 * 8448]
+#      // attr [B.shared] storage_scope = "shared"
+#      allocate B.shared[float16 * 8192]
+#      // attr [A.shared.local] storage_scope = "wmma.matrix_b"
+#      allocate A.shared.local[float16 * 256]
+#      // attr [B.shared.local] storage_scope = "wmma.matrix_a"
+#      allocate B.shared.local[float16 * 256]
+#      // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16
+#      // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2
+#      // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
+#      // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
+#      produce compute.local {
+#        for (j.c.init, 0, 1) {
+#          tvm_fill_fragment(compute.local, 16, 16, 16, 0, 0f)
+#        }
+#        // attr [iter_var(k.outer, )] pragma_tensor_core = 1
+#        for (k.outer, 0, 2) {
+#          produce A.shared {
+#            for (ax0.ax1.outer.fused.outer, 0, 8) {
+#              // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
+#              // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2
+#              // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
+#              A.shared[ramp((((((ax0.ax1.outer.fused.outer*1056) + (floordiv(threadIdx.y, 8)*264)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = A[ramp(((((((ax0.ax1.outer.fused.outer*2048) + (floordiv(threadIdx.y, 8)*512)) + (k.outer*256)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)]
+#            }
+#          }
+#          produce B.shared {
+#            for (ax0.ax1.outer.fused.outer, 0, 8) {
+#              // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
+#              // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2
+#              // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
+#              B.shared[ramp(((((ax0.ax1.outer.fused.outer*1024) + (threadIdx.y*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = B[ramp(((((((k.outer*131072) + (ax0.ax1.outer.fused.outer*16384)) + (threadIdx.y*512)) + (blockIdx.x*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)]
+#            }
+#          }
+#          for (k.inner.outer, 0, 16) {
+#            produce A.shared.local {
+#              for (ax1, 0, 1) {
+#                tvm_load_matrix_sync(A.shared.local, 16, 16, 16, 0, &(A.shared[(((threadIdx.y/16)*4224) + (k.inner.outer*16))]), 264, "col_major")
+#              }
+#            }
+#            produce B.shared.local {
+#              for (ax0, 0, 1) {
+#                for (ax1, 0, 1) {
+#                  tvm_load_matrix_sync(B.shared.local, 16, 16, 16, 0, &(B.shared[((k.inner.outer*512) + (threadIdx.z*16))]), 32, "col_major")
+#                }
+#              }
+#            }
+#            for (k.inner.inner, 0, 1) {
+#              for (j.c, 0, 1) {
+#                tvm_mma_sync(compute.local, 0, B.shared.local, 0, A.shared.local, 0, compute.local, 0)
+#              }
+#            }
+#          }
+#        }
+#      }
+#      for (j.inner.inner.inner, 0, 1) {
+#        tvm_store_matrix_sync(compute.local, 16, 16, 16, 0, &(compute[((((threadIdx.y/16)*8192) + (blockIdx.x*32)) + (threadIdx.z*16))]), 512, "col_major")
+#      }
+#    }
+#
+#    #include <cuda_fp16.h>
+#    __device__ half max(const half a, const half b)
+#    {
+#      return __hgt(__half(a), __half(b)) ? a : b;
+#    }
+#    __device__ half min(const half a, const half b)
+#    {
+#      return __hlt(__half(a), __half(b)) ? a : b;
+#    }
+#    __device__ half operator+(const volatile __half &a,  const volatile __half &b)
+#    {
+#      return __hadd(a, b);
+#    }
+#    __device__ half operator<=(const volatile __half &a,  const volatile __half &b)
+#    {
+#      return __hlt(a, b);
+#    }
+#    __device__ half operator*(const volatile __half &a,  const volatile __half &b)
+#    {
+#      return __hmul(a, b);
+#    }
+#    #include <mma.h>
+#    extern "C" __global__ void default_function_kernel0( half* __restrict__ A,  half* __restrict__ B,  float* __restrict__ compute) {
+#      nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> compute_local[1];
+#      __shared__ half A_shared[8448];
+#      __shared__ half B_shared[8192];
+#      nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> A_shared_local[1];
+#      nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> B_shared_local[1];
+#      for (int j_c_init = 0; j_c_init < 1; ++j_c_init) {
+#        (void)nvcuda::wmma::fill_fragment(compute_local[0], 0.000000e+00f);
+#      }
+#      for (int k_outer = 0; k_outer < 2; ++k_outer) {
+#        __syncthreads();
+#        for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 8; ++ax0_ax1_outer_fused_outer) {
+#          ((__shared__ float4*)(A_shared + (((((ax0_ax1_outer_fused_outer * 1056) + ((((int)threadIdx.y) >> 3) * 264)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(A + ((((((ax0_ax1_outer_fused_outer * 2048) + ((((int)threadIdx.y) >> 3) * 512)) + (k_outer * 256)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0];
+#        }
+#        for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 8; ++ax0_ax1_outer_fused_outer1) {
+#          ((__shared__ float4*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 1024) + (((int)threadIdx.y) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(B + ((((((k_outer * 131072) + (ax0_ax1_outer_fused_outer1 * 16384)) + (((int)threadIdx.y) * 512)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0];
+#        }
+#        __syncthreads();
+#        for (int k_inner_outer = 0; k_inner_outer < 16; ++k_inner_outer) {
+#          for (int ax1 = 0; ax1 < 1; ++ax1) {
+#            (void)nvcuda::wmma::load_matrix_sync(A_shared_local[0], &(A_shared[(((((int)threadIdx.y) / 16) * 4224) + (k_inner_outer * 16))]), 264);
+#          }
+#          for (int ax0 = 0; ax0 < 1; ++ax0) {
+#            for (int ax11 = 0; ax11 < 1; ++ax11) {
+#              (void)nvcuda::wmma::load_matrix_sync(B_shared_local[0], &(B_shared[((k_inner_outer * 512) + (((int)threadIdx.z) * 16))]), 32);
+#            }
+#          }
+#          for (int k_inner_inner = 0; k_inner_inner < 1; ++k_inner_inner) {
+#            for (int j_c = 0; j_c < 1; ++j_c) {
+#              (void)nvcuda::wmma::mma_sync(compute_local[0], B_shared_local[0], A_shared_local[0], compute_local[0]);
+#            }
+#          }
+#        }
+#      }
+#      for (int j_inner_inner_inner = 0; j_inner_inner_inner < 1; ++j_inner_inner_inner) {
+#        (void)nvcuda::wmma::store_matrix_sync(&(compute[((((((int)threadIdx.y) / 16) * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16))]), compute_local[0], 512, nvcuda::wmma::mem_col_major);
+#      }
+#    }
+#
+#
+#    Time cost of this operator: 0.000008
+
+###############################################################################
+# Summary
+# --------------------------
+# This tutorial demonstrates how to use the AutoTensorCoreCodeGen of TVM
+# to generate tensorcore kernels.