From 139c5b8a96342090efaee73840c4a5ea72ecadf3 Mon Sep 17 00:00:00 2001 From: Peixin-Qiao Date: Tue, 12 Apr 2022 10:15:15 +0800 Subject: [PATCH] [MLIR][OpenMP] Add support for threadprivate directive This supports the threadprivate directive in OpenMP dialect following the OpenMP 5.1 [2.21.2] standard. Also lowering to LLVM IR using OpenMP IRBduiler. Reviewed By: kiranchandramohan, shraiysh, arnamoy10 Differential Revision: https://reviews.llvm.org/D123350 --- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 30 ++++++++++++++++ .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 37 ++++++++++++++++++++ mlir/test/Dialect/OpenMP/ops.mlir | 28 ++++++++++++++- mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir | 26 ++++++++++++++ mlir/test/Target/LLVMIR/openmp-llvm.mlir | 40 ++++++++++++++++++++++ 5 files changed, 160 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 4444ad2..fab8c2c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -883,6 +883,36 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", } //===----------------------------------------------------------------------===// +// [5.1] 2.21.2 threadprivate Directive +//===----------------------------------------------------------------------===// + +def ThreadprivateOp : OpenMP_Op<"threadprivate"> { + let summary = "threadprivate directive"; + let description = [{ + The threadprivate directive specifies that variables are replicated, with + each thread having its own copy. + + The current implementation uses the OpenMP runtime to provide thread-local + storage (TLS). Using the TLS feature of the LLVM IR will be supported in + future. + + This operation takes in the address of a symbol that represents the original + variable and returns the address of its TLS. All occurrences of + threadprivate variables in a parallel region should use the TLS returned by + this operation. + + The `sym_addr` refers to the address of the symbol, which is a pointer to + the original variable. + }]; + + let arguments = (ins OpenMP_PointerLikeType:$sym_addr); + let results = (outs OpenMP_PointerLikeType:$tls_addr); + let assemblyFormat = [{ + $sym_addr `:` type($sym_addr) `->` type($tls_addr) attr-dict + }]; +} + +//===----------------------------------------------------------------------===// // 2.19.5.7 declare reduction Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index f8fff69..e6ec8fd 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1319,6 +1319,40 @@ convertOmpReductionOp(omp::ReductionOp reductionOp, return success(); } +/// Converts an OpenMP Threadprivate operation into LLVM IR using +/// OpenMPIRBuilder. +static LogicalResult +convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + auto threadprivateOp = cast(opInst); + + Value symAddr = threadprivateOp.sym_addr(); + auto symOp = symAddr.getDefiningOp(); + if (!isa(symOp)) + return opInst.emitError("Addressing symbol not found"); + LLVM::AddressOfOp addressOfOp = dyn_cast(symOp); + + LLVM::GlobalOp global = addressOfOp.getGlobal(); + llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global); + llvm::Value *data = + builder.CreateBitCast(globalValue, builder.getInt8PtrTy()); + llvm::Type *type = globalValue->getValueType(); + llvm::TypeSize typeSize = + builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize( + type); + llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedSize()); + llvm::StringRef suffix = llvm::StringRef(".cache", 6); + llvm::Twine cacheName = Twine(global.getSymName()).concat(suffix); + // Emit runtime function and bitcast its type (i8*) to real data type. + llvm::Value *callInst = + moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate( + ompLoc, data, size, cacheName); + llvm::Value *result = builder.CreateBitCast(callInst, globalValue->getType()); + moduleTranslation.mapValue(opInst.getResult(0), result); + return success(); +} + namespace { /// Implementation of the dialect interface that converts operations belonging @@ -1424,6 +1458,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( // name for critical section names. return success(); }) + .Case([&](omp::ThreadprivateOp) { + return convertOmpThreadprivate(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 9149feb..7dcbaba 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s func @omp_barrier() -> () { // CHECK: omp.barrier @@ -896,3 +896,29 @@ func @omp_single_allocate_nowait(%data_var: memref) { } return } + +// ----- + +func @omp_threadprivate() { + %0 = arith.constant 1 : i32 + %1 = arith.constant 2 : i32 + %2 = arith.constant 3 : i32 + + // CHECK: [[ARG0:%.*]] = llvm.mlir.addressof @_QFsubEx : !llvm.ptr + // CHECK: {{.*}} = omp.threadprivate [[ARG0]] : !llvm.ptr -> !llvm.ptr + %3 = llvm.mlir.addressof @_QFsubEx : !llvm.ptr + %4 = omp.threadprivate %3 : !llvm.ptr -> !llvm.ptr + llvm.store %0, %4 : !llvm.ptr + + // CHECK: omp.parallel + // CHECK: {{.*}} = omp.threadprivate [[ARG0]] : !llvm.ptr -> !llvm.ptr + omp.parallel { + %5 = omp.threadprivate %3 : !llvm.ptr -> !llvm.ptr + llvm.store %1, %5 : !llvm.ptr + omp.terminator + } + llvm.store %2, %4 : !llvm.ptr + return +} + +llvm.mlir.global internal @_QFsubEx() : i32 diff --git a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir index 171db04..40b0588 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir @@ -67,3 +67,29 @@ llvm.func @omp_atomic_update_multiple_step_update(%x: !llvm.ptr, %v: !llvm. } llvm.return } + +// ----- + +llvm.func @omp_threadprivate() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.mlir.constant(1 : i32) : i32 + %2 = llvm.mlir.constant(2 : i32) : i32 + %3 = llvm.mlir.constant(3 : i32) : i32 + + %4 = llvm.alloca %0 x i32 {in_type = i32, name = "a"} : (i64) -> !llvm.ptr + + // expected-error @below {{Addressing symbol not found}} + // expected-error @below {{LLVM Translation failed for operation: omp.threadprivate}} + %5 = omp.threadprivate %4 : !llvm.ptr -> !llvm.ptr + + llvm.store %1, %5 : !llvm.ptr + + omp.parallel { + %6 = omp.threadprivate %4 : !llvm.ptr -> !llvm.ptr + llvm.store %2, %6 : !llvm.ptr + omp.terminator + } + + llvm.store %3, %5 : !llvm.ptr + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 9186551..86f98fd 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2050,3 +2050,43 @@ llvm.func @single_nowait(%x: i32, %y: i32, %zaddr: !llvm.ptr) { // CHECK: ret void llvm.return } + +// ----- + +// CHECK: @_QFsubEx = internal global i32 undef +// CHECK: @_QFsubEx.cache = common global i8** null + +// CHECK-LABEL: @omp_threadprivate +llvm.func @omp_threadprivate() { +// CHECK: [[THREAD:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GLOB:[0-9]+]]) +// CHECK: [[TMP1:%.*]] = call i8* @__kmpc_threadprivate_cached(%struct.ident_t* @[[GLOB]], i32 [[THREAD]], i8* bitcast (i32* @_QFsubEx to i8*), i64 4, i8*** @_QFsubEx.cache) +// CHECK: [[TMP2:%.*]] = bitcast i8* [[TMP1]] to i32* +// CHECK: store i32 1, i32* [[TMP2]], align 4 +// CHECK: store i32 3, i32* [[TMP2]], align 4 + +// CHECK-LABEL: omp.par.region{{.*}} +// CHECK: [[THREAD2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GLOB2:[0-9]+]]) +// CHECK: [[TMP3:%.*]] = call i8* @__kmpc_threadprivate_cached(%struct.ident_t* @[[GLOB2]], i32 [[THREAD2]], i8* bitcast (i32* @_QFsubEx to i8*), i64 4, i8*** @_QFsubEx.cache) +// CHECK: [[TMP4:%.*]] = bitcast i8* [[TMP3]] to i32* +// CHECK: store i32 2, i32* [[TMP4]], align 4 + + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(2 : i32) : i32 + %2 = llvm.mlir.constant(3 : i32) : i32 + + %3 = llvm.mlir.addressof @_QFsubEx : !llvm.ptr + %4 = omp.threadprivate %3 : !llvm.ptr -> !llvm.ptr + + llvm.store %0, %4 : !llvm.ptr + + omp.parallel { + %5 = omp.threadprivate %3 : !llvm.ptr -> !llvm.ptr + llvm.store %1, %5 : !llvm.ptr + omp.terminator + } + + llvm.store %2, %4 : !llvm.ptr + llvm.return +} + +llvm.mlir.global internal @_QFsubEx() : i32 -- 2.7.4