--- /dev/null
+//===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utilities for generating tiled loops for matrix operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
+#define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
+
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DomTreeUpdater;
+class BasicBlock;
+class Value;
+class Loop;
+class LoopInfo;
+class IRBuilderBase;
+
+/// A helper struct to create IR loop nests for tiling in IR of the following
+/// form:
+/// for CurrentColumn = 0..NumColumns
+/// for CurrentRow = 0..NumRows
+/// for CurrentInner = 0..NumInner
+struct TileInfo {
+ /// Number of rows of the matrix.
+ unsigned NumRows;
+
+ /// Number of columns of the matrix.
+ unsigned NumColumns;
+
+ /// Number of columns of the first matrix of a multiply /
+ /// number of rows of the second matrix of a multiply.
+ unsigned NumInner;
+
+ /// Number of rows/columns in a tile.
+ unsigned TileSize = -1;
+
+ /// Start row of the current tile to compute.
+ Value *CurrentRow;
+
+ /// Start column of the current tile to compute.
+ Value *CurrentCol;
+
+ /// Current tile offset during the tile computation.
+ Value *CurrentK;
+
+ /// Header of the outermost loop iterating from 0..NumColumns.
+ BasicBlock *ColumnLoopHeader = nullptr;
+
+ /// Header of the second loop iterating from 0..NumRows.
+ BasicBlock *RowLoopHeader = nullptr;
+ /// Latch of the second loop iterating from 0..NumRows.
+ BasicBlock *RowLoopLatch = nullptr;
+ /// Header of the innermost loop iterating from 0..NumInner.
+ BasicBlock *InnerLoopHeader = nullptr;
+ /// Latch of the innermost loop iterating from 0..NumInner.
+ BasicBlock *InnerLoopLatch = nullptr;
+
+ TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
+ unsigned TileSize)
+ : NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner),
+ TileSize(TileSize) {}
+
+ /// Creates an IR loop nests for tiling of the form below. Returns the block
+ /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
+ /// fields.
+ ///
+ /// for CurrentColumn = 0..NumColumns
+ /// for CurrentRow = 0..NumRows
+ /// for CurrentInner = 0..NumInner
+ BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
+ IRBuilderBase &B, DomTreeUpdater &DTU,
+ LoopInfo &LI);
+
+private:
+ /// Creates a new loop with header, body and latch blocks that iterates from
+ /// [0, Bound). Updates \p Preheader to branch to the new header and uses \p
+ /// Exit as exit block. Adds the new loop blocks to \L and applies dominator
+ /// tree updates to \p DTU.
+ static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
+ Value *Bound, Value *Step, StringRef Name,
+ IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
+ LoopInfo &LI);
+};
+} // namespace llvm
+
+#endif
--- /dev/null
+//===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utilities for generating tiled loops for matrix operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/MatrixUtils.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+
+using namespace llvm;
+
+BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
+ Value *Bound, Value *Step, StringRef Name,
+ IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
+ LoopInfo &LI) {
+ LLVMContext &Ctx = Preheader->getContext();
+ BasicBlock *Header = BasicBlock::Create(
+ Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
+ BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
+ Header->getParent(), Exit);
+ BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
+ Header->getParent(), Exit);
+
+ Type *I32Ty = Type::getInt64Ty(Ctx);
+ BranchInst::Create(Body, Header);
+ BranchInst::Create(Latch, Body);
+ PHINode *IV =
+ PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
+ IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
+
+ B.SetInsertPoint(Latch);
+ Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
+ Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
+ BranchInst::Create(Header, Exit, Cond, Latch);
+ IV->addIncoming(Inc, Latch);
+
+ BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
+ BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
+ PreheaderBr->setSuccessor(0, Header);
+ DTU.applyUpdatesPermissive({
+ {DominatorTree::Delete, Preheader, Tmp},
+ {DominatorTree::Insert, Header, Body},
+ {DominatorTree::Insert, Body, Latch},
+ {DominatorTree::Insert, Latch, Header},
+ {DominatorTree::Insert, Latch, Exit},
+ {DominatorTree::Insert, Preheader, Header},
+ });
+
+ L->addBasicBlockToLoop(Header, LI);
+ L->addBasicBlockToLoop(Body, LI);
+ L->addBasicBlockToLoop(Latch, LI);
+ return Body;
+}
+
+// Creates the following loop nest skeleton:
+// for C = 0; C < NumColumns; C += TileSize
+// for R = 0; R < NumRows; R += TileSize
+// for K = 0; K < Inner ; K += TileSize
+BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
+ IRBuilderBase &B, DomTreeUpdater &DTU,
+ LoopInfo &LI) {
+ Loop *ColLoop = LI.AllocateLoop();
+ Loop *RowLoop = LI.AllocateLoop();
+ Loop *InnerLoop = LI.AllocateLoop();
+ RowLoop->addChildLoop(InnerLoop);
+ ColLoop->addChildLoop(RowLoop);
+ if (Loop *ParentL = LI.getLoopFor(Start))
+ ParentL->addChildLoop(ColLoop);
+ else
+ LI.addTopLevelLoop(ColLoop);
+
+ BasicBlock *ColBody =
+ CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
+ "cols", B, DTU, ColLoop, LI);
+ BasicBlock *ColLatch = ColBody->getSingleSuccessor();
+ BasicBlock *RowBody =
+ CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize),
+ "rows", B, DTU, RowLoop, LI);
+ RowLoopLatch = RowBody->getSingleSuccessor();
+
+ BasicBlock *InnerBody =
+ CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner),
+ B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI);
+ InnerLoopLatch = InnerBody->getSingleSuccessor();
+ ColumnLoopHeader = ColBody->getSinglePredecessor();
+ RowLoopHeader = RowBody->getSinglePredecessor();
+ InnerLoopHeader = InnerBody->getSinglePredecessor();
+ CurrentRow = &*RowLoopHeader->begin();
+ CurrentCol = &*ColumnLoopHeader->begin();
+ CurrentK = &*InnerLoopHeader->begin();
+
+ return InnerBody;
+}