[mlir][vector][bufferize] Implement DestinationStyleOpInterface on TransferWriteOp
authorMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 08:59:52 +0000 (10:59 +0200)
committerMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 09:02:19 +0000 (11:02 +0200)
This simplifies the BufferizableOpInterface implementation of vector.transfer_write.

Differential Revision: https://reviews.llvm.org/D136348

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/CMakeLists.txt
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 661affd..5250c6b 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
index 2b3ed46..7e35fd6 100644 (file)
@@ -16,6 +16,7 @@
 include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
@@ -1270,7 +1271,8 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-      AttrSizedOperandSegments
+      AttrSizedOperandSegments,
+      DestinationStyleOpInterface
   ]>,
     Arguments<(ins AnyVectorOfAnyRank:$vector,
                    AnyShaped:$source,
@@ -1393,6 +1395,10 @@ def Vector_TransferWriteOp :
     /// This method is added to maintain uniformity with load/store
     ///  ops of other dialects.
     Value getValue() { return getVector(); }
+
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      return {1, 2};  // `source` operand
+    }
   }];
 
   let hasFolder = 1;
index d71ea24..05f63cf 100644 (file)
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRArithDialect
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
+  MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRIR
   MLIRMaskingInterfaces
index 3ec95ed..b43114b 100644 (file)
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
@@ -63,35 +64,12 @@ struct TransferReadOpInterface
 
 /// Bufferization of vector.transfer_write. Replace with a new
 /// vector.transfer_write that operates on a memref.
+///
+/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
+/// implementations for DestinationStyle ops.
 struct TransferWriteOpInterface
-    : public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
-                                                    vector::TransferWriteOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
-                              const AnalysisState &state) const {
-    assert(opOperand.get().getType().isa<TensorType>() &&
-           "only tensor types expected");
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
-                               const AnalysisState &state) const {
-    assert(opOperand.get().getType().isa<TensorType>() &&
-           "only tensor types expected");
-    return true;
-  }
-
-  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
-                                            const AnalysisState &state) const {
-    assert(opOperand.get().getType().isa<TensorType>() &&
-           "only tensor types expected");
-    return {op->getOpResult(0)};
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const AnalysisState &state) const {
-    return BufferRelation::Equivalent;
-  }
-
+    : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
+                                                     vector::TransferWriteOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto writeOp = cast<vector::TransferWriteOp>(op);
index 4c75554..e60e712 100644 (file)
@@ -3246,6 +3246,7 @@ cc_library(
         ":ArithDialect",
         ":ArithUtils",
         ":ControlFlowInterfaces",
+        ":DestinationStyleOpInterface",
         ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",
@@ -8211,6 +8212,7 @@ td_library(
     includes = ["include"],
     deps = [
         ":ControlFlowInterfacesTdFiles",
+        ":DestinationStyleOpInterfaceTdFiles",
         ":InferTypeOpInterfaceTdFiles",
         ":MaskingInterfacesTdFiles",
         ":OpBaseTdFiles",