[mlir][linalg][bufferize][NFC] Move helper function to op interface
authorMatthias Springer <springerm@google.com>
Tue, 23 Nov 2021 02:20:27 +0000 (11:20 +0900)
committerMatthias Springer <springerm@google.com>
Tue, 23 Nov 2021 02:59:47 +0000 (11:59 +0900)
This is in preparation of changing the op traversal during bufferization.

Differential Revision: https://reviews.llvm.org/D114040

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 7ef016f..491f8a5 100644 (file)
@@ -297,6 +297,11 @@ struct BufferizationState {
 /// bufferization is necessary.
 Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
 
+/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
+/// function returns immediately. Otherwise, it calls the `bufferize` interface
+/// method of `BufferizableOpInterface`.
+LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
+
 /// PostAnalysisSteps can be registered with `BufferizationOptions` and are
 /// executed after the analysis, but before bufferization. They can be used
 /// implement custom dialect-specific optimizations.
index 6db6ba6..a569eac 100644 (file)
@@ -24,9 +24,6 @@ static constexpr int64_t kBufferAlignments = 128;
 /// Return default allocation callbacks.
 std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
 
-/// Bufferize one particular op.
-LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
-
 /// Register external models implemented for the `BufferizableOpInterface`.
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 
index 630415b..3897734 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -390,6 +391,31 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
   return operandBuffer;
 }
 
+LogicalResult
+mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
+                                                   BufferizationState &state) {
+  OpBuilder b(op->getContext());
+
+  // Skip BufferCast and TensorLoad ops.
+  if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
+    return success();
+
+  // Check if op has tensor results or operands.
+  auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
+  bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
+  bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
+  if (!hasTensorResult && !hasTensorOperand)
+    return success();
+
+  // Bufferize using `BufferizableOpInterface`.
+  b.setInsertionPoint(op);
+  if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+    return bufferizableOp.bufferize(b, state);
+
+  // Other op with tensors. No bufferization method specified.
+  return op->emitError() << "unsupported op with tensors";
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//
index e0b7588..e556534 100644 (file)
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRMemRef
 )
 
 add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
index ea9309e..cf7ee07 100644 (file)
@@ -927,30 +927,6 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
 // Bufferization entry-point for functions.
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
-                                                   BufferizationState &state) {
-  OpBuilder b(op->getContext());
-
-  // Skip BufferCast and TensorLoad ops.
-  if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
-    return success();
-
-  // Check if op has tensor results or operands.
-  auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
-  bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
-  bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
-  if (!hasTensorResult && !hasTensorOperand)
-    return success();
-
-  // Bufferize using `BufferizableOpInterface`.
-  if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
-    return bufferizableOp.bufferize(b, state);
-
-  // Other op with tensors. No bufferization method specified.
-  return op->emitError() << "unsupported op with tensors";
-}
-
 static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp,
                                               BufferizationState &state) {
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
index b29c59e..499203e 100644 (file)
@@ -6299,6 +6299,7 @@ cc_library(
     deps = [
         ":BufferizableOpInterfaceIncGen",
         ":IR",
+        ":MemRefDialect",
         ":Support",
         "//llvm:Support",
     ],