This passes bufferizes FuncOp bodies, but not FuncOp boundaries.
Differential Revision: https://reviews.llvm.org/D114671
--- /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s
+
+// Run fuzzer with different seeds.
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
+
+// CHECK-LABEL: func @use_tensor_func_arg(
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
+func @use_tensor_func_arg(%A : tensor<?xf32>) -> (vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+ // CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]]
+ %0 = vector.transfer_read %A[%c0], %f0 : tensor<?xf32>, vector<4xf32>
+
+ // CHECK: return %[[res]]
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @return_tensor(
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
+func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+ // CHECK: %[[dim:.*]] = tensor.dim %[[A]]
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK: memref.copy %[[A_memref]], %[[casted]]
+ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
+ // CHECK: return %[[res_tensor]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @func_without_tensor_args
+func @func_without_tensor_args(%v : vector<10xf32>) -> () {
+ // CHECK: %[[alloc:.*]] = memref.alloc()
+ %0 = linalg.init_tensor[10] : tensor<10xf32>
+
+ %c0 = arith.constant 0 : index
+ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+ %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>
+
+ %cst = arith.constant 0.0 : f32
+ // CHECK: vector.transfer_read %[[alloc]]
+ %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>
+
+ vector.print %r : vector<11xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func private @private_func
+func private @private_func(tensor<?xf32>) -> ()
+
+// CHECK-LABEL: func @empty_func()
+func @empty_func() -> () {
+ return
+}
}
return %1: tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @func_without_tensor_args
+func @func_without_tensor_args(%v : vector<10xf32>) -> () {
+ // CHECK: %[[alloc:.*]] = memref.alloc()
+ %0 = linalg.init_tensor[10] : tensor<10xf32>
+
+ %c0 = arith.constant 0 : index
+ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+ %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>
+
+ %cst = arith.constant 0.0 : f32
+ // CHECK: vector.transfer_read %[[alloc]]
+ %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>
+
+ vector.print %r : vector<11xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func private @private_func
+func private @private_func(tensor<?xf32>) -> ()
+
+// CHECK-LABEL: func @empty_func()
+func @empty_func() -> () {
+ return
+}
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLinalgTestPasses
+ TestComprehensiveBufferize.cpp
TestConvVectorization.cpp
TestLinalgCodegenStrategy.cpp
TestLinalgDistribution.cpp
LINK_LIBS PUBLIC
MLIRAffine
+ MLIRAffineBufferizableOpInterfaceImpl
+ MLIRArithBufferizableOpInterfaceImpl
+ MLIRArithmetic
+ MLIRBufferizableOpInterface
+ MLIRComprehensiveBufferize
MLIRGPUTransforms
MLIRLinalg
+ MLIRLinalgBufferizableOpInterfaceImpl
MLIRLinalgTransforms
MLIRLLVMToLLVMIRTranslation
+ MLIRMemRef
MLIRPass
+ MLIRSCF
+ MLIRSCFBufferizableOpInterfaceImpl
MLIRStandard
+ MLIRTensor
+ MLIRTensorBufferizableOpInterfaceImpl
MLIRTransformUtils
MLIRVector
+ MLIRVectorBufferizableOpInterfaceImpl
MLIRVectorToSCF
)
--- /dev/null
+//===- TestComprehensiveBufferize.cpp - Test Comprehensive Bufferize ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing Comprehensive Bufferize.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::linalg::comprehensive_bufferize;
+
+namespace {
+/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are
+/// mostly identical.
+struct TestComprehensiveFunctionBufferize
+ : public PassWrapper<TestComprehensiveFunctionBufferize, FunctionPass> {
+ StringRef getArgument() const final {
+ return "test-comprehensive-function-bufferize";
+ }
+
+ StringRef getDescription() const final {
+ return "Test Comprehensive Bufferize of FuncOps (body only).";
+ }
+
+ TestComprehensiveFunctionBufferize() = default;
+ TestComprehensiveFunctionBufferize(
+ const TestComprehensiveFunctionBufferize &pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, tensor::TensorDialect,
+ vector::VectorDialect, scf::SCFDialect,
+ arith::ArithmeticDialect, AffineDialect>();
+ affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ }
+
+ void runOnFunction() override;
+
+ Option<bool> allowReturnMemref{
+ *this, "allow-return-memref",
+ llvm::cl::desc("Allow returning/yielding memrefs from functions/blocks"),
+ llvm::cl::init(false)};
+ Option<bool> allowUnknownOps{
+ *this, "allow-unknown-ops",
+ llvm::cl::desc(
+ "Allows the return of memrefs (for testing purposes only)"),
+ llvm::cl::init(false)};
+ Option<bool> testAnalysisOnly{
+ *this, "test-analysis-only",
+ llvm::cl::desc(
+ "Only runs inplaceability analysis (for testing purposes only)"),
+ llvm::cl::init(false)};
+ Option<unsigned> analysisFuzzerSeed{
+ *this, "analysis-fuzzer-seed",
+ llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"),
+ llvm::cl::init(0)};
+};
+} // namespace
+
+void TestComprehensiveFunctionBufferize::runOnFunction() {
+ BufferizationOptions options;
+
+ // Enable InitTensorOp elimination.
+ options.addPostAnalysisStep<
+ linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
+ // TODO: Find a way to enable this step automatically when bufferizing
+ // tensor dialect ops.
+ options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+ options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+
+ options.allowReturnMemref = allowReturnMemref;
+ options.allowUnknownOps = allowUnknownOps;
+ options.testAnalysisOnly = testAnalysisOnly;
+ options.analysisFuzzerSeed = analysisFuzzerSeed;
+
+ Operation *op = getFunction().getOperation();
+ if (failed(runComprehensiveBufferize(op, options)))
+ return;
+
+ OpPassManager cleanupPipeline("builtin.func");
+ cleanupPipeline.addPass(createCanonicalizerPass());
+ cleanupPipeline.addPass(createCSEPass());
+ cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
+ (void)this->runPipeline(cleanupPipeline, op);
+}
+
+namespace mlir {
+namespace test {
+void registerTestComprehensiveFunctionBufferize() {
+ PassRegistration<TestComprehensiveFunctionBufferize>();
+}
+} // namespace test
+} // namespace mlir
void registerTestAliasAnalysisPass();
void registerTestBuiltinAttributeInterfaces();
void registerTestCallGraphPass();
+void registerTestComprehensiveFunctionBufferize();
void registerTestConstantFold();
void registerTestConvVectorization();
void registerTestGpuSerializeToCubinPass();
#if MLIR_ROCM_CONVERSIONS_ENABLED
mlir::test::registerTestGpuSerializeToHsacoPass();
#endif
+ mlir::test::registerTestComprehensiveFunctionBufferize();
mlir::test::registerTestConvVectorization();
mlir::test::registerTestDecomposeCallGraphTypes();
mlir::test::registerTestDataLayoutQuery();
deps = [
"//llvm:Support",
"//mlir:Affine",
+ "//mlir:AffineBufferizableOpInterfaceImpl",
+ "//mlir:ArithBufferizableOpInterfaceImpl",
"//mlir:ArithmeticDialect",
+ "//mlir:BufferizableOpInterface",
+ "//mlir:BufferizationDialect",
+ "//mlir:ComprehensiveBufferize",
"//mlir:GPUDialect",
"//mlir:IR",
+ "//mlir:LinalgBufferizableOpInterfaceImpl",
"//mlir:LinalgOps",
"//mlir:LinalgTransforms",
+ "//mlir:MemRefDialect",
"//mlir:Pass",
+ "//mlir:SCFBufferizableOpInterfaceImpl",
+ "//mlir:SCFDialect",
"//mlir:SCFTransforms",
"//mlir:StandardOps",
+ "//mlir:TensorBufferizableOpInterfaceImpl",
+ "//mlir:TensorDialect",
"//mlir:TransformUtils",
+ "//mlir:VectorBufferizableOpInterfaceImpl",
"//mlir:VectorOps",
"//mlir:VectorToSCF",
],