[OMPIRBuilder] Add support for simd (loop) directive.
authorArnamoy Bhattacharyya <arnamoy.bhattacharyya@huawei.com>
Wed, 19 Jan 2022 16:18:35 +0000 (11:18 -0500)
committerArnamoy Bhattacharyya <arnamoy.bhattacharyya@huawei.com>
Wed, 19 Jan 2022 16:32:17 +0000 (11:32 -0500)
This patch adds OMPIRBuilder support for the simd directive (without any clause).  This will be a first step towards lowering simd directive in LLVM_Flang.  The patch uses existing CanonicalLoop infrastructure of IRBuilder to add the support.  Also adds necessary code to add llvm.access.group and llvm.loop metadata wherever needed.

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D114379

clang/lib/CodeGen/CGStmtOpenMP.cpp
clang/test/OpenMP/irbuilder_simd.cpp [new file with mode: 0644]
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

index 11af281..0db59dd 100644 (file)
@@ -2584,7 +2584,67 @@ static void emitOMPSimdRegion(CodeGenFunction &CGF, const OMPLoopDirective &S,
   }
 }
 
+static bool isSupportedByOpenMPIRBuilder(const OMPExecutableDirective &S) {
+  // Check for unsupported clauses
+  if (!S.clauses().empty()) {
+    // Currently no clause is supported
+    return false;
+  }
+
+  // Check if we have a statement with the ordered directive.
+  // Visit the statement hierarchy to find a compound statement
+  // with a ordered directive in it.
+  if (const auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(S.getRawStmt())) {
+    if (const Stmt *SyntacticalLoop = CanonLoop->getLoopStmt()) {
+      for (const Stmt *SubStmt : SyntacticalLoop->children()) {
+        if (!SubStmt)
+          continue;
+        if (const CompoundStmt *CS = dyn_cast<CompoundStmt>(SubStmt)) {
+          for (const Stmt *CSSubStmt : CS->children()) {
+            if (!CSSubStmt)
+              continue;
+            if (isa<OMPOrderedDirective>(CSSubStmt)) {
+              return false;
+            }
+          }
+        }
+      }
+    }
+  }
+  return true;
+}
+
 void CodeGenFunction::EmitOMPSimdDirective(const OMPSimdDirective &S) {
+  bool UseOMPIRBuilder =
+      CGM.getLangOpts().OpenMPIRBuilder && isSupportedByOpenMPIRBuilder(S);
+  if (UseOMPIRBuilder) {
+    auto &&CodeGenIRBuilder = [this, &S, UseOMPIRBuilder](CodeGenFunction &CGF,
+                                                          PrePostActionTy &) {
+      // Use the OpenMPIRBuilder if enabled.
+      if (UseOMPIRBuilder) {
+        // Emit the associated statement and get its loop representation.
+        llvm::DebugLoc DL = SourceLocToDebugLoc(S.getBeginLoc());
+        const Stmt *Inner = S.getRawStmt();
+        llvm::CanonicalLoopInfo *CLI =
+            EmitOMPCollapsedCanonicalLoopNest(Inner, 1);
+
+        llvm::OpenMPIRBuilder &OMPBuilder =
+            CGM.getOpenMPRuntime().getOMPBuilder();
+        // Add SIMD specific metadata
+        OMPBuilder.applySimd(DL, CLI);
+        return;
+      }
+    };
+    {
+      auto LPCRegion =
+          CGOpenMPRuntime::LastprivateConditionalRAII::disable(*this, S);
+      OMPLexicalScope Scope(*this, S, OMPD_unknown);
+      CGM.getOpenMPRuntime().emitInlinedDirective(*this, OMPD_simd,
+                                                  CodeGenIRBuilder);
+    }
+    return;
+  }
+
   ParentLoopDirectiveForScanRegion ScanRegion(*this, S);
   OMPFirstScanLoop = true;
   auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
diff --git a/clang/test/OpenMP/irbuilder_simd.cpp b/clang/test/OpenMP/irbuilder_simd.cpp
new file mode 100644 (file)
index 0000000..9f36929
--- /dev/null
@@ -0,0 +1,71 @@
+// RUN: %clang_cc1 -fopenmp-enable-irbuilder -verify -fopenmp -fopenmp-version=45 -x c++ -triple x86_64-unknown-unknown -emit-llvm %s -o - | FileCheck %s 
+// expected-no-diagnostics
+
+struct S {
+  int a, b;
+};
+
+struct P {
+  int a, b;
+};
+
+void simple(float *a, float *b, int *c) {
+  S s, *p;
+  P pp;
+#pragma omp simd
+  for (int i = 3; i < 32; i += 5) {
+    // llvm.access.group test
+    // CHECK: %[[A_ADDR:.+]] = alloca float*, align 8
+    // CHECK: %[[B_ADDR:.+]] = alloca float*, align 8
+    // CHECK: %[[S:.+]] = alloca %struct.S, align 4
+    // CHECK: %[[P:.+]] = alloca %struct.S*, align 8
+    // CHECK: %[[I:.+]] = alloca i32, align 4
+    // CHECK: %[[TMP3:.+]] = load float*, float** %[[B_ADDR:.+]], align 8, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[TMP4:.+]] = load i32, i32* %[[I:.+]], align 4, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[IDXPROM:.+]] = sext i32 %[[TMP4:.+]] to i64
+    // CHECK-NEXT: %[[ARRAYIDX:.+]] = getelementptr inbounds float, float* %[[TMP3:.+]], i64 %[[IDXPROM:.+]]
+    // CHECK-NEXT: %[[TMP5:.+]] = load float, float* %[[ARRAYIDX:.+]], align 4, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[A2:.+]] = getelementptr inbounds %struct.S, %struct.S* %[[S:.+]], i32 0, i32 0
+    // CHECK-NEXT: %[[TMP6:.+]] = load i32, i32* %[[A2:.+]], align 4, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[CONV:.+]] = sitofp i32 %[[TMP6:.+]] to float
+    // CHECK-NEXT: %[[ADD:.+]] = fadd float %[[TMP5:.+]], %[[CONV:.+]]
+    // CHECK-NEXT: %[[TMP7:.+]] = load %struct.S*, %struct.S** %[[P:.+]], align 8, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[A3:.+]] = getelementptr inbounds %struct.S, %struct.S* %[[TMP7:.+]], i32 0, i32 0
+    // CHECK-NEXT: %[[TMP8:.+]] = load i32, i32* %[[A3:.+]], align 4, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[CONV4:.+]] = sitofp i32 %[[TMP8:.+]] to float
+    // CHECK-NEXT: %[[ADD5:.+]] = fadd float %[[ADD:.+]], %[[CONV4:.+]]
+    // CHECK-NEXT: %[[TMP9:.+]] = load float*, float** %[[A_ADDR:.+]], align 8, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[TMP10:.+]] = load i32, i32* %[[I:.+]], align 4, !llvm.access.group ![[META3:[0-9]+]]
+    // CHECK-NEXT: %[[IDXPROM6:.+]] = sext i32 %[[TMP10:.+]] to i64
+    // CHECK-NEXT: %[[ARRAYIDX7:.+]] = getelementptr inbounds float, float* %[[TMP9:.+]], i64 %[[IDXPROM6:.+]]
+    // CHECK-NEXT: store float %[[ADD5:.+]], float* %[[ARRAYIDX7:.+]], align 4, !llvm.access.group ![[META3:[0-9]+]]
+    // llvm.loop test
+    // CHECK: %[[OMP_LOOPDOTNEXT:.+]] = add nuw i32 %[[OMP_LOOPDOTIV:.+]], 1
+    // CHECK-NEXT: br label %omp_loop.header, !llvm.loop ![[META4:[0-9]+]]
+    a[i] = b[i] + s.a + p->a;
+  }
+
+#pragma omp simd
+  for (int j = 3; j < 32; j += 5) {
+    // test if unique access groups were used for a second loop
+    // CHECK: %[[A22:.+]] = getelementptr inbounds %struct.P, %struct.P* %[[PP:.+]], i32 0, i32 0
+    // CHECK-NEXT: %[[TMP14:.+]] = load i32, i32* %[[A22:.+]], align 4, !llvm.access.group ![[META7:[0-9]+]]
+    // CHECK-NEXT: %[[TMP15:.+]] = load i32*, i32** %[[C_ADDR:.+]], align 8, !llvm.access.group ![[META7:[0-9]+]]
+    // CHECK-NEXT: %[[TMP16:.+]] = load i32, i32* %[[J:.+]], align 4, !llvm.access.group ![[META7:[0-9]+]]
+    // CHECK-NEXT: %[[IDXPROM23:.+]] = sext i32 %[[TMP16:.+]] to i64
+    // CHECK-NEXT: %[[ARRAYIDX24:.+]] = getelementptr inbounds i32, i32* %[[TMP15:.+]], i64 %[[IDXPROM23:.+]]
+    // CHECK-NEXT: store i32 %[[TMP14:.+]], i32* %[[ARRAYIDX24:.+]], align 4, !llvm.access.group ![[META7:[0-9]+]]
+    // check llvm.loop metadata
+    // CHECK: %[[OMP_LOOPDOTNEXT:.+]] = add nuw i32 %[[OMP_LOOPDOTIV:.+]], 1
+    // CHECK-NEXT: br label %[[OMP_LLOP_BODY:.*]], !llvm.loop ![[META8:[0-9]+]]
+    c[j] = pp.a;
+  }
+}
+
+// CHECK: ![[META3:[0-9]+]] = distinct !{}
+// CHECK-NEXT: ![[META4]]  = distinct !{![[META4]], ![[META5:[0-9]+]], ![[META6:[0-9]+]]}
+// CHECK-NEXT: ![[META5]]  = !{!"llvm.loop.parallel_accesses", ![[META3]]}
+// CHECK-NEXT: ![[META6]]  = !{!"llvm.loop.vectorize.enable", i1 true}
+// CHECK-NEXT: ![[META7:[0-9]+]] = distinct !{}
+// CHECK-NEXT: ![[META8]]  = distinct !{![[META8]], ![[META9:[0-9]+]], ![[META6]]}
+// CHECK-NEXT: ![[META9]]  = !{!"llvm.loop.parallel_accesses", ![[META7]]}
\ No newline at end of file
index d80c521..cfebdbd 100644 (file)
@@ -517,6 +517,12 @@ public:
   void unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop, int32_t Factor,
                          CanonicalLoopInfo **UnrolledCLI);
 
+  /// Add metadata to simd-ize a loop.
+  ///
+  /// \param DL   Debug location for instructions added by unrolling.
+  /// \param Loop The loop to simd-ize.
+  void applySimd(DebugLoc DL, CanonicalLoopInfo *Loop);
+
   /// Generator for '#omp flush'
   ///
   /// \param Loc The location where the flush directive was encountered
index 61194da..62bef52 100644 (file)
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Triple.h"
 #include "llvm/Analysis/AssumptionCache.h"
@@ -2145,6 +2146,19 @@ static void addLoopMetadata(CanonicalLoopInfo *Loop,
   Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID);
 }
 
+/// Attach llvm.access.group metadata to the memref instructions of \p Block
+static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
+                            LoopInfo &LI) {
+  for (Instruction &I : *Block) {
+    if (I.mayReadOrWriteMemory()) {
+      // TODO: This instruction may already have access group from
+      // other pragmas e.g. #pragma clang loop vectorize.  Append
+      // so that the existing metadata is not overwritten.
+      I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
+    }
+  }
+}
+
 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
   LLVMContext &Ctx = Builder.getContext();
   addLoopMetadata(
@@ -2160,6 +2174,53 @@ void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
             });
 }
 
+void OpenMPIRBuilder::applySimd(DebugLoc, CanonicalLoopInfo *CanonicalLoop) {
+  LLVMContext &Ctx = Builder.getContext();
+
+  Function *F = CanonicalLoop->getFunction();
+
+  FunctionAnalysisManager FAM;
+  FAM.registerPass([]() { return DominatorTreeAnalysis(); });
+  FAM.registerPass([]() { return LoopAnalysis(); });
+  FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
+
+  LoopAnalysis LIA;
+  LoopInfo &&LI = LIA.run(*F, FAM);
+
+  Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
+
+  SmallSet<BasicBlock *, 8> Reachable;
+
+  // Get the basic blocks from the loop in which memref instructions
+  // can be found.
+  // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
+  // preferably without running any passes.
+  for (BasicBlock *Block : L->getBlocks()) {
+    if (Block == CanonicalLoop->getCond() ||
+        Block == CanonicalLoop->getHeader())
+      continue;
+    Reachable.insert(Block);
+  }
+
+  // Add access group metadata to memory-access instructions.
+  MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
+  for (BasicBlock *BB : Reachable)
+    addSimdMetadata(BB, AccessGroup, LI);
+
+  // Use the above access group metadata to create loop level
+  // metadata, which should be distinct for each loop.
+  ConstantAsMetadata *BoolConst =
+      ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
+  // TODO:  If the loop has existing parallel access metadata, have
+  // to combine two lists.
+  addLoopMetadata(
+      CanonicalLoop,
+      {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"),
+                         AccessGroup}),
+       MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
+                         BoolConst})});
+}
+
 /// Create the TargetMachine object to query the backend for optimization
 /// preferences.
 ///
index 4a75a8a..d00f799 100644 (file)
@@ -1662,6 +1662,37 @@ TEST_F(OpenMPIRBuilderTest, TileSingleLoopCounts) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, ApplySimd) {
+  OpenMPIRBuilder OMPBuilder(*M);
+
+  CanonicalLoopInfo *CLI = buildSingleLoopFunction(DL, OMPBuilder);
+
+  // Simd-ize the loop.
+  OMPBuilder.applySimd(DL, CLI);
+
+  OMPBuilder.finalize();
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  PassBuilder PB;
+  FunctionAnalysisManager FAM;
+  PB.registerFunctionAnalyses(FAM);
+  LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);
+
+  const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
+  EXPECT_EQ(TopLvl.size(), 1u);
+
+  Loop *L = TopLvl.front();
+  EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
+  EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
+
+  // Check for llvm.access.group metadata attached to the printf
+  // function in the loop body.
+  BasicBlock *LoopBody = CLI->getBody();
+  EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
+    return I.getMetadata("llvm.access.group") != nullptr;
+  }));
+}
+
 TEST_F(OpenMPIRBuilderTest, UnrollLoopFull) {
   OpenMPIRBuilder OMPBuilder(*M);