[mlir][linalg] Cleanup LinalgOp usage in test passes.
authorTobias Gysi <gysit@google.com>
Thu, 3 Jun 2021 11:33:30 +0000 (11:33 +0000)
committerTobias Gysi <gysit@google.com>
Thu, 3 Jun 2021 12:07:29 +0000 (12:07 +0000)
Replace the uses of deprecated Structured Op Interface methods in TestLinalgElementwiseFusion.cpp, TestLinalgFusionTransforms.cpp, and Transforms.cpp. The patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103528

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp

index c85cd0c..15420cc 100644 (file)
@@ -167,9 +167,9 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
 
   // If the op is fully static, it does not need padding.
   // TODO: there are cases where we may still want to pad to larger sizes.
-  if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) {
-        return v.getType().cast<RankedTensorType>().hasStaticShape();
-      }))
+  assert(opToPad.hasTensorSemantics() &&
+         "expected operation to have tensor semantics");
+  if (!opToPad.hasDynamicShape())
     return success();
 
   OpBuilder::InsertionGuard g(rewriter);
@@ -177,16 +177,16 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
   rewriter.setInsertionPointAfter(opToPad);
   // Make a copy of the shaped operands and update it.
   SmallVector<Value> newOperands;
-  newOperands.reserve(opToPad.getNumShapedOperands());
-  for (OpOperand &operand : opToPad.getShapedOpOperands()) {
+  newOperands.reserve(opToPad.getNumInputsAndOutputs());
+  for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
     Value paddedOperand;
     // If padding was requested but the shape cannot be bounded statically then
     // the pattern fails to apply.
-    if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, operand,
-                                                     options, paddedOperand))) {
+    if (failed(padOperandToSmallestStaticBoundingBox(
+            rewriter, opToPad, *opOperand, options, paddedOperand))) {
       return failure();
     }
-    newOperands.push_back(paddedOperand ? paddedOperand : operand.get());
+    newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
   }
 
   // Clone `opToPad` to operate on the statically padded shapes.
index d0812ab..d7ec756 100644 (file)
@@ -24,8 +24,8 @@ static void addOperands(Operation *op, SetVector<Value> &operandSet) {
     return;
   TypeSwitch<Operation *, void>(op)
       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
-        operandSet.insert(linalgOp.getInputs().begin(),
-                          linalgOp.getInputs().end());
+        SmallVector<Value> inputOperands = linalgOp.getInputOperands();
+        operandSet.insert(inputOperands.begin(), inputOperands.end());
       })
       .Default([&](Operation *operation) {
         operandSet.insert(operation->operand_begin(), operation->operand_end());
index 4413fac..26e4f6a 100644 (file)
@@ -147,14 +147,14 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
   bool changed = false;
   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
-    for (OpOperand &opOperand : linalgOp.getShapedOpOperands()) {
-      if (opOperand.get().getType().isa<MemRefType>()) {
+    for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
+      if (opOperand->get().getType().isa<MemRefType>()) {
         // TODO: LinalgDependenceGraph should be able to update itself.
         // The current naive and expensive reconstruction of the graph should be
         // removed.
         linalg::Aliases aliases;
         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
-        if (auto info = fuseProducerOfBuffer(b, opOperand, graph)) {
+        if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) {
           auto *originalOp = info->originalProducer.getOperation();
           eraseSet.insert(originalOp);
           auto *originalOpInLinalgOpsVector =
@@ -163,11 +163,11 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
           changed = true;
         }
       } else {
-        assert(opOperand.get().getType().isa<RankedTensorType>());
+        assert(opOperand->get().getType().isa<RankedTensorType>());
         // Tile and Fuse tensor input.
-        if (opOperand.getOperandNumber() >= linalgOp.getNumInputs())
+        if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
           continue;
-        if (auto info = fuseProducerOfTensor(b, opOperand)) {
+        if (auto info = fuseProducerOfTensor(b, *opOperand)) {
           auto *originalOp = info->originalProducer.getOperation();
           auto *originalOpInLinalgOpsVector =
               std::find(linalgOps.begin(), linalgOps.end(), originalOp);