Revert "BEGIN_PUBLIC"
authorMitch Phillips <31459023+hctim@users.noreply.github.com>
Wed, 21 Dec 2022 17:18:01 +0000 (09:18 -0800)
committerMitch Phillips <31459023+hctim@users.noreply.github.com>
Wed, 21 Dec 2022 17:32:54 +0000 (09:32 -0800)
This reverts commit a6d6d40d8bd062514fc379a6bf70fb1b7220be6f.

Reason: Broke the ASan/MSan bots. More information in phabricator:
https://reviews.llvm.org/D140406

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir

index 56c7d84..8b0540e 100644 (file)
@@ -174,6 +174,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
     p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
 }
 
+static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
+                                                    ValueRange inputs,
+                                                    ValueRange outputs) {
+  if (!inputs.empty()) {
+    p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
+  }
+  if (!outputs.empty()) {
+    p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
+  }
+}
 //===----------------------------------------------------------------------===//
 // Specific parsing and printing for named structured ops created by ods-gen.
 //===----------------------------------------------------------------------===//
@@ -1011,119 +1021,38 @@ void MapOp::build(
                        inputs, /*outputs=*/{}, bodyBuild);
 }
 
-static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
-                                 const OperationName &payloadOpName,
-                                 const NamedAttrList &payloadOpAttrs,
-                                 ArrayRef<Value> operands) {
-  OpBuilder b(parser.getContext());
-  Region *body = result.addRegion();
-  Block &block = body->emplaceBlock();
-  b.setInsertionPointToStart(&block);
-  SmallVector<Value> bbArgs;
-  for (auto &operand : operands) {
-    block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
-                      b.getUnknownLoc());
-  }
-
-  Operation *payloadOp = b.create(
-      result.location, b.getStringAttr(payloadOpName.getStringRef()),
-      block.getArguments(),
-      TypeRange{
-          result.operands.back().getType().cast<ShapedType>().getElementType()},
-      payloadOpAttrs);
-
-  b.create<YieldOp>(result.location, payloadOp->getResults());
-}
-
 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
-  std::optional<OperationName> payloadOpName;
-  NamedAttrList payloadOpAttrs;
-  if (succeeded(parser.parseOptionalLBrace())) {
-    FailureOr<OperationName> operationName = parser.parseCustomOperationName();
-    if (failed(operationName))
-      return failure();
-    if (parser.parseOptionalAttrDict(payloadOpAttrs))
-      return failure();
-    payloadOpName = operationName.value();
-    if (parser.parseRBrace())
-      return failure();
-  }
-
   if (parseDstStyleOp(parser, result))
     return failure();
 
-  if (payloadOpName.has_value()) {
-    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
-                         makeArrayRef(result.operands).drop_back());
-  } else {
-    SmallVector<OpAsmParser::Argument> regionArgs;
-    if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
-                                 /*allowType=*/true, /*allowAttrs=*/true)) {
-      return failure();
-    }
-    Region *body = result.addRegion();
-    if (parser.parseRegion(*body, regionArgs))
-      return failure();
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+                               /*allowType=*/true, /*allowAttrs=*/true)) {
+    return failure();
   }
-  return success();
-}
 
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-static Operation *findPayloadOp(Block *body) {
-  if (body->getOperations().size() != 2)
-    return nullptr;
-  Operation &payload = body->getOperations().front();
-  assert(isa<YieldOp>(body->getOperations().back()));
-
-  if (payload.getNumOperands() == 0 ||
-      payload.getNumOperands() != body->getNumArguments())
-    return nullptr;
-  for (const auto &[bbArg, operand] :
-       llvm::zip(payload.getOperands(), body->getArguments())) {
-    if (bbArg != operand)
-      return nullptr;
-  }
-  return &payload;
-}
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
 
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
-  SmallVector<StringRef> elidedAttrs;
-  p << " { " << payloadOp->getName().getStringRef();
-  for (const auto &attr : payloadOp->getAttrs()) {
-    auto fastAttr = attr.getValue().dyn_cast<mlir::arith::FastMathFlagsAttr>();
-    if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
-      elidedAttrs.push_back(attr.getName().str());
-      break;
-    }
-  }
-  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
-  p << " }";
+  return success();
 }
 
 void MapOp::print(OpAsmPrinter &p) {
-  Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
-  }
-
-  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
-                               SmallVector<Value>(getDpsInitOperands()));
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getDpsInputOperands()),
+      SmallVector<Value>(getDpsInitOperands()));
   p.printOptionalAttrDict((*this)->getAttrs());
 
-  if (!payloadOp) {
-    // Print region if the payload op was not detected.
-    p.increaseIndent();
-    p.printNewline();
-    p << "(";
-    llvm::interleaveComma(mapper->getArguments(), p,
-                          [&](auto arg) { p.printRegionArgument(arg); });
-    p << ") ";
-
-    p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
-    p.decreaseIndent();
-  }
+  p.increaseIndent();
+  p.printNewline();
+  p << "(";
+  llvm::interleaveComma(getMapper().getArguments(), p,
+                        [&](auto arg) { p.printRegionArgument(arg); });
+  p << ") ";
+
+  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+  p.decreaseIndent();
 }
 
 LogicalResult MapOp::verify() {
@@ -1136,7 +1065,7 @@ LogicalResult MapOp::verify() {
                             "mapper, but got: "
                          << getInputs().size() << " and " << blockArgs.size();
 
-  // The parameters of mapper should all match the element type of inputs.
+  // The parameters of mapper should all match the element type // of inputs.
   for (const auto &[bbArgType, inputArg] :
        llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
     auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
@@ -1258,40 +1187,22 @@ static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
 }
 
 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
-  std::optional<OperationName> payloadOpName;
-  NamedAttrList payloadOpAttrs;
-  if (succeeded(parser.parseOptionalLBrace())) {
-    FailureOr<OperationName> operationName = parser.parseCustomOperationName();
-    if (failed(operationName))
-      return failure();
-    if (parser.parseOptionalAttrDict(payloadOpAttrs))
-      return failure();
-    payloadOpName = operationName.value();
-    if (parser.parseRBrace())
-      return failure();
-  }
-
   if (parseDstStyleOp(
           parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
             return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
           }))
     return failure();
 
-  if (payloadOpName.has_value()) {
-    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
-                         makeArrayRef(result.operands));
-  } else {
-    SmallVector<OpAsmParser::Argument> regionArgs;
-    if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
-                                 /*allowType=*/true, /*allowAttrs=*/true)) {
-      return failure();
-    }
-
-    Region *body = result.addRegion();
-    if (parser.parseRegion(*body, regionArgs))
-      return failure();
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+                               /*allowType=*/true, /*allowAttrs=*/true)) {
+    return failure();
   }
 
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
   return success();
 }
 
@@ -1301,28 +1212,22 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
 }
 
 void ReduceOp::print(OpAsmPrinter &p) {
-  Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
-  }
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getDpsInputOperands()),
+      SmallVector<Value>(getDpsInitOperands()));
 
-  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
-                               SmallVector<Value>(getDpsInitOperands()));
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
-  if (!payloadOp) {
-    // Print region if the payload op was not detected.
-    p.increaseIndent();
-    p.printNewline();
-    p << "(";
-    llvm::interleaveComma(mapper->getArguments(), p,
-                          [&](auto arg) { p.printRegionArgument(arg); });
-    p << ") ";
-
-    p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
-    p.decreaseIndent();
-  }
+
+  p.increaseIndent();
+  p.printNewline();
+  p << "(";
+  llvm::interleaveComma(getCombiner().getArguments(), p,
+                        [&](auto arg) { p.printRegionArgument(arg); });
+  p << ") ";
+
+  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+  p.decreaseIndent();
 }
 
 LogicalResult ReduceOp::verify() {
@@ -1471,8 +1376,9 @@ void TransposeOp::getAsmResultNames(
 }
 
 void TransposeOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
-                               SmallVector<Value>(getDpsInitOperands()));
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getDpsInputOperands()),
+      SmallVector<Value>(getDpsInitOperands()));
   printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
 }
@@ -1585,8 +1491,9 @@ void BroadcastOp::getAsmResultNames(
 }
 
 void BroadcastOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
-                               SmallVector<Value>(getDpsInitOperands()));
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getDpsInputOperands()),
+      SmallVector<Value>(getDpsInitOperands()));
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
 }
index 87763c9..d418a92 100644 (file)
@@ -340,7 +340,7 @@ func.func @op_is_reading_but_following_ops_are_not(
 // CHECK-SAME:  %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32
 func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
                       %init: tensor<64xf32>) -> tensor<64xf32> {
-   // CHECK:      linalg.map { arith.addf } ins(%[[LHS]], %[[RHS]] : memref<64xf32
+   // CHECK:      linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
@@ -357,7 +357,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 // CHECK-SAME:  %[[INPUT:.*]]: memref<16x32x64xf32
 func.func @reduce(%input: tensor<16x32x64xf32>,
                   %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
-  // CHECK:     linalg.reduce { arith.addf } ins(%[[INPUT]] : memref<16x32x64xf32
+  // CHECK:     linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32
   %reduce = linalg.reduce
       ins(%input:tensor<16x32x64xf32>)
       outs(%init:tensor<16x64xf32>)
index 611d428..b1a614f 100644 (file)
@@ -356,8 +356,12 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   func.return %add : tensor<64xf32>
 }
 // CHECK-LABEL: func @map_binary
-//       CHECK:   linalg.map { arith.addf } ins
+//       CHECK:   linalg.map ins
 //  CHECK-SAME:   outs
+//  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
+//  CHECK-NEXT:     arith.addf
+//  CHECK-NEXT:     linalg.yield
+//  CHECK-NEXT:   }
 
 // -----
 
@@ -420,9 +424,13 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
   func.return %reduce : tensor<16x64xf32>
 }
 // CHECK-LABEL: func @reduce
-//       CHECK:   linalg.reduce { arith.addf } ins
+//       CHECK:   linalg.reduce ins
 //  CHECK-SAME:   outs
 //  CHECK-SAME:   dimensions = [1]
+//  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
+//  CHECK-NEXT:     arith.addf
+//  CHECK-NEXT:     linalg.yield
+//  CHECK-NEXT:   }
 
 // -----
 
@@ -438,10 +446,8 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>,
       }
   func.return
 }
-// CHECK-LABEL: func @reduce
-//       CHECK:   linalg.reduce { arith.addf } ins
-//  CHECK-SAME:   outs
-//  CHECK-SAME:   dimensions = [1]
+// CHECK-LABEL: func @reduce_memref
+//       CHECK:     linalg.reduce
 
 // -----
 
@@ -461,7 +467,6 @@ func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
 }
 // CHECK-LABEL: func @variadic_reduce
 //       CHECK:     linalg.reduce
-//   CHECK-NOT:     { arith.addf
 
 // -----
 
@@ -479,9 +484,8 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
       }
   func.return
 }
-//   CHECK-LABEL: func @variadic_reduce_memref
+// CHECK-LABEL: func @variadic_reduce_memref
 //       CHECK:     linalg.reduce
-//   CHECK-NOT:     { arith.addf
 
 // -----
 
@@ -556,46 +560,3 @@ func.func @broadcast_memref(%input: memref<8x32xf32>,
 //      CHECK:    linalg.broadcast ins
 // CHECK-SAME:    outs
 // CHECK-SAME:    dimensions
-
-// -----
-
-func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
-                      %init: tensor<64xf32>) -> tensor<64xf32> {
-  %add = linalg.map
-          ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
-          outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
-            %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
-            linalg.yield %0: f32
-          }
-  func.return %add : tensor<64xf32>
-}
-
-// CHECK-LABEL: func @map_arith_with_attr
-// CHECK-NEXT:    %[[MAPPED:.*]] = linalg.map
-// CHECK-SAME:    { arith.addf {fastmath = #arith.fastmath<fast>} }
-// CHECK-SAME:    ins
-// CHECK-SAME:    outs
-// CHECK-NEXT:    return %[[MAPPED]] : tensor<64xf32>
-
-// -----
-
-func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
-                  %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
-  %reduce = linalg.reduce
-      ins(%input:tensor<16x32x64xf32>)
-      outs(%init:tensor<16x64xf32>)
-      dimensions = [1]
-      (%in: f32, %out: f32) {
-        %0 = arith.addf %in, %out fastmath<fast> : f32
-        linalg.yield %0: f32
-      }
-  func.return %reduce : tensor<16x64xf32>
-}
-// CHECK-LABEL: func @reduce_arith_with_attr
-// CHECK-NEXT:    %[[REDUCED:.*]] = linalg.reduce
-// CHECK-SAME:    { arith.addf {fastmath = #arith.fastmath<fast>} }
-// CHECK-SAME:    ins
-// CHECK-SAME:    outs
-// CHECK-SAME:    dimensions = [1]
-// CHECK-NEXT:    return %[[REDUCED]] : tensor<16x64xf32>