[mlir][linalg][bufferize] Add FuncOp bufferization pass
authorMatthias Springer <springerm@google.com>
Tue, 7 Dec 2021 11:07:13 +0000 (20:07 +0900)
committerMatthias Springer <springerm@google.com>
Tue, 7 Dec 2021 12:44:26 +0000 (21:44 +0900)
This passes bufferizes FuncOp bodies, but not FuncOp boundaries.

Differential Revision: https://reviews.llvm.org/D114671

mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir [new file with mode: 0644]
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
new file mode 100644 (file)
index 0000000..24f8de7
--- /dev/null
@@ -0,0 +1,69 @@
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s
+
+// Run fuzzer with different seeds.
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
+
+// CHECK-LABEL: func @use_tensor_func_arg(
+//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>
+func @use_tensor_func_arg(%A : tensor<?xf32>) -> (vector<4xf32>) {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+  // CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]]
+  %0 = vector.transfer_read %A[%c0], %f0 : tensor<?xf32>, vector<4xf32>
+
+  // CHECK: return %[[res]]
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @return_tensor(
+//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>
+func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+
+  // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+  // CHECK: %[[dim:.*]] = tensor.dim %[[A]]
+  // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK: memref.copy %[[A_memref]], %[[casted]]
+  // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+  %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
+  // CHECK: return %[[res_tensor]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @func_without_tensor_args
+func @func_without_tensor_args(%v : vector<10xf32>) -> () {
+  // CHECK: %[[alloc:.*]] = memref.alloc()
+  %0 = linalg.init_tensor[10] : tensor<10xf32>
+
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+  %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>
+
+  %cst = arith.constant 0.0 : f32
+  // CHECK: vector.transfer_read %[[alloc]]
+  %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>
+
+  vector.print %r : vector<11xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func private @private_func
+func private @private_func(tensor<?xf32>) -> ()
+
+// CHECK-LABEL: func @empty_func()
+func @empty_func() -> () {
+  return
+}
index 2cda811..7f908b9 100644 (file)
@@ -979,3 +979,32 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
   }
   return %1: tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @func_without_tensor_args
+func @func_without_tensor_args(%v : vector<10xf32>) -> () {
+  // CHECK: %[[alloc:.*]] = memref.alloc()
+  %0 = linalg.init_tensor[10] : tensor<10xf32>
+
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+  %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>
+
+  %cst = arith.constant 0.0 : f32
+  // CHECK: vector.transfer_read %[[alloc]]
+  %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>
+
+  vector.print %r : vector<11xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func private @private_func
+func private @private_func(tensor<?xf32>) -> ()
+
+// CHECK-LABEL: func @empty_func()
+func @empty_func() -> () {
+  return
+}
index 6bd45c8..440d62e 100644 (file)
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRLinalgTestPasses
+  TestComprehensiveBufferize.cpp
   TestConvVectorization.cpp
   TestLinalgCodegenStrategy.cpp
   TestLinalgDistribution.cpp
@@ -12,13 +13,25 @@ add_mlir_library(MLIRLinalgTestPasses
 
   LINK_LIBS PUBLIC
   MLIRAffine
+  MLIRAffineBufferizableOpInterfaceImpl
+  MLIRArithBufferizableOpInterfaceImpl
+  MLIRArithmetic
+  MLIRBufferizableOpInterface
+  MLIRComprehensiveBufferize
   MLIRGPUTransforms
   MLIRLinalg
+  MLIRLinalgBufferizableOpInterfaceImpl
   MLIRLinalgTransforms
   MLIRLLVMToLLVMIRTranslation
+  MLIRMemRef
   MLIRPass
+  MLIRSCF
+  MLIRSCFBufferizableOpInterfaceImpl
   MLIRStandard
+  MLIRTensor
+  MLIRTensorBufferizableOpInterfaceImpl
   MLIRTransformUtils
   MLIRVector
+  MLIRVectorBufferizableOpInterfaceImpl
   MLIRVectorToSCF
   )
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
new file mode 100644 (file)
index 0000000..5ac15a9
--- /dev/null
@@ -0,0 +1,124 @@
+//===- TestComprehensiveBufferize.cpp - Test Comprehensive Bufferize ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing Comprehensive Bufferize.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::linalg::comprehensive_bufferize;
+
+namespace {
+/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are
+/// mostly identical.
+struct TestComprehensiveFunctionBufferize
+    : public PassWrapper<TestComprehensiveFunctionBufferize, FunctionPass> {
+  StringRef getArgument() const final {
+    return "test-comprehensive-function-bufferize";
+  }
+
+  StringRef getDescription() const final {
+    return "Test Comprehensive Bufferize of FuncOps (body only).";
+  }
+
+  TestComprehensiveFunctionBufferize() = default;
+  TestComprehensiveFunctionBufferize(
+      const TestComprehensiveFunctionBufferize &pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
+                    memref::MemRefDialect, tensor::TensorDialect,
+                    vector::VectorDialect, scf::SCFDialect,
+                    arith::ArithmeticDialect, AffineDialect>();
+    affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
+  }
+
+  void runOnFunction() override;
+
+  Option<bool> allowReturnMemref{
+      *this, "allow-return-memref",
+      llvm::cl::desc("Allow returning/yielding memrefs from functions/blocks"),
+      llvm::cl::init(false)};
+  Option<bool> allowUnknownOps{
+      *this, "allow-unknown-ops",
+      llvm::cl::desc(
+          "Allows the return of memrefs (for testing purposes only)"),
+      llvm::cl::init(false)};
+  Option<bool> testAnalysisOnly{
+      *this, "test-analysis-only",
+      llvm::cl::desc(
+          "Only runs inplaceability analysis (for testing purposes only)"),
+      llvm::cl::init(false)};
+  Option<unsigned> analysisFuzzerSeed{
+      *this, "analysis-fuzzer-seed",
+      llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"),
+      llvm::cl::init(0)};
+};
+} // namespace
+
+void TestComprehensiveFunctionBufferize::runOnFunction() {
+  BufferizationOptions options;
+
+  // Enable InitTensorOp elimination.
+  options.addPostAnalysisStep<
+      linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
+  // TODO: Find a way to enable this step automatically when bufferizing
+  // tensor dialect ops.
+  options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+  options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+
+  options.allowReturnMemref = allowReturnMemref;
+  options.allowUnknownOps = allowUnknownOps;
+  options.testAnalysisOnly = testAnalysisOnly;
+  options.analysisFuzzerSeed = analysisFuzzerSeed;
+
+  Operation *op = getFunction().getOperation();
+  if (failed(runComprehensiveBufferize(op, options)))
+    return;
+
+  OpPassManager cleanupPipeline("builtin.func");
+  cleanupPipeline.addPass(createCanonicalizerPass());
+  cleanupPipeline.addPass(createCSEPass());
+  cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
+  (void)this->runPipeline(cleanupPipeline, op);
+}
+
+namespace mlir {
+namespace test {
+void registerTestComprehensiveFunctionBufferize() {
+  PassRegistration<TestComprehensiveFunctionBufferize>();
+}
+} // namespace test
+} // namespace mlir
index 6b77c37..72333a9 100644 (file)
@@ -64,6 +64,7 @@ void registerTestAffineLoopParametricTilingPass();
 void registerTestAliasAnalysisPass();
 void registerTestBuiltinAttributeInterfaces();
 void registerTestCallGraphPass();
+void registerTestComprehensiveFunctionBufferize();
 void registerTestConstantFold();
 void registerTestConvVectorization();
 void registerTestGpuSerializeToCubinPass();
@@ -159,6 +160,7 @@ void registerTestPasses() {
 #if MLIR_ROCM_CONVERSIONS_ENABLED
   mlir::test::registerTestGpuSerializeToHsacoPass();
 #endif
+  mlir::test::registerTestComprehensiveFunctionBufferize();
   mlir::test::registerTestConvVectorization();
   mlir::test::registerTestDecomposeCallGraphTypes();
   mlir::test::registerTestDataLayoutQuery();
index b0ed24a..c69a315 100644 (file)
@@ -381,15 +381,27 @@ cc_library(
     deps = [
         "//llvm:Support",
         "//mlir:Affine",
+        "//mlir:AffineBufferizableOpInterfaceImpl",
+        "//mlir:ArithBufferizableOpInterfaceImpl",
         "//mlir:ArithmeticDialect",
+        "//mlir:BufferizableOpInterface",
+        "//mlir:BufferizationDialect",
+        "//mlir:ComprehensiveBufferize",
         "//mlir:GPUDialect",
         "//mlir:IR",
+        "//mlir:LinalgBufferizableOpInterfaceImpl",
         "//mlir:LinalgOps",
         "//mlir:LinalgTransforms",
+        "//mlir:MemRefDialect",
         "//mlir:Pass",
+        "//mlir:SCFBufferizableOpInterfaceImpl",
+        "//mlir:SCFDialect",
         "//mlir:SCFTransforms",
         "//mlir:StandardOps",
+        "//mlir:TensorBufferizableOpInterfaceImpl",
+        "//mlir:TensorDialect",
         "//mlir:TransformUtils",
+        "//mlir:VectorBufferizableOpInterfaceImpl",
         "//mlir:VectorOps",
         "//mlir:VectorToSCF",
     ],