[mlir][shape] Migrate bufferization to BufferizableOpInterface
authorMatthias Springer <springerm@google.com>
Mon, 7 Mar 2022 12:27:53 +0000 (21:27 +0900)
committerMatthias Springer <springerm@google.com>
Mon, 7 Mar 2022 12:54:27 +0000 (21:54 +0900)
Differential Revision: https://reviews.llvm.org/D121043

mlir/include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp [new file with mode: 0644]
mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp [deleted file]
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644 (file)
index 0000000..bd89869
--- /dev/null
@@ -0,0 +1,20 @@
+//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
+#define MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace shape {
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace shape
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
index f2d69ac..157e1e2 100644 (file)
@@ -40,21 +40,6 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
 void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
 std::unique_ptr<OperationPass<FuncOp>> createRemoveShapeConstraintsPass();
 
-/// Populates patterns for shape dialect structural type conversions and sets up
-/// the provided ConversionTarget with the appropriate legality configuration
-/// for the ops to get converted properly.
-///
-/// A "structural" type conversion is one where the underlying ops are
-/// completely agnostic to the actual types involved and simply need to update
-/// their types consistently. An example of this is shape.assuming -- the
-/// shape.assuming op and the corresponding shape.assuming_yield op need to have
-/// consistent types, but the exact types don't matter. So all that we need to
-/// do for a structural type conversion is to update both of their types
-/// consistently to the new types prescribed by the TypeConverter.
-void populateShapeStructuralTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target);
-
 // Bufferizes shape dialect ops.
 //
 // Note that most shape dialect ops must be converted to std before
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644 (file)
index 0000000..9e21e0b
--- /dev/null
@@ -0,0 +1,169 @@
+//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::shape;
+
+namespace mlir {
+namespace shape {
+namespace {
+
+/// Bufferization of shape.assuming.
+struct AssumingOpInterface
+    : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
+                                                    shape::AssumingOp> {
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       const BufferizationState &state) const {
+    // AssumingOps do not have tensor OpOperands. The yielded value can be any
+    // SSA value that is in scope. To allow for use-def chain traversal through
+    // AssumingOps in the analysis, the corresponding yield value is considered
+    // to be aliasing with the result.
+    auto assumingOp = cast<shape::AssumingOp>(op);
+    size_t resultNum = std::distance(op->getOpResults().begin(),
+                                     llvm::find(op->getOpResults(), opResult));
+    // TODO: Support multiple blocks.
+    assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
+           "expected exactly 1 block");
+    auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
+        assumingOp.getDoRegion().front().getTerminator());
+    assert(yieldOp && "expected shape.assuming_yield terminator");
+    return {&yieldOp->getOpOperand(resultNum)};
+  }
+
+  // TODO: For better bufferization results, this could return `true` only if
+  // there is a memory write in the region.
+  bool isMemoryWrite(Operation *op, OpResult opResult,
+                     const BufferizationState &state) const {
+    // Similar to scf.if, results of this op are always considered memory writes
+    // in the analysis. This is a useful pattern for all ops that have tensor
+    // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
+    // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
+    // ops without OpOperands.
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    auto assumingOp = cast<shape::AssumingOp>(op);
+
+    // Compute new result types.
+    SmallVector<Type> newResultTypes;
+    for (Type type : assumingOp->getResultTypes()) {
+      if (auto tensorType = type.dyn_cast<TensorType>()) {
+        newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
+      } else {
+        newResultTypes.push_back(type);
+      }
+    }
+
+    // Create new op and move over region.
+    auto newOp = rewriter.create<shape::AssumingOp>(
+        op->getLoc(), newResultTypes, assumingOp.getWitness());
+    newOp.getDoRegion().takeBody(assumingOp.getRegion());
+
+    // Update terminator.
+    assert(newOp.getDoRegion().getBlocks().size() == 1 &&
+           "only 1 block supported");
+    Block *newBlock = &newOp.getDoRegion().front();
+    auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator());
+    rewriter.setInsertionPoint(yieldOp);
+    SmallVector<Value> newYieldValues;
+    for (const auto &it : llvm::enumerate(yieldOp.operands())) {
+      Value val = it.value();
+      if (val.getType().isa<TensorType>()) {
+        newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
+            yieldOp.getLoc(), newResultTypes[it.index()], val));
+      } else {
+        newYieldValues.push_back(val);
+      }
+    }
+    rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp,
+                                                        newYieldValues);
+
+    // Update all uses of the old op.
+    rewriter.setInsertionPointAfter(newOp);
+    SmallVector<Value> newResults;
+    for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
+      if (it.value().isa<TensorType>()) {
+        newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
+            assumingOp.getLoc(), newOp->getResult(it.index())));
+      } else {
+        newResults.push_back(newOp->getResult(it.index()));
+      }
+    }
+
+    // Replace old op.
+    rewriter.replaceOp(assumingOp, newResults);
+
+    return success();
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationState &state) const {
+    return BufferRelation::Equivalent;
+  }
+};
+
+/// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
+/// ops, so this is for analysis only.
+struct AssumingYieldOpInterface
+    : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
+                                                    shape::AssumingOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const BufferizationState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const BufferizationState &state) const {
+    return false;
+  }
+
+  SmallVector<OpResult>
+  getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                      const BufferizationState &state) const {
+    assert(isa<shape::AssumingOp>(op->getParentOp()) &&
+           "expected that parent is an AssumingOp");
+    return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
+  }
+
+  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
+                            const BufferizationState &state) const {
+    // Yield operands always bufferize inplace. Otherwise, an alloc + copy
+    // may be generated inside the block. We should not return/yield allocations
+    // when possible.
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationState &state) const {
+    // Op is bufferized as part of AssumingOp.
+    return failure();
+  }
+};
+
+} // namespace
+} // namespace shape
+} // namespace mlir
+
+void mlir::shape::registerBufferizableOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addOpInterface<shape::AssumingOp, AssumingOpInterface>();
+  registry.addOpInterface<shape::AssumingYieldOp, AssumingYieldOpInterface>();
+}
index 7b712dc..48c8202 100644 (file)
@@ -8,30 +8,32 @@
 
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "PassDetail.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Pass/Pass.h"
 
 using namespace mlir;
+using namespace bufferization;
 
 namespace {
 struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
   void runOnOperation() override {
-    MLIRContext &ctx = getContext();
+    BufferizationOptions options = getPartialBufferizationOptions();
+    options.allowDialectInFilter<shape::ShapeDialect>();
 
-    RewritePatternSet patterns(&ctx);
-    bufferization::BufferizeTypeConverter typeConverter;
-    ConversionTarget target(ctx);
-
-    bufferization::populateBufferizeMaterializationLegality(target);
-    populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
-                                                      target);
-
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();
   }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    shape::ShapeDialect>();
+    shape::registerBufferizableOpInterfaceExternalModels(registry);
+  }
 };
 } // namespace
 
index 3402085..c3d20e1 100644 (file)
@@ -1,8 +1,8 @@
 add_mlir_dialect_library(MLIRShapeOpsTransforms
+  BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   RemoveShapeConstraints.cpp
   ShapeToShapeLowering.cpp
-  StructuralTypeConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
 target_link_libraries(MLIRShapeOpsTransforms
   PUBLIC
   MLIRArithmetic
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRIR
   MLIRMemRef
diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
deleted file mode 100644 (file)
index e368eca..0000000
+++ /dev/null
@@ -1,70 +0,0 @@
-//===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/Dialect/Shape/Transforms/Passes.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-using namespace mlir;
-using namespace mlir::shape;
-
-namespace {
-class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(AssumingOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    SmallVector<Type, 2> newResultTypes;
-    newResultTypes.reserve(op.getNumResults());
-    for (auto result : op.getResults()) {
-      auto originalType = result.getType();
-      Type convertedType = getTypeConverter()->convertType(originalType);
-      newResultTypes.push_back(convertedType);
-    }
-
-    auto newAssumingOp = rewriter.create<AssumingOp>(
-        op.getLoc(), newResultTypes, op.getWitness());
-    rewriter.inlineRegionBefore(op.getDoRegion(), newAssumingOp.getDoRegion(),
-                                newAssumingOp.getDoRegion().end());
-    rewriter.replaceOp(op, newAssumingOp.getResults());
-
-    return success();
-  }
-};
-} // namespace
-
-namespace {
-class ConvertAssumingYieldOpTypes
-    : public OpConversionPattern<AssumingYieldOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(AssumingYieldOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, adaptor.getOperands());
-    return success();
-  }
-};
-} // namespace
-
-void mlir::populateShapeStructuralTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target) {
-  patterns.add<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
-      typeConverter, patterns.getContext());
-  target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
-    return typeConverter.isLegal(op.getResultTypes());
-  });
-  target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
-    return typeConverter.isLegal(op.getOperandTypes());
-  });
-}
index 5d76e2f..bd7cb46 100644 (file)
@@ -2702,7 +2702,10 @@ cc_library(
         "lib/Dialect/Shape/Transforms/*.cpp",
         "lib/Dialect/Shape/Transforms/*.h",
     ]),
-    hdrs = ["include/mlir/Dialect/Shape/Transforms/Passes.h"],
+    hdrs = [
+        "include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h",
+        "include/mlir/Dialect/Shape/Transforms/Passes.h",
+    ],
     includes = ["include"],
     deps = [
         ":ArithmeticDialect",