[Matrix] Add TileInfo abstraction for tiled matrix code-gen.
authorFlorian Hahn <flo@fhahn.com>
Mon, 20 Jul 2020 17:31:04 +0000 (18:31 +0100)
committerFlorian Hahn <flo@fhahn.com>
Mon, 20 Jul 2020 17:49:08 +0000 (18:49 +0100)
This patch adds a TileInfo abstraction and utilities to
create a 3-level loop nest for tiling.

Reviewers: anemet

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D77550

llvm/include/llvm/Transforms/Utils/MatrixUtils.h [new file with mode: 0644]
llvm/lib/Transforms/Utils/CMakeLists.txt
llvm/lib/Transforms/Utils/MatrixUtils.cpp [new file with mode: 0644]

diff --git a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h
new file mode 100644 (file)
index 0000000..39a0d4b
--- /dev/null
@@ -0,0 +1,94 @@
+//===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utilities for generating tiled loops for matrix operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
+#define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
+
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DomTreeUpdater;
+class BasicBlock;
+class Value;
+class Loop;
+class LoopInfo;
+class IRBuilderBase;
+
+/// A helper struct to create IR loop nests for tiling in IR of the following
+/// form:
+///   for CurrentColumn = 0..NumColumns
+///     for CurrentRow = 0..NumRows
+///       for CurrentInner = 0..NumInner
+struct TileInfo {
+  /// Number of rows of the matrix.
+  unsigned NumRows;
+
+  /// Number of columns of the matrix.
+  unsigned NumColumns;
+
+  /// Number of columns of the first matrix of a multiply /
+  /// number of rows of the second matrix of a multiply.
+  unsigned NumInner;
+
+  /// Number of rows/columns in a tile.
+  unsigned TileSize = -1;
+
+  /// Start row of the current tile to compute.
+  Value *CurrentRow;
+
+  /// Start column of the current tile to compute.
+  Value *CurrentCol;
+
+  /// Current tile offset during the tile computation.
+  Value *CurrentK;
+
+  /// Header of the outermost loop iterating from 0..NumColumns.
+  BasicBlock *ColumnLoopHeader = nullptr;
+
+  /// Header of the second loop iterating from 0..NumRows.
+  BasicBlock *RowLoopHeader = nullptr;
+  /// Latch of the second loop iterating from 0..NumRows.
+  BasicBlock *RowLoopLatch = nullptr;
+  /// Header of the innermost loop iterating from 0..NumInner.
+  BasicBlock *InnerLoopHeader = nullptr;
+  /// Latch of the innermost loop iterating from 0..NumInner.
+  BasicBlock *InnerLoopLatch = nullptr;
+
+  TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
+           unsigned TileSize)
+      : NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner),
+        TileSize(TileSize) {}
+
+  /// Creates an IR loop nests for tiling of the form below. Returns the block
+  /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
+  /// fields.
+  ///
+  /// for CurrentColumn = 0..NumColumns
+  ///   for CurrentRow = 0..NumRows
+  ///     for CurrentInner = 0..NumInner
+  BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
+                               IRBuilderBase &B, DomTreeUpdater &DTU,
+                               LoopInfo &LI);
+
+private:
+  /// Creates a new loop with header, body and latch blocks that iterates from
+  /// [0, Bound). Updates \p Preheader to branch to the new header and uses \p
+  /// Exit as exit block.  Adds the new loop blocks to \L and applies dominator
+  /// tree updates to \p DTU.
+  static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
+                                Value *Bound, Value *Step, StringRef Name,
+                                IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
+                                LoopInfo &LI);
+};
+} // namespace llvm
+
+#endif
index 5c26767..19f655c 100644 (file)
@@ -46,6 +46,7 @@ add_llvm_component_library(LLVMTransformUtils
   LowerInvoke.cpp
   LowerMemIntrinsics.cpp
   LowerSwitch.cpp
+  MatrixUtils.cpp
   Mem2Reg.cpp
   MetaRenamer.cpp
   MisExpect.cpp
diff --git a/llvm/lib/Transforms/Utils/MatrixUtils.cpp b/llvm/lib/Transforms/Utils/MatrixUtils.cpp
new file mode 100644 (file)
index 0000000..6a13763
--- /dev/null
@@ -0,0 +1,104 @@
+//===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utilities for generating tiled loops for matrix operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/MatrixUtils.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+
+using namespace llvm;
+
+BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
+                                 Value *Bound, Value *Step, StringRef Name,
+                                 IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
+                                 LoopInfo &LI) {
+  LLVMContext &Ctx = Preheader->getContext();
+  BasicBlock *Header = BasicBlock::Create(
+      Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
+  BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
+                                        Header->getParent(), Exit);
+  BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
+                                         Header->getParent(), Exit);
+
+  Type *I32Ty = Type::getInt64Ty(Ctx);
+  BranchInst::Create(Body, Header);
+  BranchInst::Create(Latch, Body);
+  PHINode *IV =
+      PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
+  IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
+
+  B.SetInsertPoint(Latch);
+  Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
+  Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
+  BranchInst::Create(Header, Exit, Cond, Latch);
+  IV->addIncoming(Inc, Latch);
+
+  BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
+  BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
+  PreheaderBr->setSuccessor(0, Header);
+  DTU.applyUpdatesPermissive({
+      {DominatorTree::Delete, Preheader, Tmp},
+      {DominatorTree::Insert, Header, Body},
+      {DominatorTree::Insert, Body, Latch},
+      {DominatorTree::Insert, Latch, Header},
+      {DominatorTree::Insert, Latch, Exit},
+      {DominatorTree::Insert, Preheader, Header},
+  });
+
+  L->addBasicBlockToLoop(Header, LI);
+  L->addBasicBlockToLoop(Body, LI);
+  L->addBasicBlockToLoop(Latch, LI);
+  return Body;
+}
+
+// Creates the following loop nest skeleton:
+//  for C = 0; C < NumColumns; C += TileSize
+//    for R = 0; R < NumRows; R += TileSize
+//      for K = 0; K < Inner ; K += TileSize
+BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
+                                       IRBuilderBase &B, DomTreeUpdater &DTU,
+                                       LoopInfo &LI) {
+  Loop *ColLoop = LI.AllocateLoop();
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *InnerLoop = LI.AllocateLoop();
+  RowLoop->addChildLoop(InnerLoop);
+  ColLoop->addChildLoop(RowLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(ColLoop);
+  else
+    LI.addTopLevelLoop(ColLoop);
+
+  BasicBlock *ColBody =
+      CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
+                 "cols", B, DTU, ColLoop, LI);
+  BasicBlock *ColLatch = ColBody->getSingleSuccessor();
+  BasicBlock *RowBody =
+      CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
+                 "rows", B, DTU, RowLoop, LI);
+  RowLoopLatch = RowBody->getSingleSuccessor();
+
+  BasicBlock *InnerBody =
+      CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
+                 B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
+  InnerLoopLatch = InnerBody->getSingleSuccessor();
+  ColumnLoopHeader = ColBody->getSinglePredecessor();
+  RowLoopHeader = RowBody->getSinglePredecessor();
+  InnerLoopHeader = InnerBody->getSinglePredecessor();
+  CurrentRow = &*RowLoopHeader->begin();
+  CurrentCol = &*ColumnLoopHeader->begin();
+  CurrentK = &*InnerLoopHeader->begin();
+
+  return InnerBody;
+}