From 30bd11fab47f75e43ba9d0133978d964eef819ca Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay Date: Thu, 28 Oct 2021 11:04:40 +0530 Subject: [PATCH] [MLIR][OpenMP] Fixed the missing inclusive clause in omp.wsloop and fix order clause This patch adds the inclusive clause (which was missed in previous reorganization - https://reviews.llvm.org/D110903) in omp.wsloop operation. Added a test for validating it. Also fixes the order clause, which was not accepting any values. It now accepts "concurrent" as a value, as specified in the standard. Reviewed By: kiranchandramohan, peixin, clementval Differential Revision: https://reviews.llvm.org/D112198 --- llvm/include/llvm/Frontend/OpenMP/OMP.td | 4 +- llvm/unittests/Frontend/OpenMPParsingTest.cpp | 5 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 35 +++---- mlir/test/Dialect/OpenMP/invalid.mlir | 129 +++++++++++++++++++++++++- mlir/test/Dialect/OpenMP/ops.mlir | 37 +++++++- 5 files changed, 188 insertions(+), 22 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index fffd8d7..5fd3041 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -285,11 +285,13 @@ def OMPC_NonTemporal : Clause<"nontemporal"> { let isValueList = true; } -def OMP_ORDER_concurrent : ClauseVal<"default",2,0> { let isDefault = 1; } +def OMP_ORDER_concurrent : ClauseVal<"concurrent",1,1> {} +def OMP_ORDER_unknown : ClauseVal<"unknown",2,0> { let isDefault = 1; } def OMPC_Order : Clause<"order"> { let clangClass = "OMPOrderClause"; let enumClauseValue = "OrderKind"; let allowedClauseValues = [ + OMP_ORDER_unknown, OMP_ORDER_concurrent ]; } diff --git a/llvm/unittests/Frontend/OpenMPParsingTest.cpp b/llvm/unittests/Frontend/OpenMPParsingTest.cpp index ea06b34..227e08c 100644 --- a/llvm/unittests/Frontend/OpenMPParsingTest.cpp +++ b/llvm/unittests/Frontend/OpenMPParsingTest.cpp @@ -55,8 +55,9 @@ TEST(OpenMPParsingTest, isAllowedClauseForDirective) { } TEST(OpenMPParsingTest, getOrderKind) { - EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_concurrent); - EXPECT_EQ(getOrderKind("default"), OMP_ORDER_concurrent); + EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_unknown); + EXPECT_EQ(getOrderKind("unknown"), OMP_ORDER_unknown); + EXPECT_EQ(getOrderKind("concurrent"), OMP_ORDER_concurrent); } TEST(OpenMPParsingTest, getProcBindKind) { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 14e8994..e85a4b7 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -492,7 +492,6 @@ enum ClauseType { collapseClause, orderClause, orderedClause, - inclusiveClause, memoryOrderClause, hintClause, COUNT @@ -577,8 +576,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, // segments if (clause == defaultClause || clause == procBindClause || clause == nowaitClause || clause == collapseClause || - clause == orderClause || clause == orderedClause || - clause == inclusiveClause) + clause == orderClause || clause == orderedClause) continue; pos[clause] = currPos++; @@ -596,7 +594,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, bool allowRepeat = false) -> ParseResult { if (!llvm::is_contained(clauses, clause)) return parser.emitError(parser.getCurrentLocation()) - << clauseKeyword << "is not a valid clause for the " << opName + << clauseKeyword << " is not a valid clause for the " << opName << " operation"; if (done[clause] && !allowRepeat) return parser.emitError(parser.getCurrentLocation()) @@ -717,12 +715,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, parser.parseKeyword(&order) || parser.parseRParen()) return failure(); auto attr = parser.getBuilder().getStringAttr(order); - result.addAttribute("order", attr); - } else if (clauseKeyword == "inclusive") { - if (checkAllowed(inclusiveClause)) - return failure(); - auto attr = UnitAttr::get(parser.getBuilder().getContext()); - result.addAttribute("inclusive", attr); + result.addAttribute("order_val", attr); } else if (clauseKeyword == "memory_order") { StringRef memoryOrder; if (checkAllowed(memoryOrderClause) || parser.parseLParen() || @@ -875,11 +868,11 @@ static ParseResult parseParallelOp(OpAsmParser &parser, /// /// wsloop ::= `omp.wsloop` loop-control clause-list /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds -/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps +/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps /// steps := `step` `(`ssa-id-list`)` /// clause-list ::= clause clause-list | empty /// clause ::= private | firstprivate | lastprivate | linear | schedule | -// collapse | nowait | ordered | order | inclusive | reduction +// collapse | nowait | ordered | order | reduction static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { // Parse an opening `(` followed by induction variables followed by `)` @@ -906,6 +899,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { parser.resolveOperands(upper, loopVarType, result.operands)) return failure(); + if (succeeded(parser.parseOptionalKeyword("inclusive"))) { + auto attr = UnitAttr::get(parser.getBuilder().getContext()); + result.addAttribute("inclusive", attr); + } + // Parse step values. SmallVector steps; if (parser.parseKeyword("step") || @@ -936,7 +934,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { auto args = op.getRegion().front().getArguments(); p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() - << ") to (" << op.upperBound() << ") step (" << op.step() << ") "; + << ") to (" << op.upperBound() << ") "; + if (op.inclusive()) { + p << "inclusive "; + } + p << "step (" << op.step() << ") "; printDataVars(p, op.private_vars(), "private"); printDataVars(p, op.firstprivate_vars(), "firstprivate"); @@ -962,15 +964,14 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { if (auto ordered = op.ordered_val()) p << "ordered(" << ordered << ") "; + if (auto order = op.order_val()) + p << "order(" << order << ") "; + if (!op.reduction_vars().empty()) { p << "reduction("; printReductionVarList(p, op.reductions(), op.reduction_vars()); } - if (op.inclusive()) { - p << "inclusive "; - } - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index e57ddfc..36eee32 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -69,7 +69,62 @@ func @copyin_once(%n : memref) { } // ----- - + +func @lastprivate_not_allowed(%n : memref) { + // expected-error@+1 {{lastprivate is not a valid clause for the omp.parallel operation}} + omp.parallel lastprivate(%n : memref) {} + return +} + +// ----- + +func @nowait_not_allowed(%n : memref) { + // expected-error@+1 {{nowait is not a valid clause for the omp.parallel operation}} + omp.parallel nowait {} + return +} + +// ----- + +func @linear_not_allowed(%data_var : memref, %linear_var : i32) { + // expected-error@+1 {{linear is not a valid clause for the omp.parallel operation}} + omp.parallel linear(%data_var = %linear_var : memref) {} + return +} + +// ----- + +func @schedule_not_allowed() { + // expected-error@+1 {{schedule is not a valid clause for the omp.parallel operation}} + omp.parallel schedule(static) {} + return +} + +// ----- + +func @collapse_not_allowed() { + // expected-error@+1 {{collapse is not a valid clause for the omp.parallel operation}} + omp.parallel collapse(3) {} + return +} + +// ----- + +func @order_not_allowed() { + // expected-error@+1 {{order is not a valid clause for the omp.parallel operation}} + omp.parallel order(concurrent) {} + return +} + +// ----- + +func @ordered_not_allowed() { + // expected-error@+1 {{ordered is not a valid clause for the omp.parallel operation}} + omp.parallel ordered(2) {} +} + +// ----- + func @default_once() { // expected-error@+1 {{at most one default clause can appear on the omp.parallel operation}} omp.parallel default(private) default(firstprivate) { @@ -90,6 +145,78 @@ func @proc_bind_once() { // ----- +func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) { + // expected-error @below {{inclusive is not a valid clause}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait inclusive { + omp.yield + } +} + +// ----- + +func @order_value(%lb : index, %ub : index, %step : index) { + // expected-error @below {{attribute 'order_val' failed to satisfy constraint: OrderKind Clause}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(default) { + omp.yield + } +} + +// ----- + +func @shared_not_allowed(%lb : index, %ub : index, %step : index, %var : memref) { + // expected-error @below {{shared is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) shared(%var) { + omp.yield + } +} + +// ----- + +func @copyin(%lb : index, %ub : index, %step : index, %var : memref) { + // expected-error @below {{copyin is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) copyin(%var) { + omp.yield + } +} + +// ----- + +func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) { + // expected-error @below {{if is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) if(%bool_var: i1) { + omp.yield + } +} + +// ----- + +func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var : i32) { + // expected-error @below {{num_threads is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) num_threads(%int_var: i32) { + omp.yield + } +} + +// ----- + +func @default_not_allowed(%lb : index, %ub : index, %step : index) { + // expected-error @below {{default is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) default(private) { + omp.yield + } +} + +// ----- + +func @proc_bind_not_allowed(%lb : index, %ub : index, %step : index) { + // expected-error @below {{proc_bind is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) proc_bind(close) { + omp.yield + } +} + +// ----- + // expected-error @below {{op expects initializer region with one argument of the reduction type}} omp.reduction.declare @add_f32 : f64 init { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 0d7c7af..4d0801d 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -123,7 +123,27 @@ func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : omp.terminator } - return + // CHECK: omp.parallel default(private) + omp.parallel default(private) { + omp.terminator + } + + // CHECK: omp.parallel default(firstprivate) + omp.parallel default(firstprivate) { + omp.terminator + } + + // CHECK: omp.parallel default(shared) + omp.parallel default(shared) { + omp.terminator + } + + // CHECK: omp.parallel default(none) + omp.parallel default(none) { + omp.terminator + } + + return } // CHECK-LABEL: omp_wsloop @@ -207,6 +227,21 @@ func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index, omp.yield } + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) + omp.wsloop (%iv) : index = (%lb) to (%ub) inclusive step (%step) { + omp.yield + } + + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait { + omp.yield + } + + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait order(concurrent) + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(concurrent) nowait { + omp.yield + } + return } -- 2.7.4