From 2c6e8b4636700f22a76eeda01e4a3258692b80f3 Mon Sep 17 00:00:00 2001 From: Francis Visoiu Mistrih Date: Wed, 20 Jul 2022 11:12:30 +0200 Subject: [PATCH] [Matrix] Refactor tiled loops in a struct. NFC The three loops have the same structure: index, header, latch. --- llvm/include/llvm/Transforms/Utils/MatrixUtils.h | 47 ++++++++++------------ .../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 22 +++++----- llvm/lib/Transforms/Utils/MatrixUtils.cpp | 42 +++++++++---------- 3 files changed, 54 insertions(+), 57 deletions(-) diff --git a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h index 39a0d4b..ffad570 100644 --- a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h +++ b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h @@ -25,9 +25,9 @@ class IRBuilderBase; /// A helper struct to create IR loop nests for tiling in IR of the following /// form: -/// for CurrentColumn = 0..NumColumns -/// for CurrentRow = 0..NumRows -/// for CurrentInner = 0..NumInner +/// for ColumnLoop.Index = 0..NumColumns +/// for RowLoop.Index = 0..NumRows +/// for KLoop.Index = 0..NumInner struct TileInfo { /// Number of rows of the matrix. unsigned NumRows; @@ -42,26 +42,21 @@ struct TileInfo { /// Number of rows/columns in a tile. unsigned TileSize = -1; - /// Start row of the current tile to compute. - Value *CurrentRow; - - /// Start column of the current tile to compute. - Value *CurrentCol; - - /// Current tile offset during the tile computation. - Value *CurrentK; - - /// Header of the outermost loop iterating from 0..NumColumns. - BasicBlock *ColumnLoopHeader = nullptr; - - /// Header of the second loop iterating from 0..NumRows. - BasicBlock *RowLoopHeader = nullptr; - /// Latch of the second loop iterating from 0..NumRows. - BasicBlock *RowLoopLatch = nullptr; - /// Header of the innermost loop iterating from 0..NumInner. - BasicBlock *InnerLoopHeader = nullptr; - /// Latch of the innermost loop iterating from 0..NumInner. - BasicBlock *InnerLoopLatch = nullptr; + /// Properties of a single loop used when generating the tiled loop nest. + struct MatrixLoop { + /// The index updated on every iteration. + Value *Index = nullptr; + /// The header and latch of the loop. + BasicBlock *Header = nullptr; + BasicBlock *Latch = nullptr; + }; + + /// The loop iterating on the rows. + MatrixLoop RowLoop; + /// The loop iterating on the columns. + MatrixLoop ColumnLoop; + /// The loop iterating on k (inner dimension). + MatrixLoop KLoop; TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner, unsigned TileSize) @@ -72,9 +67,9 @@ struct TileInfo { /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch /// fields. /// - /// for CurrentColumn = 0..NumColumns - /// for CurrentRow = 0..NumRows - /// for CurrentInner = 0..NumInner + /// for ColumnLoop.Index = 0..NumColumns + /// for RowLoop.Index = 0..NumRows + /// for InnerLoop.Index = 0..NumInner BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, DomTreeUpdater &DTU, LoopInfo &LI); diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index c059066..73cd92d 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1423,13 +1423,13 @@ public: FixedVectorType::get(MatMul->getType()->getScalarType(), TileSize); MatrixTy TileResult; // Insert in the inner loop header. - Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); + Builder.SetInsertPoint(TI.KLoop.Header->getTerminator()); // Create PHI nodes for the result columns to accumulate across iterations. SmallVector ColumnPhis; for (unsigned I = 0; I < TileSize; I++) { auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I)); Phi->addIncoming(ConstantAggregateZero::get(TileVecTy), - TI.RowLoopHeader->getSingleSuccessor()); + TI.RowLoop.Header->getSingleSuccessor()); TileResult.addVector(Phi); ColumnPhis.push_back(Phi); } @@ -1438,27 +1438,29 @@ public: // Res += Load(CurrentRow, K) * Load(K, CurrentColumn) Builder.SetInsertPoint(InnerBody->getTerminator()); // Load tiles of the operands. - MatrixTy A = loadMatrix(LPtr, {}, false, LShape, TI.CurrentRow, TI.CurrentK, - {TileSize, TileSize}, EltType, Builder); - MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, - {TileSize, TileSize}, EltType, Builder); + MatrixTy A = + loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index, + {TileSize, TileSize}, EltType, Builder); + MatrixTy B = + loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index, + {TileSize, TileSize}, EltType, Builder); emitMatrixMultiply(TileResult, A, B, Builder, true, false, getFastMathFlags(MatMul)); // Store result after the inner loop is done. - Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); + Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator()); storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), Store->isVolatile(), {LShape.NumRows, RShape.NumColumns}, - TI.CurrentRow, TI.CurrentCol, EltType, Builder); + TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder); for (unsigned I = 0; I < TileResult.getNumVectors(); I++) - ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.InnerLoopLatch); + ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch); // Force unrolling of a few iterations of the inner loop, to make sure there // is enough work per iteration. // FIXME: The unroller should make this decision directly instead, but // currently the cost-model is not up to the task. unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize); - addStringMetadataToLoop(LI->getLoopFor(TI.InnerLoopHeader), + addStringMetadataToLoop(LI->getLoopFor(TI.KLoop.Header), "llvm.loop.unroll.count", InnerLoopUnrollCount); } diff --git a/llvm/lib/Transforms/Utils/MatrixUtils.cpp b/llvm/lib/Transforms/Utils/MatrixUtils.cpp index 6a13763..e218773 100644 --- a/llvm/lib/Transforms/Utils/MatrixUtils.cpp +++ b/llvm/lib/Transforms/Utils/MatrixUtils.cpp @@ -70,35 +70,35 @@ BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, DomTreeUpdater &DTU, LoopInfo &LI) { - Loop *ColLoop = LI.AllocateLoop(); - Loop *RowLoop = LI.AllocateLoop(); - Loop *InnerLoop = LI.AllocateLoop(); - RowLoop->addChildLoop(InnerLoop); - ColLoop->addChildLoop(RowLoop); + Loop *ColumnLoopInfo = LI.AllocateLoop(); + Loop *RowLoopInfo = LI.AllocateLoop(); + Loop *KLoopInfo = LI.AllocateLoop(); + RowLoopInfo->addChildLoop(KLoopInfo); + ColumnLoopInfo->addChildLoop(RowLoopInfo); if (Loop *ParentL = LI.getLoopFor(Start)) - ParentL->addChildLoop(ColLoop); + ParentL->addChildLoop(ColumnLoopInfo); else - LI.addTopLevelLoop(ColLoop); + LI.addTopLevelLoop(ColumnLoopInfo); BasicBlock *ColBody = CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize), - "cols", B, DTU, ColLoop, LI); - BasicBlock *ColLatch = ColBody->getSingleSuccessor(); + "cols", B, DTU, ColumnLoopInfo, LI); + ColumnLoop.Latch = ColBody->getSingleSuccessor(); BasicBlock *RowBody = - CreateLoop(ColBody, ColLatch, B.getInt64(NumRows), B.getInt64(TileSize), - "rows", B, DTU, RowLoop, LI); - RowLoopLatch = RowBody->getSingleSuccessor(); + CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows), + B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI); + RowLoop.Latch = RowBody->getSingleSuccessor(); BasicBlock *InnerBody = - CreateLoop(RowBody, RowLoopLatch, B.getInt64(NumInner), - B.getInt64(TileSize), "inner", B, DTU, InnerLoop, LI); - InnerLoopLatch = InnerBody->getSingleSuccessor(); - ColumnLoopHeader = ColBody->getSinglePredecessor(); - RowLoopHeader = RowBody->getSinglePredecessor(); - InnerLoopHeader = InnerBody->getSinglePredecessor(); - CurrentRow = &*RowLoopHeader->begin(); - CurrentCol = &*ColumnLoopHeader->begin(); - CurrentK = &*InnerLoopHeader->begin(); + CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner), + B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI); + KLoop.Latch = InnerBody->getSingleSuccessor(); + ColumnLoop.Header = ColBody->getSinglePredecessor(); + RowLoop.Header = RowBody->getSinglePredecessor(); + KLoop.Header = InnerBody->getSinglePredecessor(); + RowLoop.Index = &*RowLoop.Header->begin(); + ColumnLoop.Index = &*ColumnLoop.Header->begin(); + KLoop.Index = &*KLoop.Header->begin(); return InnerBody; } -- 2.7.4