[OpenMPIRBuilder] Implement collapseLoops.
authorMichael Kruse <llvm-project@meinersbur.de>
Wed, 3 Feb 2021 19:44:00 +0000 (13:44 -0600)
committerMichael Kruse <llvm-project@meinersbur.de>
Thu, 4 Feb 2021 01:12:02 +0000 (19:12 -0600)
The collapseLoops method implements a transformations facilitating the implementation of the collapse-clause. It takes a list of loops from a loop nest and reduces it to a single loop that can be used by other methods that are implemented on just a single loop, such as createStaticWorkshareLoop.

This patch shares some changes with D92974 (such as adding some getters to CanonicalLoopNest), used by both patches.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D93268

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

index 22204d9..4b52588 100644 (file)
@@ -274,6 +274,70 @@ public:
                                          InsertPointTy ComputeIP = {},
                                          const Twine &Name = "loop");
 
+  /// Collapse a loop nest into a single loop.
+  ///
+  /// Merges loops of a loop nest into a single CanonicalLoopNest representation
+  /// that has the same number of innermost loop iterations as the origin loop
+  /// nest. The induction variables of the input loops are derived from the
+  /// collapsed loop's induction variable. This is intended to be used to
+  /// implement OpenMP's collapse clause. Before applying a directive,
+  /// collapseLoops normalizes a loop nest to contain only a single loop and the
+  /// directive's implementation does not need to handle multiple loops itself.
+  /// This does not remove the need to handle all loop nest handling by
+  /// directives, such as the ordered(<n>) clause or the simd schedule-clause
+  /// modifier of the worksharing-loop directive.
+  ///
+  /// Example:
+  /// \code
+  ///   for (int i = 0; i < 7; ++i) // Canonical loop "i"
+  ///     for (int j = 0; j < 9; ++j) // Canonical loop "j"
+  ///       body(i, j);
+  /// \endcode
+  ///
+  /// After collapsing with Loops={i,j}, the loop is changed to
+  /// \code
+  ///   for (int ij = 0; ij < 63; ++ij) {
+  ///     int i = ij / 9;
+  ///     int j = ij % 9;
+  ///     body(i, j);
+  ///   }
+  /// \endcode
+  ///
+  /// In the current implementation, the following limitations apply:
+  ///
+  ///  * All input loops have an induction variable of the same type.
+  ///
+  ///  * The collapsed loop will have the same trip count integer type as the
+  ///    input loops. Therefore it is possible that the collapsed loop cannot
+  ///    represent all iterations of the input loops. For instance, assuming a
+  ///    32 bit integer type, and two input loops both iterating 2^16 times, the
+  ///    theoretical trip count of the collapsed loop would be 2^32 iteration,
+  ///    which cannot be represented in an 32-bit integer. Behavior is undefined
+  ///    in this case.
+  ///
+  ///  * The trip counts of every input loop must be available at \p ComputeIP.
+  ///    Non-rectangular loops are not yet supported.
+  ///
+  ///  * At each nest level, code between a surrounding loop and its nested loop
+  ///    is hoisted into the loop body, and such code will be executed more
+  ///    often than before collapsing (or not at all if any inner loop iteration
+  ///    has a trip count of 0). This is permitted by the OpenMP specification.
+  ///
+  /// \param DL        Debug location for instructions added for collapsing,
+  ///                  such as instructions to compute derive the input loop's
+  ///                  induction variables.
+  /// \param Loops     Loops in the loop nest to collapse. Loops are specified
+  ///                  from outermost-to-innermost and every control flow of a
+  ///                  loop's body must pass through its directly nested loop.
+  /// \param ComputeIP Where additional instruction that compute the collapsed
+  ///                  trip count. If not set, defaults to before the generated
+  ///                  loop.
+  ///
+  /// \returns The CanonicalLoopInfo object representing the collapsed loop.
+  CanonicalLoopInfo *collapseLoops(DebugLoc DL,
+                                   ArrayRef<CanonicalLoopInfo *> Loops,
+                                   InsertPointTy ComputeIP);
+
   /// Modifies the canonical loop to be a statically-scheduled workshare loop.
   ///
   /// This takes a \p LoopInfo representing a canonical loop, such as the one
index 1f67aec..9286394 100644 (file)
@@ -1225,6 +1225,127 @@ static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
   DeleteDeadBlocks(BBVec);
 }
 
+CanonicalLoopInfo *
+OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
+                               InsertPointTy ComputeIP) {
+  assert(Loops.size() >= 1 && "At least one loop required");
+  size_t NumLoops = Loops.size();
+
+  // Nothing to do if there is already just one loop.
+  if (NumLoops == 1)
+    return Loops.front();
+
+  CanonicalLoopInfo *Outermost = Loops.front();
+  CanonicalLoopInfo *Innermost = Loops.back();
+  BasicBlock *OrigPreheader = Outermost->getPreheader();
+  BasicBlock *OrigAfter = Outermost->getAfter();
+  Function *F = OrigPreheader->getParent();
+
+  // Setup the IRBuilder for inserting the trip count computation.
+  Builder.SetCurrentDebugLocation(DL);
+  if (ComputeIP.isSet())
+    Builder.restoreIP(ComputeIP);
+  else
+    Builder.restoreIP(Outermost->getPreheaderIP());
+
+  // Derive the collapsed' loop trip count.
+  // TODO: Find common/largest indvar type.
+  Value *CollapsedTripCount = nullptr;
+  for (CanonicalLoopInfo *L : Loops) {
+    Value *OrigTripCount = L->getTripCount();
+    if (!CollapsedTripCount) {
+      CollapsedTripCount = OrigTripCount;
+      continue;
+    }
+
+    // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
+    CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
+                                           {}, /*HasNUW=*/true);
+  }
+
+  // Create the collapsed loop control flow.
+  CanonicalLoopInfo *Result =
+      createLoopSkeleton(DL, CollapsedTripCount, F,
+                         OrigPreheader->getNextNode(), OrigAfter, "collapsed");
+
+  // Build the collapsed loop body code.
+  // Start with deriving the input loop induction variables from the collapsed
+  // one, using a divmod scheme. To preserve the original loops' order, the
+  // innermost loop use the least significant bits.
+  Builder.restoreIP(Result->getBodyIP());
+
+  Value *Leftover = Result->getIndVar();
+  SmallVector<Value *> NewIndVars;
+  NewIndVars.set_size(NumLoops);
+  for (int i = NumLoops - 1; i >= 1; --i) {
+    Value *OrigTripCount = Loops[i]->getTripCount();
+
+    Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
+    NewIndVars[i] = NewIndVar;
+
+    Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
+  }
+  // Outermost loop gets all the remaining bits.
+  NewIndVars[0] = Leftover;
+
+  // Construct the loop body control flow.
+  // We progressively construct the branch structure following in direction of
+  // the control flow, from the leading in-between code, the loop nest body, the
+  // trailing in-between code, and rejoining the collapsed loop's latch.
+  // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
+  // the ContinueBlock is set, continue with that block. If ContinuePred, use
+  // its predecessors as sources.
+  BasicBlock *ContinueBlock = Result->getBody();
+  BasicBlock *ContinuePred = nullptr;
+  auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
+                                                          BasicBlock *NextSrc) {
+    if (ContinueBlock)
+      redirectTo(ContinueBlock, Dest, DL);
+    else
+      redirectAllPredecessorsTo(ContinuePred, Dest, DL);
+
+    ContinueBlock = nullptr;
+    ContinuePred = NextSrc;
+  };
+
+  // The code before the nested loop of each level.
+  // Because we are sinking it into the nest, it will be executed more often
+  // that the original loop. More sophisticated schemes could keep track of what
+  // the in-between code is and instantiate it only once per thread.
+  for (size_t i = 0; i < NumLoops - 1; ++i)
+    ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
+
+  // Connect the loop nest body.
+  ContinueWith(Innermost->getBody(), Innermost->getLatch());
+
+  // The code after the nested loop at each level.
+  for (size_t i = NumLoops - 1; i > 0; --i)
+    ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
+
+  // Connect the finished loop to the collapsed loop latch.
+  ContinueWith(Result->getLatch(), nullptr);
+
+  // Replace the input loops with the new collapsed loop.
+  redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
+  redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
+
+  // Replace the input loop indvars with the derived ones.
+  for (size_t i = 0; i < NumLoops; ++i)
+    Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
+
+  // Remove unused parts of the input loops.
+  SmallVector<BasicBlock *, 12> OldControlBBs;
+  OldControlBBs.reserve(6 * Loops.size());
+  for (CanonicalLoopInfo *Loop : Loops)
+    Loop->collectControlBlocks(OldControlBBs);
+  removeUnusedBlocksFromParent(OldControlBBs);
+
+#ifndef NDEBUG
+  Result->assertOK();
+#endif
+  return Result;
+}
+
 std::vector<CanonicalLoopInfo *>
 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
                            ArrayRef<Value *> TileSizes) {
index 3c2cc35..16695b0 100644 (file)
@@ -1160,6 +1160,99 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, CollapseNestedLoops) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+
+  IRBuilder<> Builder(BB);
+
+  Type *LCTy = F->getArg(0)->getType();
+  Constant *One = ConstantInt::get(LCTy, 1);
+  Constant *Two = ConstantInt::get(LCTy, 2);
+  Value *OuterTripCount =
+      Builder.CreateAdd(F->getArg(0), Two, "tripcount.outer");
+  Value *InnerTripCount =
+      Builder.CreateAdd(F->getArg(0), One, "tripcount.inner");
+
+  // Fix an insertion point for ComputeIP.
+  BasicBlock *LoopNextEnter =
+      BasicBlock::Create(M->getContext(), "loopnest.enter", F,
+                         Builder.GetInsertBlock()->getNextNode());
+  BranchInst *EnterBr = Builder.CreateBr(LoopNextEnter);
+  InsertPointTy ComputeIP{EnterBr->getParent(), EnterBr->getIterator()};
+
+  Builder.SetInsertPoint(LoopNextEnter);
+  OpenMPIRBuilder::LocationDescription OuterLoc(Builder.saveIP(), DL);
+
+  CanonicalLoopInfo *InnerLoop = nullptr;
+  CallInst *InbetweenLead = nullptr;
+  CallInst *InbetweenTrail = nullptr;
+  CallInst *Call = nullptr;
+  auto OuterLoopBodyGenCB = [&](InsertPointTy OuterCodeGenIP, Value *OuterLC) {
+    Builder.restoreIP(OuterCodeGenIP);
+    InbetweenLead =
+        createPrintfCall(Builder, "In-between lead i=%d\\n", {OuterLC});
+
+    auto InnerLoopBodyGenCB = [&](InsertPointTy InnerCodeGenIP,
+                                  Value *InnerLC) {
+      Builder.restoreIP(InnerCodeGenIP);
+      Call = createPrintfCall(Builder, "body i=%d j=%d\\n", {OuterLC, InnerLC});
+    };
+    InnerLoop = OMPBuilder.createCanonicalLoop(
+        Builder.saveIP(), InnerLoopBodyGenCB, InnerTripCount, "inner");
+
+    Builder.restoreIP(InnerLoop->getAfterIP());
+    InbetweenTrail =
+        createPrintfCall(Builder, "In-between trail i=%d\\n", {OuterLC});
+  };
+  CanonicalLoopInfo *OuterLoop = OMPBuilder.createCanonicalLoop(
+      OuterLoc, OuterLoopBodyGenCB, OuterTripCount, "outer");
+
+  // Finish the function.
+  Builder.restoreIP(OuterLoop->getAfterIP());
+  Builder.CreateRetVoid();
+
+  CanonicalLoopInfo *Collapsed =
+      OMPBuilder.collapseLoops(DL, {OuterLoop, InnerLoop}, ComputeIP);
+
+  OMPBuilder.finalize();
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  // Verify control flow and BB order.
+  BasicBlock *RefOrder[] = {
+      Collapsed->getPreheader(),   Collapsed->getHeader(),
+      Collapsed->getCond(),        Collapsed->getBody(),
+      InbetweenLead->getParent(),  Call->getParent(),
+      InbetweenTrail->getParent(), Collapsed->getLatch(),
+      Collapsed->getExit(),        Collapsed->getAfter(),
+  };
+  EXPECT_TRUE(verifyDFSOrder(F, RefOrder));
+  EXPECT_TRUE(verifyListOrder(F, RefOrder));
+
+  // Verify the total trip count.
+  auto *TripCount = cast<MulOperator>(Collapsed->getTripCount());
+  EXPECT_EQ(TripCount->getOperand(0), OuterTripCount);
+  EXPECT_EQ(TripCount->getOperand(1), InnerTripCount);
+
+  // Verify the changed indvar.
+  auto *OuterIV = cast<BinaryOperator>(Call->getOperand(1));
+  EXPECT_EQ(OuterIV->getOpcode(), Instruction::UDiv);
+  EXPECT_EQ(OuterIV->getParent(), Collapsed->getBody());
+  EXPECT_EQ(OuterIV->getOperand(1), InnerTripCount);
+  EXPECT_EQ(OuterIV->getOperand(0), Collapsed->getIndVar());
+
+  auto *InnerIV = cast<BinaryOperator>(Call->getOperand(2));
+  EXPECT_EQ(InnerIV->getOpcode(), Instruction::URem);
+  EXPECT_EQ(InnerIV->getParent(), Collapsed->getBody());
+  EXPECT_EQ(InnerIV->getOperand(0), Collapsed->getIndVar());
+  EXPECT_EQ(InnerIV->getOperand(1), InnerTripCount);
+
+  EXPECT_EQ(InbetweenLead->getOperand(1), OuterIV);
+  EXPECT_EQ(InbetweenTrail->getOperand(1), OuterIV);
+}
+
 TEST_F(OpenMPIRBuilderTest, TileSingleLoop) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);