Add support for lowering simd if clause to LLVM IR
authorDominik Adamski <dominik.adamski@amd.com>
Thu, 28 Jul 2022 08:57:40 +0000 (03:57 -0500)
committerDominik Adamski <dominik.adamski@amd.com>
Mon, 1 Aug 2022 09:43:32 +0000 (04:43 -0500)
Scope of changes:
  1) Added new function to generate loop versioning
  2) Added support for if clause to applySimd function
  2) Added tests which confirm that lowering is successful

If ifCond is specified, then collapsed loop is duplicated and if branch
is added. Duplicated loop is executed if simd ifCond is evaluated to false.

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D129368

Signed-off-by: Dominik Adamski <dominik.adamski@amd.com>
clang/lib/CodeGen/CGStmtOpenMP.cpp
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/openmp-llvm.mlir

index aa55cda..962620f 100644 (file)
@@ -2646,7 +2646,9 @@ void CodeGenFunction::EmitOMPSimdDirective(const OMPSimdDirective &S) {
           auto *Val = cast<llvm::ConstantInt>(Len.getScalarVal());
           Simdlen = Val;
         }
-        OMPBuilder.applySimd(CLI, Simdlen);
+        // Add simd metadata to the collapsed loop. Do not generate
+        // another loop for if clause. Support for if clause is done earlier.
+        OMPBuilder.applySimd(CLI, /*IfCond*/ nullptr, Simdlen);
         return;
       }
     };
index 40ca2da..5ae9baa 100644 (file)
@@ -14,6 +14,7 @@
 #ifndef LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
 #define LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
 
+#include "llvm/Analysis/MemorySSAUpdater.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/IR/IRBuilder.h"
@@ -467,6 +468,20 @@ private:
                                           bool NeedsBarrier,
                                           Value *Chunk = nullptr);
 
+  /// Create alternative version of the loop to support if clause
+  ///
+  /// OpenMP if clause can require to generate second loop. This loop
+  /// will be executed when if clause condition is not met. createIfVersion
+  /// adds branch instruction to the copied loop if \p  ifCond is not met.
+  ///
+  /// \param Loop       Original loop which should be versioned.
+  /// \param IfCond     Value which corresponds to if clause condition
+  /// \param VMap       Value to value map to define relation between
+  ///                   original and copied loop values and loop blocks.
+  /// \param NamePrefix Optional name prefix for if.then if.else blocks.
+  void createIfVersion(CanonicalLoopInfo *Loop, Value *IfCond,
+                       ValueToValueMapTy &VMap, const Twine &NamePrefix = "");
+
 public:
   /// Modifies the canonical loop to be a workshare loop.
   ///
@@ -597,11 +612,15 @@ public:
   void unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop, int32_t Factor,
                          CanonicalLoopInfo **UnrolledCLI);
 
-  /// Add metadata to simd-ize a loop.
+  /// Add metadata to simd-ize a loop. If IfCond is not nullptr, the loop
+  /// is cloned. The metadata which prevents vectorization is added to
+  /// to the cloned loop. The cloned loop is executed when ifCond is evaluated
+  /// to false.
   ///
   /// \param Loop    The loop to simd-ize.
+  /// \param IfCond  The value which corresponds to the if clause condition.
   /// \param Simdlen The Simdlen length to apply to the simd loop.
-  void applySimd(CanonicalLoopInfo *Loop, ConstantInt *Simdlen);
+  void applySimd(CanonicalLoopInfo *Loop, Value *IfCond, ConstantInt *Simdlen);
 
   /// Generator for '#omp flush'
   ///
index cee4cdd..736976d 100644 (file)
@@ -34,6 +34,7 @@
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/CodeExtractor.h"
 #include "llvm/Transforms/Utils/LoopPeel.h"
 #include "llvm/Transforms/Utils/UnrollLoop.h"
@@ -2839,32 +2840,40 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
   return Result;
 }
 
-/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
-/// loop already has metadata, the loop properties are appended.
-static void addLoopMetadata(CanonicalLoopInfo *Loop,
-                            ArrayRef<Metadata *> Properties) {
-  assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
-
+/// Attach metadata \p Properties to the basic block described by \p BB. If the
+/// basic block already has metadata, the basic block properties are appended.
+static void addBasicBlockMetadata(BasicBlock *BB,
+                                  ArrayRef<Metadata *> Properties) {
   // Nothing to do if no property to attach.
   if (Properties.empty())
     return;
 
-  LLVMContext &Ctx = Loop->getFunction()->getContext();
-  SmallVector<Metadata *> NewLoopProperties;
-  NewLoopProperties.push_back(nullptr);
+  LLVMContext &Ctx = BB->getContext();
+  SmallVector<Metadata *> NewProperties;
+  NewProperties.push_back(nullptr);
 
-  // If the loop already has metadata, prepend it to the new metadata.
-  BasicBlock *Latch = Loop->getLatch();
-  assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
-  MDNode *Existing = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop);
+  // If the basic block already has metadata, prepend it to the new metadata.
+  MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
   if (Existing)
-    append_range(NewLoopProperties, drop_begin(Existing->operands(), 1));
+    append_range(NewProperties, drop_begin(Existing->operands(), 1));
 
-  append_range(NewLoopProperties, Properties);
-  MDNode *LoopID = MDNode::getDistinct(Ctx, NewLoopProperties);
-  LoopID->replaceOperandWith(0, LoopID);
+  append_range(NewProperties, Properties);
+  MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
+  BasicBlockID->replaceOperandWith(0, BasicBlockID);
 
-  Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID);
+  BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
+}
+
+/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
+/// loop already has metadata, the loop properties are appended.
+static void addLoopMetadata(CanonicalLoopInfo *Loop,
+                            ArrayRef<Metadata *> Properties) {
+  assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
+
+  // Attach metadata to the loop's latch
+  BasicBlock *Latch = Loop->getLatch();
+  assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
+  addBasicBlockMetadata(Latch, Properties);
 }
 
 /// Attach llvm.access.group metadata to the memref instructions of \p Block
@@ -2895,12 +2904,77 @@ void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
             });
 }
 
-void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
+void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
+                                      Value *IfCond, ValueToValueMapTy &VMap,
+                                      const Twine &NamePrefix) {
+  Function *F = CanonicalLoop->getFunction();
+
+  // Define where if branch should be inserted
+  Instruction *SplitBefore;
+  if (Instruction::classof(IfCond)) {
+    SplitBefore = dyn_cast<Instruction>(IfCond);
+  } else {
+    SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
+  }
+
+  // TODO: We should not rely on pass manager. Currently we use pass manager
+  // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
+  // object. We should have a method  which returns all blocks between
+  // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
+  FunctionAnalysisManager FAM;
+  FAM.registerPass([]() { return DominatorTreeAnalysis(); });
+  FAM.registerPass([]() { return LoopAnalysis(); });
+  FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
+
+  // Get the loop which needs to be cloned
+  LoopAnalysis LIA;
+  LoopInfo &&LI = LIA.run(*F, FAM);
+  Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
+
+  // Create additional blocks for the if statement
+  BasicBlock *Head = SplitBefore->getParent();
+  Instruction *HeadOldTerm = Head->getTerminator();
+  llvm::LLVMContext &C = Head->getContext();
+  llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
+      C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
+  llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
+      C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
+
+  // Create if condition branch.
+  Builder.SetInsertPoint(HeadOldTerm);
+  Instruction *BrInstr =
+      Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
+  InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
+  // Then block contains branch to omp loop which needs to be vectorized
+  spliceBB(IP, ThenBlock, false);
+  ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
+
+  Builder.SetInsertPoint(ElseBlock);
+
+  // Clone loop for the else branch
+  SmallVector<BasicBlock *, 8> NewBlocks;
+
+  VMap[CanonicalLoop->getPreheader()] = ElseBlock;
+  for (BasicBlock *Block : L->getBlocks()) {
+    BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
+    NewBB->moveBefore(CanonicalLoop->getExit());
+    VMap[Block] = NewBB;
+    NewBlocks.push_back(NewBB);
+  }
+  remapInstructionsInBlocks(NewBlocks, VMap);
+  Builder.CreateBr(NewBlocks.front());
+}
+
+void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, Value *IfCond,
                                 ConstantInt *Simdlen) {
   LLVMContext &Ctx = Builder.getContext();
 
   Function *F = CanonicalLoop->getFunction();
 
+  // TODO: We should not rely on pass manager. Currently we use pass manager
+  // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
+  // object. We should have a method  which returns all blocks between
+  // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
   FunctionAnalysisManager FAM;
   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
   FAM.registerPass([]() { return LoopAnalysis(); });
@@ -2911,6 +2985,24 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
 
   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
 
+  if (IfCond) {
+    ValueToValueMapTy VMap;
+    createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
+    // Add metadata to the cloned loop which disables vectorization
+    Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
+    assert(MappedLatch &&
+           "Cannot find value which corresponds to original loop latch");
+    assert(isa<BasicBlock>(MappedLatch) &&
+           "Cannot cast mapped latch block value to BasicBlock");
+    BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
+    ConstantAsMetadata *BoolConst =
+        ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
+    addBasicBlockMetadata(
+        NewLatchBlock,
+        {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
+                           BoolConst})});
+  }
+
   SmallSet<BasicBlock *, 8> Reachable;
 
   // Get the basic blocks from the loop in which memref instructions
index 55afea6..7e3b548 100644 (file)
@@ -1771,7 +1771,7 @@ TEST_F(OpenMPIRBuilderTest, ApplySimd) {
   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
 
   // Simd-ize the loop.
-  OMPBuilder.applySimd(CLI, nullptr);
+  OMPBuilder.applySimd(CLI, /* IfCond */ nullptr, /* Simdlen */ nullptr);
 
   OMPBuilder.finalize();
   EXPECT_FALSE(verifyModule(*M, &errs()));
@@ -1802,7 +1802,8 @@ TEST_F(OpenMPIRBuilderTest, ApplySimdlen) {
   CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
 
   // Simd-ize the loop.
-  OMPBuilder.applySimd(CLI, ConstantInt::get(Type::getInt32Ty(Ctx), 3));
+  OMPBuilder.applySimd(CLI, /*IfCond */ nullptr,
+                       ConstantInt::get(Type::getInt32Ty(Ctx), 3));
 
   OMPBuilder.finalize();
   EXPECT_FALSE(verifyModule(*M, &errs()));
@@ -1828,6 +1829,54 @@ TEST_F(OpenMPIRBuilderTest, ApplySimdlen) {
   }));
 }
 
+TEST_F(OpenMPIRBuilderTest, ApplySimdLoopIf) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  IRBuilder<> Builder(BB);
+  AllocaInst *Alloc1 = Builder.CreateAlloca(Builder.getInt32Ty());
+  AllocaInst *Alloc2 = Builder.CreateAlloca(Builder.getInt32Ty());
+
+  // Generation of if condition
+  Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 0U), Alloc1);
+  Builder.CreateStore(ConstantInt::get(Type::getInt32Ty(Ctx), 1U), Alloc2);
+  LoadInst *Load1 = Builder.CreateLoad(Alloc1->getAllocatedType(), Alloc1);
+  LoadInst *Load2 = Builder.CreateLoad(Alloc2->getAllocatedType(), Alloc2);
+
+  Value *IfCmp = Builder.CreateICmpNE(Load1, Load2);
+
+  CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder, 32);
+
+  // Simd-ize the loop with if condition
+  OMPBuilder.applySimd(CLI, IfCmp, ConstantInt::get(Type::getInt32Ty(Ctx), 3));
+
+  OMPBuilder.finalize();
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  PassBuilder PB;
+  FunctionAnalysisManager FAM;
+  PB.registerFunctionAnalyses(FAM);
+  LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
+
+  // Check if there are two loops (one with enabled vectorization)
+  const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
+  EXPECT_EQ(TopLvl.size(), 2u);
+
+  Loop *L = TopLvl[0];
+  EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
+  EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
+  EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 3);
+
+  // The second loop should have disabled vectorization
+  L = TopLvl[1];
+  EXPECT_FALSE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
+  EXPECT_FALSE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
+  // Check for llvm.access.group metadata attached to the printf
+  // function in the loop body.
+  BasicBlock *LoopBody = CLI->getBody();
+  EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
+    return I.getMetadata("llvm.access.group") != nullptr;
+  }));
+}
+
 TEST_F(OpenMPIRBuilderTest, UnrollLoopFull) {
   OpenMPIRBuilder OMPBuilder(*M);
 
index 34c5323..85ec47a 100644 (file)
@@ -912,11 +912,6 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
   LogicalResult bodyGenStatus = success();
-
-  // TODO: The code generation for if clause is not supported yet.
-  if (loop.if_expr())
-    return failure();
-
   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
     // Make sure further conversions know about the induction variable.
     moduleTranslation.mapValue(
@@ -975,7 +970,10 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (llvm::Optional<uint64_t> simdlenVar = loop.simdlen())
     simdlen = builder.getInt64(simdlenVar.value());
 
-  ompBuilder->applySimd(loopInfo, simdlen);
+  ompBuilder->applySimd(
+      loopInfo,
+      loop.if_expr() ? moduleTranslation.lookupValue(loop.if_expr()) : nullptr,
+      simdlen);
 
   builder.restoreIP(afterIP);
   return success();
index 4b36ed0..a3e6aa2 100644 (file)
@@ -750,6 +750,34 @@ llvm.func @simdloop_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64
 
 // -----
 
+// CHECK-LABEL: @simdloop_if
+llvm.func @simdloop_if(%arg0: !llvm.ptr<i32> {fir.bindc_name = "n"}, %arg1: !llvm.ptr<i32> {fir.bindc_name = "threshold"}) {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {adapt.valuebyref, in_type = i32, operand_segment_sizes = dense<0> : vector<2xi32>} : (i64) -> !llvm.ptr<i32>
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.alloca %2 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = dense<0> : vector<2xi32>, uniq_name = "_QFtest_simdEi"} : (i64) -> !llvm.ptr<i32>
+  %4 = llvm.mlir.constant(0 : i32) : i32
+  %5 = llvm.load %arg0 : !llvm.ptr<i32>
+  %6 = llvm.mlir.constant(1 : i32) : i32
+  %7 = llvm.load %arg0 : !llvm.ptr<i32>
+  %8 = llvm.load %arg1 : !llvm.ptr<i32>
+  %9 = llvm.icmp "sge" %7, %8 : i32
+  omp.simdloop   if(%9) for  (%arg2) : i32 = (%4) to (%5) inclusive step (%6) {
+    // The form of the emitted IR is controlled by OpenMPIRBuilder and
+    // tested there. Just check that the right metadata is added.
+    // CHECK: llvm.access.group
+    llvm.store %arg2, %1 : !llvm.ptr<i32>
+    omp.yield
+  }
+  llvm.return
+}
+// Be sure that llvm.loop.vectorize.enable metadata appears twice
+// CHECK: llvm.loop.parallel_accesses
+// CHECK-NEXT: llvm.loop.vectorize.enable
+// CHECK: llvm.loop.vectorize.enable
+
+// -----
+
 llvm.func @body(i64)
 
 llvm.func @test_omp_wsloop_ordered(%lb : i64, %ub : i64, %step : i64) -> () {