[mlir] Add func-bufferize pass.
authorSean Silva <silvasean@google.com>
Mon, 26 Oct 2020 19:52:28 +0000 (12:52 -0700)
committerSean Silva <silvasean@google.com>
Mon, 2 Nov 2020 20:42:32 +0000 (12:42 -0800)
This is the most basic possible finalizing bufferization pass, which I
also think is sufficient for most new use cases. The more concentrated
nature of this pass also greatly clarifies the invariants that it
requires on its input to safely transform the program (see the
pass description in Passes.td).

With this pass, I have now upstreamed practically all of the
bufferizations from npcomp (the exception being std.constant, which can
be upstreamed when std.global_memref lands:
https://llvm.discourse.group/t/rfc-global-variables-in-mlir/2076/16 )

Differential Revision: https://reviews.llvm.org/D90205

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
mlir/include/mlir/Transforms/Bufferize.h
mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp [new file with mode: 0644]
mlir/lib/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/func-bufferize.mlir [new file with mode: 0644]

index 714acdf..76fa79a 100644 (file)
@@ -38,6 +38,9 @@ void populateStdBufferizePatterns(MLIRContext *context,
 /// Creates an instance of std bufferization pass.
 std::unique_ptr<Pass> createStdBufferizePass();
 
+/// Creates an instance of func bufferization pass.
+std::unique_ptr<Pass> createFuncBufferizePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
index 1ccef1d..b0b172c 100644 (file)
@@ -22,4 +22,33 @@ def StdBufferize : FunctionPass<"std-bufferize"> {
   let dependentDialects = ["scf::SCFDialect"];
 }
 
+def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
+  let summary = "Bufferize func/call/return ops";
+  let description = [{
+    A finalizing bufferize pass that bufferizes std.func and std.call ops.
+
+    Because this pass updates std.func ops, it must be a module pass. It is
+    useful to keep this pass separate from other bufferizations so that the
+    other ones can be run at function-level in parallel.
+
+    This pass must be done atomically for two reasons:
+    1. This pass changes func op signatures, which requires atomically updating
+       calls as well throughout the entire module.
+    2. This pass changes the type of block arguments, which requires that all
+       successor arguments of predecessors be converted. Terminators are not
+       a closed universe (and need not implement BranchOpInterface), and so we
+       cannot in general rewrite them.
+
+    Note, because this is a "finalizing" bufferize step, it can create
+    invalid IR because it will not create materializations. To avoid this
+    situation, the pass must only be run when the only SSA values of
+    tensor type are:
+    - block arguments
+    - the result of tensor_load
+    Other values of tensor type should be eliminated by earlier
+    bufferization passes.
+  }];
+  let constructor = "mlir::createFuncBufferizePass()";
+}
+
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
index e40c1bc..920eb6c 100644 (file)
@@ -150,9 +150,18 @@ private:
 /// This function should be called by all bufferization passes using
 /// BufferizeTypeConverter so that materializations work proprely. One exception
 /// is bufferization passes doing "full" conversions, where it can be desirable
-/// for even the materializations to remain illegal so that they are eliminated.
+/// for even the materializations to remain illegal so that they are eliminated,
+/// such as via the patterns in
+/// populateEliminateBufferizeMaterializationsPatterns.
 void populateBufferizeMaterializationLegality(ConversionTarget &target);
 
+/// Populate patterns to eliminate bufferize materializations.
+///
+/// In particular, these are the tensor_load/tensor_to_memref ops.
+void populateEliminateBufferizeMaterializationsPatterns(
+    MLIRContext *context, BufferizeTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns);
+
 /// Helper conversion pattern that encapsulates a BufferizeTypeConverter
 /// instance.
 template <typename SourceOp>
index aabb81c..1334e7f 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
   ExpandAtomic.cpp
   ExpandMemRefReshape.cpp
   ExpandTanh.cpp
+  FuncBufferize.cpp
   FuncConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
new file mode 100644 (file)
index 0000000..4aadb72
--- /dev/null
@@ -0,0 +1,56 @@
+//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of std.func's and std.call's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Transforms/Bufferize.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
+  void runOnOperation() override {
+    auto module = getOperation();
+    auto *context = &getContext();
+
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    populateFuncOpTypeConversionPattern(patterns, context, typeConverter);
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getType()) &&
+             typeConverter.isLegal(&op.getBody());
+    });
+    populateCallOpTypeConversionPattern(patterns, context, typeConverter);
+    populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
+                                                       patterns);
+    target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+
+    // If all result types are legal, and all block arguments are legal (ensured
+    // by func conversion above), then all types in the program are legal.
+    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+      return typeConverter.isLegal(op->getResultTypes());
+    });
+
+    if (failed(applyFullConversion(module, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createFuncBufferizePass() {
+  return std::make_unique<FuncBufferizePass>();
+}
index d1f9109..1564290 100644 (file)
@@ -76,6 +76,45 @@ void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
   target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
 }
 
+namespace {
+// In a finalizing bufferize conversion, we know that all tensors have been
+// converted to memrefs, thus, this op becomes an identity.
+class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TensorLoadOp::Adaptor adaptor(operands);
+    rewriter.replaceOp(op, adaptor.memref());
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+// In a finalizing bufferize conversion, we know that all tensors have been
+// converted to memrefs, thus, this op becomes an identity.
+class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TensorToMemrefOp::Adaptor adaptor(operands);
+    rewriter.replaceOp(op, adaptor.tensor());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateEliminateBufferizeMaterializationsPatterns(
+    MLIRContext *context, BufferizeTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns) {
+  patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
+      typeConverter, context);
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeFuncOpConverter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir
new file mode 100644 (file)
index 0000000..20af66c
--- /dev/null
@@ -0,0 +1,64 @@
+// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL:   func @identity(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           return %[[ARG]] : memref<f32>
+func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+// CHECK-LABEL:   func @block_arguments(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           br ^bb1(%[[ARG]] : memref<f32>)
+// CHECK:         ^bb1(%[[BBARG:.*]]: memref<f32>):
+// CHECK:           return %[[BBARG]] : memref<f32>
+func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
+  br ^bb1(%arg0: tensor<f32>)
+^bb1(%bbarg: tensor<f32>):
+  return %bbarg : tensor<f32>
+}
+
+// CHECK-LABEL:   func @eliminate_target_materialization(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           return %[[ARG]] : memref<f32>
+func @eliminate_target_materialization(%arg0: tensor<f32>) -> memref<f32> {
+  %0 = tensor_to_memref %arg0 : memref<f32>
+  return %0 : memref<f32>
+}
+
+// CHECK-LABEL:   func @eliminate_source_materialization(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           return %[[ARG]] : memref<f32>
+func @eliminate_source_materialization(%arg0: memref<f32>) -> tensor<f32> {
+  %0 = tensor_load %arg0 : memref<f32>
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL:   func @source() -> memref<f32>
+// CHECK-LABEL:   func @call_source() -> memref<f32> {
+// CHECK:           %[[RET:.*]] = call @source() : () -> memref<f32>
+// CHECK:           return %[[RET]] : memref<f32>
+func @source() -> tensor<f32>
+func @call_source() -> tensor<f32> {
+  %0 = call @source() : () -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL:   func @sink(memref<f32>)
+// CHECK-LABEL:   func @call_sink(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) {
+// CHECK:           call @sink(%[[ARG]]) : (memref<f32>) -> ()
+// CHECK:           return
+func @sink(tensor<f32>)
+func @call_sink(%arg0: tensor<f32>) {
+  call @sink(%arg0) : (tensor<f32>) -> ()
+  return
+}
+
+// -----
+
+func @failed_to_legalize() -> tensor<f32> {
+  // expected-error @+1 {{failed to legalize operation 'test.source'}}
+  %0 = "test.source"() : () -> (tensor<f32>)
+  return %0 : tensor<f32>
+}