[MLIR][OpenMP] Add support for threadprivate directive
authorPeixin-Qiao <qiaopeixin@huawei.com>
Tue, 12 Apr 2022 02:15:15 +0000 (10:15 +0800)
committerPeixin-Qiao <qiaopeixin@huawei.com>
Tue, 12 Apr 2022 02:15:15 +0000 (10:15 +0800)
This supports the threadprivate directive in OpenMP dialect following
the OpenMP 5.1 [2.21.2] standard. Also lowering to LLVM IR using OpenMP
IRBduiler.

Reviewed By: kiranchandramohan, shraiysh, arnamoy10

Differential Revision: https://reviews.llvm.org/D123350

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Dialect/OpenMP/ops.mlir
mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir
mlir/test/Target/LLVMIR/openmp-llvm.mlir

index 4444ad2..fab8c2c 100644 (file)
@@ -883,6 +883,36 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture",
 }
 
 //===----------------------------------------------------------------------===//
+// [5.1] 2.21.2 threadprivate Directive
+//===----------------------------------------------------------------------===//
+
+def ThreadprivateOp : OpenMP_Op<"threadprivate"> {
+  let summary = "threadprivate directive";
+  let description = [{
+    The threadprivate directive specifies that variables are replicated, with
+    each thread having its own copy.
+
+    The current implementation uses the OpenMP runtime to provide thread-local
+    storage (TLS). Using the TLS feature of the LLVM IR will be supported in
+    future.
+
+    This operation takes in the address of a symbol that represents the original
+    variable and returns the address of its TLS. All occurrences of
+    threadprivate variables in a parallel region should use the TLS returned by
+    this operation.
+
+    The `sym_addr` refers to the address of the symbol, which is a pointer to
+    the original variable.
+  }];
+
+  let arguments = (ins OpenMP_PointerLikeType:$sym_addr);
+  let results = (outs OpenMP_PointerLikeType:$tls_addr);
+  let assemblyFormat = [{
+    $sym_addr `:` type($sym_addr) `->` type($tls_addr) attr-dict
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // 2.19.5.7 declare reduction Directive
 //===----------------------------------------------------------------------===//
 
index f8fff69..e6ec8fd 100644 (file)
@@ -1319,6 +1319,40 @@ convertOmpReductionOp(omp::ReductionOp reductionOp,
   return success();
 }
 
+/// Converts an OpenMP Threadprivate operation into LLVM IR using
+/// OpenMPIRBuilder.
+static LogicalResult
+convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
+
+  Value symAddr = threadprivateOp.sym_addr();
+  auto symOp = symAddr.getDefiningOp();
+  if (!isa<LLVM::AddressOfOp>(symOp))
+    return opInst.emitError("Addressing symbol not found");
+  LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
+
+  LLVM::GlobalOp global = addressOfOp.getGlobal();
+  llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
+  llvm::Value *data =
+      builder.CreateBitCast(globalValue, builder.getInt8PtrTy());
+  llvm::Type *type = globalValue->getValueType();
+  llvm::TypeSize typeSize =
+      builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
+          type);
+  llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedSize());
+  llvm::StringRef suffix = llvm::StringRef(".cache", 6);
+  llvm::Twine cacheName = Twine(global.getSymName()).concat(suffix);
+  // Emit runtime function and bitcast its type (i8*) to real data type.
+  llvm::Value *callInst =
+      moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate(
+          ompLoc, data, size, cacheName);
+  llvm::Value *result = builder.CreateBitCast(callInst, globalValue->getType());
+  moduleTranslation.mapValue(opInst.getResult(0), result);
+  return success();
+}
+
 namespace {
 
 /// Implementation of the dialect interface that converts operations belonging
@@ -1424,6 +1458,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
         // name for critical section names.
         return success();
       })
+      .Case([&](omp::ThreadprivateOp) {
+        return convertOmpThreadprivate(*op, builder, moduleTranslation);
+      })
       .Default([&](Operation *inst) {
         return inst->emitError("unsupported OpenMP operation: ")
                << inst->getName();
index 9149feb..7dcbaba 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s
 
 func @omp_barrier() -> () {
   // CHECK: omp.barrier
@@ -896,3 +896,29 @@ func @omp_single_allocate_nowait(%data_var: memref<i32>) {
   }
   return
 }
+
+// -----
+
+func @omp_threadprivate() {
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 2 : i32
+  %2 = arith.constant 3 : i32
+
+  // CHECK: [[ARG0:%.*]] = llvm.mlir.addressof @_QFsubEx : !llvm.ptr<i32>
+  // CHECK: {{.*}} = omp.threadprivate [[ARG0]] : !llvm.ptr<i32> -> !llvm.ptr<i32>
+  %3 = llvm.mlir.addressof @_QFsubEx : !llvm.ptr<i32>
+  %4 = omp.threadprivate %3 : !llvm.ptr<i32> -> !llvm.ptr<i32>
+  llvm.store %0, %4 : !llvm.ptr<i32>
+
+  // CHECK:  omp.parallel
+  // CHECK:    {{.*}} = omp.threadprivate [[ARG0]] : !llvm.ptr<i32> -> !llvm.ptr<i32>
+  omp.parallel  {
+    %5 = omp.threadprivate %3 : !llvm.ptr<i32> -> !llvm.ptr<i32>
+    llvm.store %1, %5 : !llvm.ptr<i32>
+    omp.terminator
+  }
+  llvm.store %2, %4 : !llvm.ptr<i32>
+  return
+}
+
+llvm.mlir.global internal @_QFsubEx() : i32
index 171db04..40b0588 100644 (file)
@@ -67,3 +67,29 @@ llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr<i32>, %v: !llvm.
   }
   llvm.return
 }
+
+// -----
+
+llvm.func @omp_threadprivate() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %2 = llvm.mlir.constant(2 : i32) : i32
+  %3 = llvm.mlir.constant(3 : i32) : i32
+
+  %4 = llvm.alloca %0 x i32 {in_type = i32, name = "a"} : (i64) -> !llvm.ptr<i32>
+
+  // expected-error @below {{Addressing symbol not found}}
+  // expected-error @below {{LLVM Translation failed for operation: omp.threadprivate}}
+  %5 = omp.threadprivate %4 : !llvm.ptr<i32> -> !llvm.ptr<i32>
+
+  llvm.store %1, %5 : !llvm.ptr<i32>
+
+  omp.parallel  {
+    %6 = omp.threadprivate %4 : !llvm.ptr<i32> -> !llvm.ptr<i32>
+    llvm.store %2, %6 : !llvm.ptr<i32>
+    omp.terminator
+  }
+
+  llvm.store %3, %5 : !llvm.ptr<i32>
+  llvm.return
+}
index 9186551..86f98fd 100644 (file)
@@ -2050,3 +2050,43 @@ llvm.func @single_nowait(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
   // CHECK: ret void
   llvm.return
 }
+
+// -----
+
+// CHECK: @_QFsubEx = internal global i32 undef
+// CHECK: @_QFsubEx.cache = common global i8** null
+
+// CHECK-LABEL: @omp_threadprivate
+llvm.func @omp_threadprivate() {
+// CHECK:  [[THREAD:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GLOB:[0-9]+]])
+// CHECK:  [[TMP1:%.*]] = call i8* @__kmpc_threadprivate_cached(%struct.ident_t* @[[GLOB]], i32 [[THREAD]], i8* bitcast (i32* @_QFsubEx to i8*), i64 4, i8*** @_QFsubEx.cache)
+// CHECK:  [[TMP2:%.*]] = bitcast i8* [[TMP1]] to i32*
+// CHECK:  store i32 1, i32* [[TMP2]], align 4
+// CHECK:  store i32 3, i32* [[TMP2]], align 4
+
+// CHECK-LABEL: omp.par.region{{.*}}
+// CHECK:  [[THREAD2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GLOB2:[0-9]+]])
+// CHECK:  [[TMP3:%.*]] = call i8* @__kmpc_threadprivate_cached(%struct.ident_t* @[[GLOB2]], i32 [[THREAD2]], i8* bitcast (i32* @_QFsubEx to i8*), i64 4, i8*** @_QFsubEx.cache)
+// CHECK:  [[TMP4:%.*]] = bitcast i8* [[TMP3]] to i32*
+// CHECK:  store i32 2, i32* [[TMP4]], align 4
+
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(2 : i32) : i32
+  %2 = llvm.mlir.constant(3 : i32) : i32
+
+  %3 = llvm.mlir.addressof @_QFsubEx : !llvm.ptr<i32>
+  %4 = omp.threadprivate %3 : !llvm.ptr<i32> -> !llvm.ptr<i32>
+
+  llvm.store %0, %4 : !llvm.ptr<i32>
+
+  omp.parallel  {
+    %5 = omp.threadprivate %3 : !llvm.ptr<i32> -> !llvm.ptr<i32>
+    llvm.store %1, %5 : !llvm.ptr<i32>
+    omp.terminator
+  }
+
+  llvm.store %2, %4 : !llvm.ptr<i32>
+  llvm.return
+}
+
+llvm.mlir.global internal @_QFsubEx() : i32