[Flang][OpenMP][OMPIRBuilder] Add lowering of TargetOp for device codegen
authorSergio Afonso <safonsof@amd.com>
Mon, 10 Apr 2023 13:24:25 +0000 (14:24 +0100)
committerSergio Afonso <safonsof@amd.com>
Thu, 18 May 2023 14:14:03 +0000 (15:14 +0100)
This patch adds support in the `OpenMPIRBuilder` for generating working
device code for OpenMP target regions. It generates and handles the
result of a call to `__kmpc_target_init()` at the beginning of the
function resulting from outlining each target region, and it also
generates the matching `__kmpc_target_deinit()` call before returning.

It relies on the implementation of target region outlining for host
codegen to handle the production of the new function and the lowering of
its body based on the contents of the associated target region.

Depends on D147172

Differential Revision: https://reviews.llvm.org/D147940

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir [new file with mode: 0644]

index 99f454e..4b4a710 100644 (file)
@@ -4148,8 +4148,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
 }
 
 static Function *
-createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName,
-                       SmallVectorImpl<Value *> &Inputs,
+createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+                       StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
                        OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
   SmallVector<Type *> ParameterTypes;
   for (auto &Arg : Inputs)
@@ -4166,8 +4166,17 @@ createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName,
   // Generate the region into the function.
   BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
   Builder.SetInsertPoint(EntryBB);
+
+  // Insert target init call in the device compilation pass.
+  if (OMPBuilder.Config.isEmbedded())
+    Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+
   Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
 
+  // Insert target deinit call in the device compilation pass.
+  if (OMPBuilder.Config.isEmbedded())
+    OMPBuilder.createTargetDeinit(Builder, /*IsSPMD*/ false);
+
   // Insert return instruction.
   Builder.CreateRetVoid();
 
@@ -4197,8 +4206,9 @@ emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                            OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
-      [&Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
-        return createOutlinedFunction(Builder, EntryFnName, Inputs, CBFunc);
+      [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
+        return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
+                                      CBFunc);
       };
 
   Constant *OutlinedFnID;
@@ -4209,7 +4219,7 @@ emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
 static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn,
                            SmallVectorImpl<Value *> &Args) {
-  // TODO: Add kernel launch call when device codegen is supported.
+  // TODO: Add kernel launch call
   Builder.CreateCall(OutlinedFn, Args);
 }
 
@@ -4225,7 +4235,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
   Function *OutlinedFn;
   emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams,
                              NumThreads, Args, CBFunc);
-  emitTargetCall(Builder, OutlinedFn, Args);
+  if (!Config.isEmbedded())
+    emitTargetCall(Builder, OutlinedFn, Args);
   return Builder.saveIP();
 }
 
index 11637f4..d4c0bf8 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/DIBuilder.h"
@@ -5175,6 +5176,94 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.setConfig(OpenMPIRBuilderConfig(true, false, false, false));
+  OMPBuilder.initialize();
+
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  StoreInst *TargetStore = nullptr;
+  llvm::SmallVector<llvm::Value *, 2> CapturedArgs = {
+      Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)),
+      Constant::getNullValue(Type::getInt32PtrTy(Ctx))};
+
+  auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+                       OpenMPIRBuilder::InsertPointTy CodeGenIP)
+      -> OpenMPIRBuilder::InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    TargetStore = Builder.CreateStore(CapturedArgs[0], CapturedArgs[1]);
+    return Builder.saveIP();
+  };
+
+  IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+                                   F->getEntryBlock().getFirstInsertionPt());
+  TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+                                  /*Line=*/3, /*Count=*/0);
+
+  Builder.restoreIP(
+      OMPBuilder.createTarget(Loc, EntryIP, EntryInfo, /*NumTeams=*/-1,
+                              /*NumThreads=*/-1, CapturedArgs, BodyGenCB));
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  // Check outlined function
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+  EXPECT_NE(TargetStore, nullptr);
+  Function *OutlinedFn = TargetStore->getFunction();
+  EXPECT_NE(F, OutlinedFn);
+
+  EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+  EXPECT_EQ(OutlinedFn->arg_size(), 2U);
+  EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+  EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isIntegerTy(32));
+  EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy());
+
+  // Check entry block
+  auto &EntryBlock = OutlinedFn->getEntryBlock();
+  Instruction *Init = EntryBlock.getFirstNonPHI();
+  EXPECT_NE(Init, nullptr);
+
+  auto *InitCall = dyn_cast<CallInst>(Init);
+  EXPECT_NE(InitCall, nullptr);
+  EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init");
+  EXPECT_EQ(InitCall->arg_size(), 3U);
+  EXPECT_TRUE(isa<GlobalVariable>(InitCall->getArgOperand(0)));
+  EXPECT_EQ(InitCall->getArgOperand(1),
+            ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC));
+  EXPECT_EQ(InitCall->getArgOperand(2),
+            ConstantInt::get(Type::getInt1Ty(Ctx), true));
+
+  auto *EntryBlockBranch = EntryBlock.getTerminator();
+  EXPECT_NE(EntryBlockBranch, nullptr);
+  EXPECT_EQ(EntryBlockBranch->getNumSuccessors(), 2U);
+
+  // Check user code block
+  auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
+  EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
+  EXPECT_EQ(UserCodeBlock->getFirstNonPHI(), TargetStore);
+
+  auto *Deinit = TargetStore->getNextNode();
+  EXPECT_NE(Deinit, nullptr);
+
+  auto *DeinitCall = dyn_cast<CallInst>(Deinit);
+  EXPECT_NE(DeinitCall, nullptr);
+  EXPECT_EQ(DeinitCall->getCalledFunction()->getName(), "__kmpc_target_deinit");
+  EXPECT_EQ(DeinitCall->arg_size(), 2U);
+  EXPECT_TRUE(isa<GlobalVariable>(DeinitCall->getArgOperand(0)));
+  EXPECT_EQ(DeinitCall->getArgOperand(1),
+            ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC));
+
+  EXPECT_TRUE(isa<ReturnInst>(DeinitCall->getNextNode()));
+
+  // Check exit block
+  auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
+  EXPECT_EQ(ExitBlock->getName(), "worker.exit");
+  EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+}
+
 TEST_F(OpenMPIRBuilderTest, CreateTask) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
index 750f715..27e3cec 100644 (file)
@@ -1629,15 +1629,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   if (!targetOpSupported(opInst))
     return failure();
 
-  bool isDevice = false;
-  if (auto offloadMod = dyn_cast<mlir::omp::OffloadModuleInterface>(
-          opInst.getParentOfType<mlir::ModuleOp>().getOperation())) {
-    isDevice = offloadMod.getIsDevice();
-  }
-
-  if (isDevice) // TODO: Implement device codegen.
-    return success();
-
   auto targetOp = cast<omp::TargetOp>(opInst);
   auto &targetRegion = targetOp.getRegion();
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
new file mode 100644 (file)
index 0000000..3e385e0
--- /dev/null
@@ -0,0 +1,44 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_device = #omp.isdevice<is_device = true>} {
+  llvm.func @omp_target_region_() {
+    %0 = llvm.mlir.constant(20 : i32) : i32
+    %1 = llvm.mlir.constant(10 : i32) : i32
+    %2 = llvm.mlir.constant(1 : i64) : i64
+    %3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr<i32>
+    %4 = llvm.mlir.constant(1 : i64) : i64
+    %5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr<i32>
+    %6 = llvm.mlir.constant(1 : i64) : i64
+    %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr<i32>
+    llvm.store %1, %3 : !llvm.ptr<i32>
+    llvm.store %0, %5 : !llvm.ptr<i32>
+    omp.target   {
+      %8 = llvm.load %3 : !llvm.ptr<i32>
+      %9 = llvm.load %5 : !llvm.ptr<i32>
+      %10 = llvm.add %8, %9  : i32
+      llvm.store %10, %7 : !llvm.ptr<i32>
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK:      @[[SRC_LOC:.*]] = private unnamed_addr constant [23 x i8] c"{{[^"]*}}", align 1
+// CHECK:      @[[IDENT:.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @[[SRC_LOC]] }, align 8
+// CHECK:      define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]])
+// CHECK:        %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[IDENT]], i8 1, i1 true)
+// CHECK-NEXT:   %[[CMP:.*]] = icmp eq i32 %3, -1
+// CHECK-NEXT:   br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]]
+// CHECK:        [[LABEL_ENTRY]]:
+// CHECK-NEXT:   br label %[[LABEL_TARGET:.*]]
+// CHECK:        [[LABEL_TARGET]]:
+// CHECK:        %[[A:.*]] = load i32, ptr %[[ADDR_A]], align 4
+// CHECK:        %[[B:.*]] = load i32, ptr %[[ADDR_B]], align 4
+// CHECK:        %[[C:.*]] = add i32 %[[A]], %[[B]]
+// CHECK:        store i32 %[[C]], ptr %[[ADDR_C]], align 4
+// CHECK:        br label %[[LABEL_DEINIT:.*]]
+// CHECK:        [[LABEL_DEINIT]]:
+// CHECK-NEXT:   call void @__kmpc_target_deinit(ptr @[[IDENT]], i8 1)
+// CHECK-NEXT:   ret void
+// CHECK:        [[LABEL_EXIT]]:
+// CHECK-NEXT:   ret void