[mlir][OpenMP] Translation to LLVM IR for omp.taskgroup
authorShraiysh Vaishay <shraiysh@gmail.com>
Wed, 31 Aug 2022 04:34:24 +0000 (04:34 +0000)
committerShraiysh <shraiysh@gmail.com>
Wed, 31 Aug 2022 04:55:01 +0000 (04:55 +0000)
This patch adds translation from OpenMP Dialect to LLVM IR for
omp.taskgroup. This patch also adds missing tests for the clauses in
omp.taskgroup operation.

Reviewed By: peixin

Differential Revision: https://reviews.llvm.org/D130157

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Dialect/OpenMP/ops.mlir
mlir/test/Target/LLVMIR/openmp-llvm.mlir

index d390eae..2707567 100644 (file)
@@ -701,6 +701,27 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   return bodyGenStatus;
 }
 
+/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, llvm::IRBuilderBase &builder,
+                      LLVM::ModuleTranslation &moduleTranslation) {
+  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+  LogicalResult bodyGenStatus = success();
+  if (!tgOp.task_reduction_vars().empty() || !tgOp.allocate_vars().empty()) {
+    return tgOp.emitError("unhandled clauses for translation to LLVM IR");
+  }
+  auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
+    builder.restoreIP(codegenIP);
+    convertOmpOpRegions(tgOp.region(), "omp.taskgroup.region", builder,
+                        moduleTranslation, bodyGenStatus);
+  };
+  InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTaskgroup(
+      ompLoc, allocaIP, bodyCB));
+  return bodyGenStatus;
+}
+
 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1406,6 +1427,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
       .Case([&](omp::TaskOp op) {
         return convertOmpTaskOp(op, builder, moduleTranslation);
       })
+      .Case([&](omp::TaskGroupOp op) {
+        return convertOmpTaskgroupOp(op, builder, moduleTranslation);
+      })
       .Case<omp::YieldOp, omp::TerminatorOp, omp::ReductionDeclareOp,
             omp::CriticalDeclareOp>([](auto op) {
         // `yield` and `terminator` can be just omitted. The block structure
index 4cc8fd5..51b7dba 100644 (file)
@@ -1572,6 +1572,30 @@ func.func @omp_taskgroup_multiple_tasks() -> () {
   return
 }
 
+// CHECK-LABEL: @omp_taskgroup_clauses
+func.func @omp_taskgroup_clauses() -> () {
+  %testmemref = "test.memref"() : () -> (memref<i32>)
+  %testf32 = "test.f32"() : () -> (!llvm.ptr<f32>)
+  // CHECK: omp.taskgroup task_reduction(@add_f32 -> %{{.+}}: !llvm.ptr<f32>) allocate(%{{.+}}: memref<i32> -> %{{.+}}: memref<i32>)
+  omp.taskgroup allocate(%testmemref : memref<i32> -> %testmemref : memref<i32>) task_reduction(@add_f32 -> %testf32 : !llvm.ptr<f32>) {
+    // CHECK: omp.task
+    omp.task {
+      "test.foo"() : () -> ()
+      // CHECK: omp.terminator
+      omp.terminator
+    }
+    // CHECK: omp.task
+    omp.task {
+      "test.foo"() : () -> ()
+      // CHECK: omp.terminator
+      omp.terminator
+    }
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: @omp_taskloop
 func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
 
index 787575a..da6f99c 100644 (file)
@@ -2357,3 +2357,98 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
 // CHECK:   call void @[[outlined_fn]](ptr %[[task_data]])
 // CHECK:   ret i32 0
 // CHECK: }
+
+// -----
+
+llvm.func @foo() -> ()
+
+llvm.func @omp_taskgroup(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
+  omp.taskgroup {
+    llvm.call @foo() : () -> ()
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define void @omp_taskgroup(
+// CHECK-SAME:                             i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]]) 
+// CHECK:         br label %[[entry:[^,]+]]
+// CHECK:       [[entry]]:
+// CHECK:         %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
+// CHECK:         call void @__kmpc_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
+// CHECK:         br label %[[omp_taskgroup_region:[^,]+]]
+// CHECK:       [[omp_taskgroup_region]]:
+// CHECK:         call void @foo()
+// CHECK:         br label %[[omp_region_cont:[^,]+]]
+// CHECK:       [[omp_region_cont]]:
+// CHECK:         br label %[[taskgroup_exit:[^,]+]]
+// CHECK:       [[taskgroup_exit]]:
+// CHECK:         call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
+// CHECK:         ret void
+
+// -----
+
+llvm.func @foo() -> ()
+llvm.func @bar(i32, i32, !llvm.ptr<i32>) -> ()
+
+llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
+  omp.taskgroup {
+    %c1 = llvm.mlir.constant(1) : i32
+    %ptr1 = llvm.alloca %c1 x i8 : (i32) -> !llvm.ptr<i8>
+    omp.task {
+      llvm.call @foo() : () -> ()
+      omp.terminator
+    }
+    omp.task {
+      llvm.call @bar(%x, %y, %zaddr) : (i32, i32, !llvm.ptr<i32>) -> ()
+      omp.terminator
+    }
+    llvm.br ^bb1
+  ^bb1:
+    llvm.call @foo() : () -> ()
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define void @omp_taskgroup_task(
+// CHECK-SAME:                                  i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]]) 
+// CHECK:         %[[structArg:.+]] = alloca { i32, i32, ptr }, align 8
+// CHECK:         br label %[[entry:[^,]+]]
+// CHECK:       [[entry]]:                                            ; preds = %3
+// CHECK:         %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
+// CHECK:         call void @__kmpc_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
+// CHECK:         br label %[[omp_taskgroup_region:[^,]+]]
+// CHECK:       [[omp_taskgroup_region1:.+]]:
+// CHECK:         call void @foo()
+// CHECK:         br label %[[omp_region_cont:[^,]+]]
+// CHECK:       [[omp_taskgroup_region]]:
+// CHECK:         %{{.+}} = alloca i8, align 1
+// CHECK:         br label %[[codeRepl:[^,]+]]
+// CHECK:       [[codeRepl]]:
+// CHECK:         %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
+// CHECK:         %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 0, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
+// CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
+// CHECK:         br label %[[task_exit:[^,]+]]
+// CHECK:       [[task_exit]]:
+// CHECK:         br label %[[codeRepl9:[^,]+]]
+// CHECK:       [[codeRepl9]]:
+// CHECK:         %[[gep1:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 0
+// CHECK:         store i32 %[[x]], ptr %[[gep1]], align 4
+// CHECK:         %[[gep2:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 1
+// CHECK:         store i32 %[[y]], ptr %[[gep2]], align 4
+// CHECK:         %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
+// CHECK:         store ptr %[[zaddr]], ptr %[[gep3]], align 8
+// CHECK:         %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
+// CHECK:         %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 16, i64 0, ptr @omp_taskgroup_task..omp_par.1.wrapper)
+// CHECK:         call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[t2_alloc]], ptr align 8 %[[structArg]], i64 16, i1 false)
+// CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]])
+// CHECK:         br label %[[task_exit3:[^,]+]]
+// CHECK:       [[task_exit3]]:
+// CHECK:         br label %[[omp_taskgroup_region1]]
+// CHECK:       [[omp_region_cont]]:
+// CHECK:         br label %[[taskgroup_exit:[^,]+]]
+// CHECK:       [[taskgroup_exit]]:
+// CHECK:         call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]])
+// CHECK:         ret void
+// CHECK:       }