From c0b775a5b506408bcdd9ffe31a51400a99734f2c Mon Sep 17 00:00:00 2001 From: Mitch Phillips <31459023+hctim@users.noreply.github.com> Date: Wed, 21 Dec 2022 09:18:01 -0800 Subject: [PATCH] Revert "BEGIN_PUBLIC" This reverts commit a6d6d40d8bd062514fc379a6bf70fb1b7220be6f. Reason: Broke the ASan/MSan bots. More information in phabricator: https://reviews.llvm.org/D140406 --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 209 +++++++---------------- mlir/test/Dialect/Linalg/one-shot-bufferize.mlir | 4 +- mlir/test/Dialect/Linalg/roundtrip.mlir | 65 ++----- 3 files changed, 73 insertions(+), 205 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 56c7d84..8b0540e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -174,6 +174,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; } +static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p, + ValueRange inputs, + ValueRange outputs) { + if (!inputs.empty()) { + p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; + } + if (!outputs.empty()) { + p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; + } +} //===----------------------------------------------------------------------===// // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// @@ -1011,119 +1021,38 @@ void MapOp::build( inputs, /*outputs=*/{}, bodyBuild); } -static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, - const OperationName &payloadOpName, - const NamedAttrList &payloadOpAttrs, - ArrayRef operands) { - OpBuilder b(parser.getContext()); - Region *body = result.addRegion(); - Block &block = body->emplaceBlock(); - b.setInsertionPointToStart(&block); - SmallVector bbArgs; - for (auto &operand : operands) { - block.addArgument(operand.getType().cast().getElementType(), - b.getUnknownLoc()); - } - - Operation *payloadOp = b.create( - result.location, b.getStringAttr(payloadOpName.getStringRef()), - block.getArguments(), - TypeRange{ - result.operands.back().getType().cast().getElementType()}, - payloadOpAttrs); - - b.create(result.location, payloadOp->getResults()); -} - ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { - std::optional payloadOpName; - NamedAttrList payloadOpAttrs; - if (succeeded(parser.parseOptionalLBrace())) { - FailureOr operationName = parser.parseCustomOperationName(); - if (failed(operationName)) - return failure(); - if (parser.parseOptionalAttrDict(payloadOpAttrs)) - return failure(); - payloadOpName = operationName.value(); - if (parser.parseRBrace()) - return failure(); - } - if (parseDstStyleOp(parser, result)) return failure(); - if (payloadOpName.has_value()) { - addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, - makeArrayRef(result.operands).drop_back()); - } else { - SmallVector regionArgs; - if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, - /*allowType=*/true, /*allowAttrs=*/true)) { - return failure(); - } - Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs)) - return failure(); + SmallVector regionArgs; + if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, /*allowAttrs=*/true)) { + return failure(); } - return success(); -} -// Retrieve the operation from the body, if it is the only one (except -// yield) and if it gets the same amount of arguments as the body does. -static Operation *findPayloadOp(Block *body) { - if (body->getOperations().size() != 2) - return nullptr; - Operation &payload = body->getOperations().front(); - assert(isa(body->getOperations().back())); - - if (payload.getNumOperands() == 0 || - payload.getNumOperands() != body->getNumArguments()) - return nullptr; - for (const auto &[bbArg, operand] : - llvm::zip(payload.getOperands(), body->getArguments())) { - if (bbArg != operand) - return nullptr; - } - return &payload; -} + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return failure(); -void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { - SmallVector elidedAttrs; - p << " { " << payloadOp->getName().getStringRef(); - for (const auto &attr : payloadOp->getAttrs()) { - auto fastAttr = attr.getValue().dyn_cast(); - if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { - elidedAttrs.push_back(attr.getName().str()); - break; - } - } - p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs); - p << " }"; + return success(); } void MapOp::print(OpAsmPrinter &p) { - Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper); - if (payloadOp) { - printShortForm(p, payloadOp); - } - - printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); p.printOptionalAttrDict((*this)->getAttrs()); - if (!payloadOp) { - // Print region if the payload op was not detected. - p.increaseIndent(); - p.printNewline(); - p << "("; - llvm::interleaveComma(mapper->getArguments(), p, - [&](auto arg) { p.printRegionArgument(arg); }); - p << ") "; - - p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); - p.decreaseIndent(); - } + p.increaseIndent(); + p.printNewline(); + p << "("; + llvm::interleaveComma(getMapper().getArguments(), p, + [&](auto arg) { p.printRegionArgument(arg); }); + p << ") "; + + p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult MapOp::verify() { @@ -1136,7 +1065,7 @@ LogicalResult MapOp::verify() { "mapper, but got: " << getInputs().size() << " and " << blockArgs.size(); - // The parameters of mapper should all match the element type of inputs. + // The parameters of mapper should all match the element type // of inputs. for (const auto &[bbArgType, inputArg] : llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { auto inputElemType = inputArg.getType().cast().getElementType(); @@ -1258,40 +1187,22 @@ static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, } ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { - std::optional payloadOpName; - NamedAttrList payloadOpAttrs; - if (succeeded(parser.parseOptionalLBrace())) { - FailureOr operationName = parser.parseCustomOperationName(); - if (failed(operationName)) - return failure(); - if (parser.parseOptionalAttrDict(payloadOpAttrs)) - return failure(); - payloadOpName = operationName.value(); - if (parser.parseRBrace()) - return failure(); - } - if (parseDstStyleOp( parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); })) return failure(); - if (payloadOpName.has_value()) { - addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, - makeArrayRef(result.operands)); - } else { - SmallVector regionArgs; - if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, - /*allowType=*/true, /*allowAttrs=*/true)) { - return failure(); - } - - Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs)) - return failure(); + SmallVector regionArgs; + if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, /*allowAttrs=*/true)) { + return failure(); } + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return failure(); + return success(); } @@ -1301,28 +1212,22 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, } void ReduceOp::print(OpAsmPrinter &p) { - Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper); - if (payloadOp) { - printShortForm(p, payloadOp); - } + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); - printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); - if (!payloadOp) { - // Print region if the payload op was not detected. - p.increaseIndent(); - p.printNewline(); - p << "("; - llvm::interleaveComma(mapper->getArguments(), p, - [&](auto arg) { p.printRegionArgument(arg); }); - p << ") "; - - p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); - p.decreaseIndent(); - } + + p.increaseIndent(); + p.printNewline(); + p << "("; + llvm::interleaveComma(getCombiner().getArguments(), p, + [&](auto arg) { p.printRegionArgument(arg); }); + p << ") "; + + p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult ReduceOp::verify() { @@ -1471,8 +1376,9 @@ void TransposeOp::getAsmResultNames( } void TransposeOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); } @@ -1585,8 +1491,9 @@ void BroadcastOp::getAsmResultNames( } void BroadcastOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); } diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir index 87763c9..d418a92 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -340,7 +340,7 @@ func.func @op_is_reading_but_following_ops_are_not( // CHECK-SAME: %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32 func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { - // CHECK: linalg.map { arith.addf } ins(%[[LHS]], %[[RHS]] : memref<64xf32 + // CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32 %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) @@ -357,7 +357,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, // CHECK-SAME: %[[INPUT:.*]]: memref<16x32x64xf32 func.func @reduce(%input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { - // CHECK: linalg.reduce { arith.addf } ins(%[[INPUT]] : memref<16x32x64xf32 + // CHECK: linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32 %reduce = linalg.reduce ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<16x64xf32>) diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 611d428..b1a614f 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -356,8 +356,12 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, func.return %add : tensor<64xf32> } // CHECK-LABEL: func @map_binary -// CHECK: linalg.map { arith.addf } ins +// CHECK: linalg.map ins // CHECK-SAME: outs +// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } // ----- @@ -420,9 +424,13 @@ func.func @reduce(%input: tensor<16x32x64xf32>, func.return %reduce : tensor<16x64xf32> } // CHECK-LABEL: func @reduce -// CHECK: linalg.reduce { arith.addf } ins +// CHECK: linalg.reduce ins // CHECK-SAME: outs // CHECK-SAME: dimensions = [1] +// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } // ----- @@ -438,10 +446,8 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>, } func.return } -// CHECK-LABEL: func @reduce -// CHECK: linalg.reduce { arith.addf } ins -// CHECK-SAME: outs -// CHECK-SAME: dimensions = [1] +// CHECK-LABEL: func @reduce_memref +// CHECK: linalg.reduce // ----- @@ -461,7 +467,6 @@ func.func @variadic_reduce(%input1: tensor<16x32x64xf32>, } // CHECK-LABEL: func @variadic_reduce // CHECK: linalg.reduce -// CHECK-NOT: { arith.addf // ----- @@ -479,9 +484,8 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>, } func.return } -// CHECK-LABEL: func @variadic_reduce_memref +// CHECK-LABEL: func @variadic_reduce_memref // CHECK: linalg.reduce -// CHECK-NOT: { arith.addf // ----- @@ -556,46 +560,3 @@ func.func @broadcast_memref(%input: memref<8x32xf32>, // CHECK: linalg.broadcast ins // CHECK-SAME: outs // CHECK-SAME: dimensions - -// ----- - -func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, - %init: tensor<64xf32>) -> tensor<64xf32> { - %add = linalg.map - ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) - outs(%init:tensor<64xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { - %0 = arith.addf %lhs_elem, %rhs_elem fastmath : f32 - linalg.yield %0: f32 - } - func.return %add : tensor<64xf32> -} - -// CHECK-LABEL: func @map_arith_with_attr -// CHECK-NEXT: %[[MAPPED:.*]] = linalg.map -// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath} } -// CHECK-SAME: ins -// CHECK-SAME: outs -// CHECK-NEXT: return %[[MAPPED]] : tensor<64xf32> - -// ----- - -func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>, - %init: tensor<16x64xf32>) -> tensor<16x64xf32> { - %reduce = linalg.reduce - ins(%input:tensor<16x32x64xf32>) - outs(%init:tensor<16x64xf32>) - dimensions = [1] - (%in: f32, %out: f32) { - %0 = arith.addf %in, %out fastmath : f32 - linalg.yield %0: f32 - } - func.return %reduce : tensor<16x64xf32> -} -// CHECK-LABEL: func @reduce_arith_with_attr -// CHECK-NEXT: %[[REDUCED:.*]] = linalg.reduce -// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath} } -// CHECK-SAME: ins -// CHECK-SAME: outs -// CHECK-SAME: dimensions = [1] -// CHECK-NEXT: return %[[REDUCED]] : tensor<16x64xf32> -- 2.7.4