[mlir][bufferize] Provide default BufferizableOpInterface impl for destination style ops
authorMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 08:40:24 +0000 (10:40 +0200)
committerMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 08:52:47 +0000 (10:52 +0200)
tensor.insert and tensor.insert_slice (as destination style ops) do no longer need to implement the entire BufferizableOpInterface.

Differential Revision: https://reviews.llvm.org/D136347

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h [new file with mode: 0644]
mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index ab8f3e2..a77af0b 100644 (file)
@@ -143,6 +143,10 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Return the OpResult that aliases with a given OpOperand when
           bufferized in-place. This method will never be called on OpOperands
           that do not have a tensor type.
+
+          Note: This method can return multiple OpResults, indicating that a
+          given OpOperand may at runtime alias with any (or multiple) of the
+          returned OpResults.
         }],
         /*retType=*/"SmallVector<OpResult>",
         /*methodName=*/"getAliasingOpResult",
@@ -165,8 +169,9 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           return the OpOperands that are yielded by the terminator.
 
           Note: This method can return multiple OpOperands, indicating that the
-          given OpResult may at runtime alias with any of the OpOperands. This
-          is useful for branches and for ops such as `arith.select`.
+          given OpResult may at runtime alias with any (or multiple) of the
+          returned OpOperands. This can be useful for branches and for ops such
+          as `arith.select`.
         }],
         /*retType=*/"SmallVector<OpOperand *>",
         /*methodName=*/"getAliasingOpOperand",
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h
new file mode 100644 (file)
index 0000000..f1b1c65
--- /dev/null
@@ -0,0 +1,59 @@
+//===- DstBufferizableOpInterfaceImpl.h - Dst Op Bufferization --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_DSTBUFFERIZABLEOPINTERFACEIMPL_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_DSTBUFFERIZABLEOPINTERFACEIMPL_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// Bufferizable ops that implement the DestinationStyleOpInterface can use this
+/// external model base class. It provides default implementations for various
+/// required interface methods.
+template <typename ConcreteModel, typename ConcreteOp>
+struct DstBufferizableOpInterfaceExternalModel
+    : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    // All inputs and outputs bufferize to a memory read.
+    assert(isa<DestinationStyleOpInterface>(op) &&
+           "expected that op implements DestinationStyleOpInterface");
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    // Only outputs bufferize to a memory write.
+    auto dstOp = cast<DestinationStyleOpInterface>(op);
+    return dstOp.isOutput(&opOperand);
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    // Output operands alias with their respective tied OpResults.
+    auto dstOp = cast<DestinationStyleOpInterface>(op);
+    if (dstOp.isOutput(&opOperand))
+      return {dstOp.getTiedOpResult(&opOperand)};
+    return {};
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const AnalysisState &state) const {
+    assert(isa<DestinationStyleOpInterface>(op) &&
+           "expected that op implements DestinationStyleOpInterface");
+    return BufferRelation::Equivalent;
+  }
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_DSTBUFFERIZABLEOPINTERFACEIMPL_H_
index 2e2f9fd..0c085a4 100644 (file)
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRDestinationStyleOpInterface
   MLIRDialect
   MLIRFuncDialect
   MLIRIR
index 16e84a4..ac24d0d 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -563,32 +564,12 @@ struct GenerateOpInterface
 };
 
 /// Bufferization of tensor.insert. Replace with memref.store.
+///
+/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
+/// implementations for DestinationStyle ops.
 struct InsertOpInterface
-    : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
-                                                    tensor::InsertOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const AnalysisState &state) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const AnalysisState &state) const {
-    return true;
-  }
-
-  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                                            const AnalysisState &state) const {
-    assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
-           "expected dest OpOperand");
-    return {op->getOpResult(0)};
-  }
-
-  SmallVector<OpOperand *>
-  getAliasingOpOperand(Operation *op, OpResult opResult,
-                       const AnalysisState &state) const {
-    return {&op->getOpOperand(1) /*dest*/};
-  }
-
+    : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
+                                                     tensor::InsertOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto insertOp = cast<tensor::InsertOp>(op);
@@ -601,11 +582,6 @@ struct InsertOpInterface
     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
     return success();
   }
-
-  BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const AnalysisState &state) const {
-    return BufferRelation::Equivalent;
-  }
 };
 
 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
@@ -732,31 +708,12 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
 
 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
 /// certain circumstances, this op can also be a no-op.
+///
+/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
+/// implementations for DestinationStyle ops.
 struct InsertSliceOpInterface
-    : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
-                                                    tensor::InsertSliceOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const AnalysisState &state) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const AnalysisState &state) const {
-    return &opOperand == &op->getOpOperand(1) /*dest*/;
-  }
-
-  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                                            const AnalysisState &state) const {
-    if (&opOperand == &op->getOpOperand(1) /*dest*/)
-      return {op->getResult(0)};
-    return {};
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const AnalysisState &state) const {
-    return BufferRelation::Equivalent;
-  }
-
+    : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
+                                                     tensor::InsertSliceOp> {
   bool isNotConflicting(Operation *op, OpOperand *uRead,
                         OpOperand *uConflictingWrite,
                         const AnalysisState &state) const {
index ad85002..4c75554 100644 (file)
@@ -9817,6 +9817,7 @@ cc_library(
     hdrs = [
         "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
+        "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
     ],
     includes = ["include"],
     deps = [
@@ -9828,6 +9829,7 @@ cc_library(
         ":BufferizationOpsIncGen",
         ":ControlFlowInterfaces",
         ":CopyOpInterface",
+        ":DestinationStyleOpInterface",
         ":FuncDialect",
         ":IR",
         ":InferTypeOpInterface",