[Flang][OpenMP][MLIR] Filter emitted code depending on declare target and device
authorSergio Afonso <safonsof@amd.com>
Thu, 29 Jun 2023 11:20:25 +0000 (12:20 +0100)
committerSergio Afonso <safonsof@amd.com>
Mon, 17 Jul 2023 08:07:54 +0000 (09:07 +0100)
This patch adds support for selecting which functions are lowered to LLVM IR
from MLIR depending on declare target information and whether host or device
code is being generated.

The approach proposed by this patch is to perform the filtering in two stages:
  - An MLIR transformation pass, which is added to the Flang translation flow
    after the `OMPEarlyOutliningPass`. The functions that are kept are those
    that match the OpenMP processor (host or device) the compiler invocation
    is targeting, according to the presence of the `-fopenmp-is-target-device`
    compiler option and declare target information. All functions contaning an
    `omp.target` are also kept, regardless of the declare target information of
    the function, due to the need for keeping target regions visible for both
    host and device compilation.
  - A filtering step during translation to LLVM IR, which is peformed for those
    functions that were kept because of the presence of a target region inside.
    If the targeted OpenMP processor does not match the declare target
    information of the function, then it is removed from the LLVM IR after its
    contents have been processed and translated. Since they should only contain
    an omp.target operation which, in turn, should have been outlined into
    another LLVM IR function, the wrapper can be deleted at that point.

Depends on D150328 and D150329.

Differential Revision: https://reviews.llvm.org/D147641

12 files changed:
flang/include/flang/Optimizer/Transforms/Passes.h
flang/include/flang/Optimizer/Transforms/Passes.td
flang/lib/Frontend/FrontendActions.cpp
flang/lib/Optimizer/Transforms/CMakeLists.txt
flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp [new file with mode: 0644]
flang/test/Lower/OpenMP/function-filtering.f90 [new file with mode: 0644]
flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90
flang/test/Lower/OpenMP/omp-declare-target-program-var.f90
flang/test/Transforms/omp-function-filtering.mlir [new file with mode: 0644]
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
mlir/test/Target/LLVMIR/openmp-llvm.mlir

index 8d15046..3272cb3 100644 (file)
@@ -73,8 +73,11 @@ std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
 std::unique_ptr<mlir::Pass>
 createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
 std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
+
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 createOMPEarlyOutliningPass();
+std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
+
 // declarative passes
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
index e6ecded..40a08c9 100644 (file)
@@ -311,4 +311,13 @@ def OMPEarlyOutliningPass
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
+def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
+  let summary = "Filters out functions intended for the host when compiling "
+                "for the device and vice versa.";
+  let constructor = "::fir::createOMPFunctionFilteringPass()";
+  let dependentDialects = [
+    "mlir::func::FuncDialect"
+  ];
+}
+
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
index 173c257..c03c3fd 100644 (file)
@@ -312,6 +312,7 @@ bool CodeGenAction::beginSourceFileAction() {
 
     if (isDevice)
       pm.addPass(fir::createOMPEarlyOutliningPass());
+    pm.addPass(fir::createOMPFunctionFilteringPass());
   }
 
   pm.enableVerifier(/*verifyPasses=*/true);
index bd4aee3..1808542 100644 (file)
@@ -17,6 +17,7 @@ add_flang_library(FIRTransforms
   PolymorphicOpConversion.cpp
   LoopVersioning.cpp
   OMPEarlyOutlining.cpp
+  OMPFunctionFiltering.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
new file mode 100644 (file)
index 0000000..7784c90
--- /dev/null
@@ -0,0 +1,73 @@
+//===- OMPFunctionFiltering.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to filter out functions intended for the host
+// when compiling for the device and vice versa.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace fir {
+#define GEN_PASS_DEF_OMPFUNCTIONFILTERING
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+
+namespace {
+class OMPFunctionFilteringPass
+    : public fir::impl::OMPFunctionFilteringBase<OMPFunctionFilteringPass> {
+public:
+  OMPFunctionFilteringPass() = default;
+
+  void runOnOperation() override {
+    auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
+    if (!op)
+      return;
+
+    bool isDeviceCompilation = op.getIsTargetDevice();
+    op->walk<WalkOrder::PostOrder>([&](func::FuncOp funcOp) {
+      // Do not filter functions with target regions inside, because they have
+      // to be available for both host and device so that regular and reverse
+      // offloading can be supported.
+      bool hasTargetRegion =
+          funcOp
+              ->walk<WalkOrder::PreOrder>(
+                  [&](omp::TargetOp) { return WalkResult::interrupt(); })
+              .wasInterrupted();
+      if (hasTargetRegion)
+        return;
+
+      omp::DeclareTargetDeviceType declareType =
+          omp::DeclareTargetDeviceType::host;
+      auto declareTargetOp =
+          dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
+      if (declareTargetOp && declareTargetOp.isDeclareTarget())
+        declareType = declareTargetOp.getDeclareTargetDeviceType();
+
+      if ((isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::host) ||
+          (!isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::nohost))
+        funcOp->erase();
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> fir::createOMPFunctionFilteringPass() {
+  return std::make_unique<OMPFunctionFilteringPass>();
+}
diff --git a/flang/test/Lower/OpenMP/function-filtering.f90 b/flang/test/Lower/OpenMP/function-filtering.f90
new file mode 100644 (file)
index 0000000..4386cb4
--- /dev/null
@@ -0,0 +1,44 @@
+! RUN: %flang_fc1 -fopenmp -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-HOST,LLVM-ALL %s
+! RUN: %flang_fc1 -fopenmp -emit-mlir %s -o - | FileCheck --check-prefix=MLIR-HOST %s
+! RUN: %flang_fc1 -fopenmp -fopenmp-is-target-device -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-DEVICE,LLVM-ALL %s
+! RUN: %flang_fc1 -fopenmp -fopenmp-is-target-device -emit-mlir %s -o - | FileCheck --check-prefix=MLIR-DEVICE %s
+
+! Check that the correct LLVM IR functions are kept for the host and device
+! after running the whole set of translation and transformation passes from
+! Fortran.
+
+! MLIR-HOST-NOT: func.func @{{.*}}device_fn(
+! MLIR-DEVICE: func.func @{{.*}}device_fn(
+! LLVM-HOST-NOT: define {{.*}} @{{.*}}device_fn{{.*}}(
+! LLVM-DEVICE: define {{.*}} @{{.*}}device_fn{{.*}}(
+function device_fn() result(x)
+  !$omp declare target to(device_fn) device_type(nohost)
+  integer :: x
+  x = 10
+end function device_fn
+
+! MLIR-HOST: func.func @{{.*}}host_fn(
+! MLIR-DEVICE-NOT: func.func @{{.*}}host_fn(
+! LLVM-HOST: define {{.*}} @{{.*}}host_fn{{.*}}(
+! LLVM-DEVICE-NOT: define {{.*}} @{{.*}}host_fn{{.*}}(
+function host_fn() result(x)
+  !$omp declare target to(host_fn) device_type(host)
+  integer :: x
+  x = 10
+end function host_fn
+
+! MLIR-HOST: func.func @{{.*}}target_subr(
+! MLIR-HOST-NOT: func.func @{{.*}}target_subr_omp_outline_0(
+! MLIR-DEVICE-NOT: func.func @{{.*}}target_subr(
+! MLIR-DEVICE: func.func @{{.*}}target_subr_omp_outline_0(
+
+! LLVM-ALL-NOT: define {{.*}} @{{.*}}target_subr_omp_outline_0{{.*}}(
+! LLVM-HOST: define {{.*}} @{{.*}}target_subr{{.*}}(
+! LLVM-DEVICE-NOT: define {{.*}} @{{.*}}target_subr{{.*}}(
+! LLVM-ALL: define {{.*}} @__omp_offloading_{{.*}}_{{.*}}_target_subr__{{.*}}(
+subroutine target_subr(x)
+  integer, intent(out) :: x
+  !$omp target map(from:x)
+    x = 10
+  !$omp end target
+end subroutine target_subr
index 6e197c5..26741c6 100644 (file)
@@ -1,51 +1,52 @@
-!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes ALL,HOST
+!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL,DEVICE
 
 ! Check specification valid forms of declare target with functions 
 ! utilising device_type and to clauses as well as the default 
 ! zero clause declare target
 
-! CHECK-LABEL: func.func @_QPfunc_t_device()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+! DEVICE-LABEL: func.func @_QPfunc_t_device()
+! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_T_DEVICE() RESULT(I)
 !$omp declare target to(FUNC_T_DEVICE) device_type(nohost)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_T_DEVICE
 
-! CHECK-LABEL: func.func @_QPfunc_t_host()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
+! HOST-LABEL: func.func @_QPfunc_t_host()
+! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_T_HOST() RESULT(I)
 !$omp declare target to(FUNC_T_HOST) device_type(host)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_T_HOST
 
-! CHECK-LABEL: func.func @_QPfunc_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_T_ANY() RESULT(I)
 !$omp declare target to(FUNC_T_ANY) device_type(any)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_T_ANY
 
-! CHECK-LABEL: func.func @_QPfunc_default_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_default_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_DEFAULT_T_ANY() RESULT(I)
 !$omp declare target to(FUNC_DEFAULT_T_ANY)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_DEFAULT_T_ANY
 
-! CHECK-LABEL: func.func @_QPfunc_default_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_default_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_DEFAULT_ANY() RESULT(I)
 !$omp declare target
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_DEFAULT_ANY
 
-! CHECK-LABEL: func.func @_QPfunc_default_extendedlist()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_default_extendedlist()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_DEFAULT_EXTENDEDLIST() RESULT(I)
 !$omp declare target(FUNC_DEFAULT_EXTENDEDLIST)
     INTEGER :: I
@@ -58,46 +59,46 @@ END FUNCTION FUNC_DEFAULT_EXTENDEDLIST
 ! utilising device_type and to clauses as well as the default 
 ! zero clause declare target
 
-! CHECK-LABEL: func.func @_QPsubr_t_device()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+! DEVICE-LABEL: func.func @_QPsubr_t_device()
+! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_T_DEVICE()
 !$omp declare target to(SUBR_T_DEVICE) device_type(nohost)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_t_host()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
+! HOST-LABEL: func.func @_QPsubr_t_host()
+! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_T_HOST()
 !$omp declare target to(SUBR_T_HOST) device_type(host)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_T_ANY()
 !$omp declare target to(SUBR_T_ANY) device_type(any)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_default_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_default_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_DEFAULT_T_ANY()
 !$omp declare target to(SUBR_DEFAULT_T_ANY)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_default_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_default_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_DEFAULT_ANY()
 !$omp declare target
 END
 
-! CHECK-LABEL: func.func @_QPsubr_default_extendedlist()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_default_extendedlist()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_DEFAULT_EXTENDEDLIST()
 !$omp declare target(SUBR_DEFAULT_EXTENDEDLIST)
 END
 
 !! -----
 
-! CHECK-LABEL: func.func @_QPrecursive_declare_target
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+! DEVICE-LABEL: func.func @_QPrecursive_declare_target
+! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
 RECURSIVE FUNCTION RECURSIVE_DECLARE_TARGET(INCREMENT) RESULT(K)
 !$omp declare target to(RECURSIVE_DECLARE_TARGET) device_type(nohost)
     INTEGER :: INCREMENT, K
index ef39a98..0da76f6 100644 (file)
@@ -1,12 +1,12 @@
-!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s 
-!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes=HOST,ALL
+!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefix=ALL
 
 PROGRAM main
-    ! CHECK-DAG: %0 = fir.alloca f32 {bindc_name = "i", uniq_name = "_QFEi"}
+    ! HOST-DAG: %0 = fir.alloca f32 {bindc_name = "i", uniq_name = "_QFEi"}
     REAL :: I
-    ! CHECK-DAG: fir.global internal @_QFEi {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : f32 {
-    ! CHECK-DAG: %0 = fir.undefined f32
-    ! CHECK-DAG: fir.has_value %0 : f32
-    ! CHECK-DAG: }
+    ! ALL-DAG: fir.global internal @_QFEi {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : f32 {
+    ! ALL-DAG: %0 = fir.undefined f32
+    ! ALL-DAG: fir.has_value %0 : f32
+    ! ALL-DAG: }
     !$omp declare target(I)
 END
diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir
new file mode 100644 (file)
index 0000000..ccb11ca
--- /dev/null
@@ -0,0 +1,111 @@
+// RUN: fir-opt -split-input-file --omp-function-filtering %s | FileCheck %s
+
+// CHECK:     func.func @any
+// CHECK:     func.func @nohost
+// CHECK-NOT: func.func @host
+// CHECK-NOT: func.func @none
+// CHECK:     func.func @nohost_target
+// CHECK:     func.func @host_target
+// CHECK:     func.func @none_target
+module attributes {omp.is_target_device = true} {
+  func.func @any() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (any), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @nohost() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @none() -> () {
+    func.return
+  }
+  func.func @nohost_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @host_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @none_target() -> () {
+    omp.target {}
+    func.return
+  }
+}
+
+// -----
+
+// CHECK:     func.func @any
+// CHECK-NOT: func.func @nohost
+// CHECK:     func.func @host
+// CHECK:     func.func @none
+// CHECK:     func.func @nohost_target
+// CHECK:     func.func @host_target
+// CHECK:     func.func @none_target
+module attributes {omp.is_target_device = false} {
+  func.func @any() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (any), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @nohost() -> ()
+      attributes {
+          omp.declare_target =
+            #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @none() -> () {
+    func.return
+  }
+  func.func @nohost_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @host_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @none_target() -> () {
+    omp.target {}
+    func.return
+  }
+}
index 49df49d..efcb918 100644 (file)
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
 #include "mlir/IR/IRMapping.h"
@@ -1667,6 +1668,38 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   return bodyGenStatus;
 }
 
+static LogicalResult
+convertDeclareTargetAttr(Operation *op,
+                         omp::DeclareTargetAttr declareTargetAttr,
+                         LLVM::ModuleTranslation &moduleTranslation) {
+  // Amend omp.declare_target by deleting the IR of the outlined functions
+  // created for target regions. They cannot be filtered out from MLIR earlier
+  // because the omp.target operation inside must be translated to LLVM, but the
+  // wrapper functions themselves must not remain at the end of the process.
+  // We know that functions where omp.declare_target does not match
+  // omp.is_target_device at this stage can only be wrapper functions because
+  // those that aren't are removed earlier as an MLIR transformation pass.
+  if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
+    if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
+            op->getParentOfType<ModuleOp>().getOperation())) {
+      bool isDeviceCompilation = offloadMod.getIsTargetDevice();
+      omp::DeclareTargetDeviceType declareType =
+          declareTargetAttr.getDeviceType().getValue();
+
+      if ((isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::host) ||
+          (!isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::nohost)) {
+        llvm::Function *llvmFunc =
+            moduleTranslation.lookupFunction(funcOp.getName());
+        llvmFunc->dropAllReferences();
+        llvmFunc->eraseFromParent();
+      }
+    }
+  }
+  return success();
+}
+
 namespace {
 
 /// Implementation of the dialect interface that converts operations belonging
@@ -1694,7 +1727,6 @@ public:
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
     Operation *op, NamedAttribute attribute,
     LLVM::ModuleTranslation &moduleTranslation) const {
-
   return llvm::TypeSwitch<Attribute, LogicalResult>(attribute.getValue())
       .Case([&](mlir::omp::FlagsAttr rtlAttr) {
         return convertFlagsAttr(op, rtlAttr, moduleTranslation);
@@ -1706,6 +1738,10 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
                                     versionAttr.getVersion());
         return success();
       })
+      .Case([&](mlir::omp::DeclareTargetAttr declareTargetAttr) {
+        return convertDeclareTargetAttr(op, declareTargetAttr,
+                                        moduleTranslation);
+      })
       .Default([&](Attribute attr) {
         // fall through for omp attributes that do not require lowering and/or
         // have no concrete definition and thus no type to define a case on
index a121538..bee77bb 100644 (file)
@@ -2,7 +2,7 @@
 // name stored in the omp.outline_parent_name attribute.
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
-module attributes {omp.is_device = true} {
+module attributes {omp.is_target_device = true} {
   llvm.func @writeindex_omp_outline_0_(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>) attributes {omp.outline_parent_name = "writeindex_"} {
     omp.target   map((from -> %arg0 : !llvm.ptr<i32>), (implicit -> %arg1: !llvm.ptr<i32>)) {
       %0 = llvm.mlir.constant(20 : i32) : i32
index 89a4578..15eb0b3 100644 (file)
@@ -2543,3 +2543,47 @@ module attributes {omp.flags = #omp.flags<debug_kind = 0, assume_teams_oversubsc
 // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1
 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0
 module attributes {omp.flags = #omp.flags<assume_teams_oversubscription = true, assume_no_thread_state = true>} {}
+
+// -----
+
+module attributes {omp.is_target_device = false} {
+  // CHECK-NOT: @filter_host_nohost
+  llvm.func @filter_host_nohost() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+
+  // CHECK: @filter_host_host
+  llvm.func @filter_host_host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+}
+
+// -----
+
+module attributes {omp.is_target_device = true} {
+  // CHECK: @filter_device_nohost
+  llvm.func @filter_device_nohost() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+
+  // CHECK-NOT: @filter_device_host
+  llvm.func @filter_device_host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+}