From 8a232632c5269bbae736ca58eb760c9e784ee309 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 7 Dec 2021 20:07:13 +0900 Subject: [PATCH] [mlir][linalg][bufferize] Add FuncOp bufferization pass This passes bufferizes FuncOp bodies, but not FuncOp boundaries. Differential Revision: https://reviews.llvm.org/D114671 --- .../Linalg/comprehensive-function-bufferize.mlir | 69 ++++++++++++ .../Linalg/comprehensive-module-bufferize.mlir | 29 +++++ mlir/test/lib/Dialect/Linalg/CMakeLists.txt | 13 +++ .../Dialect/Linalg/TestComprehensiveBufferize.cpp | 124 +++++++++++++++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + .../llvm-project-overlay/mlir/test/BUILD.bazel | 12 ++ 6 files changed, 249 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir create mode 100644 mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir new file mode 100644 index 0000000..24f8de7 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s + +// Run fuzzer with different seeds. +// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null + +// CHECK-LABEL: func @use_tensor_func_arg( +// CHECK-SAME: %[[A:.*]]: tensor +func @use_tensor_func_arg(%A : tensor) -> (vector<4xf32>) { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]] + %0 = vector.transfer_read %A[%c0], %f0 : tensor, vector<4xf32> + + // CHECK: return %[[res]] + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @return_tensor( +// CHECK-SAME: %[[A:.*]]: tensor +func @return_tensor(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = tensor.dim %[[A]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: memref.copy %[[A_memref]], %[[casted]] + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] + // CHECK: return %[[res_tensor]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @func_without_tensor_args +func @func_without_tensor_args(%v : vector<10xf32>) -> () { + // CHECK: %[[alloc:.*]] = memref.alloc() + %0 = linalg.init_tensor[10] : tensor<10xf32> + + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32> + + %cst = arith.constant 0.0 : f32 + // CHECK: vector.transfer_read %[[alloc]] + %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32> + + vector.print %r : vector<11xf32> + return +} + +// ----- + +// CHECK-LABEL: func private @private_func +func private @private_func(tensor) -> () + +// CHECK-LABEL: func @empty_func() +func @empty_func() -> () { + return +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir index 2cda811..7f908b9 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -979,3 +979,32 @@ func @equivalent_func_arg_2(%t0: tensor {linalg.inplaceable = true}, } return %1: tensor } + +// ----- + +// CHECK-LABEL: func @func_without_tensor_args +func @func_without_tensor_args(%v : vector<10xf32>) -> () { + // CHECK: %[[alloc:.*]] = memref.alloc() + %0 = linalg.init_tensor[10] : tensor<10xf32> + + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32> + + %cst = arith.constant 0.0 : f32 + // CHECK: vector.transfer_read %[[alloc]] + %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32> + + vector.print %r : vector<11xf32> + return +} + +// ----- + +// CHECK-LABEL: func private @private_func +func private @private_func(tensor) -> () + +// CHECK-LABEL: func @empty_func() +func @empty_func() -> () { + return +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt index 6bd45c8..440d62e 100644 --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses + TestComprehensiveBufferize.cpp TestConvVectorization.cpp TestLinalgCodegenStrategy.cpp TestLinalgDistribution.cpp @@ -12,13 +13,25 @@ add_mlir_library(MLIRLinalgTestPasses LINK_LIBS PUBLIC MLIRAffine + MLIRAffineBufferizableOpInterfaceImpl + MLIRArithBufferizableOpInterfaceImpl + MLIRArithmetic + MLIRBufferizableOpInterface + MLIRComprehensiveBufferize MLIRGPUTransforms MLIRLinalg + MLIRLinalgBufferizableOpInterfaceImpl MLIRLinalgTransforms MLIRLLVMToLLVMIRTranslation + MLIRMemRef MLIRPass + MLIRSCF + MLIRSCFBufferizableOpInterfaceImpl MLIRStandard + MLIRTensor + MLIRTensorBufferizableOpInterfaceImpl MLIRTransformUtils MLIRVector + MLIRVectorBufferizableOpInterfaceImpl MLIRVectorToSCF ) diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp new file mode 100644 index 0000000..5ac15a9 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -0,0 +1,124 @@ +//===- TestComprehensiveBufferize.cpp - Test Comprehensive Bufferize ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing Comprehensive Bufferize. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::linalg::comprehensive_bufferize; + +namespace { +/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are +/// mostly identical. +struct TestComprehensiveFunctionBufferize + : public PassWrapper { + StringRef getArgument() const final { + return "test-comprehensive-function-bufferize"; + } + + StringRef getDescription() const final { + return "Test Comprehensive Bufferize of FuncOps (body only)."; + } + + TestComprehensiveFunctionBufferize() = default; + TestComprehensiveFunctionBufferize( + const TestComprehensiveFunctionBufferize &pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + affine_ext::registerBufferizableOpInterfaceExternalModels(registry); + arith_ext::registerBufferizableOpInterfaceExternalModels(registry); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + vector_ext::registerBufferizableOpInterfaceExternalModels(registry); + } + + void runOnFunction() override; + + Option allowReturnMemref{ + *this, "allow-return-memref", + llvm::cl::desc("Allow returning/yielding memrefs from functions/blocks"), + llvm::cl::init(false)}; + Option allowUnknownOps{ + *this, "allow-unknown-ops", + llvm::cl::desc( + "Allows the return of memrefs (for testing purposes only)"), + llvm::cl::init(false)}; + Option testAnalysisOnly{ + *this, "test-analysis-only", + llvm::cl::desc( + "Only runs inplaceability analysis (for testing purposes only)"), + llvm::cl::init(false)}; + Option analysisFuzzerSeed{ + *this, "analysis-fuzzer-seed", + llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"), + llvm::cl::init(0)}; +}; +} // namespace + +void TestComprehensiveFunctionBufferize::runOnFunction() { + BufferizationOptions options; + + // Enable InitTensorOp elimination. + options.addPostAnalysisStep< + linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + // TODO: Find a way to enable this step automatically when bufferizing + // tensor dialect ops. + options.addPostAnalysisStep(); + options.addPostAnalysisStep(); + + options.allowReturnMemref = allowReturnMemref; + options.allowUnknownOps = allowUnknownOps; + options.testAnalysisOnly = testAnalysisOnly; + options.analysisFuzzerSeed = analysisFuzzerSeed; + + Operation *op = getFunction().getOperation(); + if (failed(runComprehensiveBufferize(op, options))) + return; + + OpPassManager cleanupPipeline("builtin.func"); + cleanupPipeline.addPass(createCanonicalizerPass()); + cleanupPipeline.addPass(createCSEPass()); + cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); + (void)this->runPipeline(cleanupPipeline, op); +} + +namespace mlir { +namespace test { +void registerTestComprehensiveFunctionBufferize() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 6b77c37..72333a9 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -64,6 +64,7 @@ void registerTestAffineLoopParametricTilingPass(); void registerTestAliasAnalysisPass(); void registerTestBuiltinAttributeInterfaces(); void registerTestCallGraphPass(); +void registerTestComprehensiveFunctionBufferize(); void registerTestConstantFold(); void registerTestConvVectorization(); void registerTestGpuSerializeToCubinPass(); @@ -159,6 +160,7 @@ void registerTestPasses() { #if MLIR_ROCM_CONVERSIONS_ENABLED mlir::test::registerTestGpuSerializeToHsacoPass(); #endif + mlir::test::registerTestComprehensiveFunctionBufferize(); mlir::test::registerTestConvVectorization(); mlir::test::registerTestDecomposeCallGraphTypes(); mlir::test::registerTestDataLayoutQuery(); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index b0ed24a..c69a315 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -381,15 +381,27 @@ cc_library( deps = [ "//llvm:Support", "//mlir:Affine", + "//mlir:AffineBufferizableOpInterfaceImpl", + "//mlir:ArithBufferizableOpInterfaceImpl", "//mlir:ArithmeticDialect", + "//mlir:BufferizableOpInterface", + "//mlir:BufferizationDialect", + "//mlir:ComprehensiveBufferize", "//mlir:GPUDialect", "//mlir:IR", + "//mlir:LinalgBufferizableOpInterfaceImpl", "//mlir:LinalgOps", "//mlir:LinalgTransforms", + "//mlir:MemRefDialect", "//mlir:Pass", + "//mlir:SCFBufferizableOpInterfaceImpl", + "//mlir:SCFDialect", "//mlir:SCFTransforms", "//mlir:StandardOps", + "//mlir:TensorBufferizableOpInterfaceImpl", + "//mlir:TensorDialect", "//mlir:TransformUtils", + "//mlir:VectorBufferizableOpInterfaceImpl", "//mlir:VectorOps", "//mlir:VectorToSCF", ], -- 2.7.4