Add a new dialect interface for the OperationFolder `OpFolderDialectInterface`.
authorRiver Riddle <riverriddle@google.com>
Mon, 2 Sep 2019 03:06:42 +0000 (20:06 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 2 Sep 2019 03:07:08 +0000 (20:07 -0700)
This interface will allow for providing hooks to interrop with operation folding. The first hook, 'shouldMaterializeInto', will allow for controlling which region to insert materialized constants into. The folder will generally materialize constants into the top-level isolated region, this allows for materializing into a lower level ancestor region if it is more profitable/correct.

PiperOrigin-RevId: 266702972

mlir/include/mlir/Transforms/FoldUtils.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/constant-fold.mlir
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestOps.td
mlir/test/lib/Transforms/TestConstantFold.cpp

index 87a3e13..bbf2c0e 100644 (file)
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectInterface.h"
 
 namespace mlir {
 class Operation;
 class Value;
 
+//===--------------------------------------------------------------------===//
+// Operation Folding Interface
+//===--------------------------------------------------------------------===//
+
+/// This class defines a dialect interface used to assist the operation folder.
+/// It provides hooks for materializing and folding operations.
+class OpFolderDialectInterface
+    : public DialectInterface::Base<OpFolderDialectInterface> {
+public:
+  OpFolderDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Registered hook to check if the given region, which is attached to an
+  /// operation that is *not* isolated from above, should be used when
+  /// materializing constants. The folder will generally materialize constants
+  /// into the top-level isolated region, this allows for materializing into a
+  /// lower level ancestor region if it is more profitable/correct.
+  virtual bool shouldMaterializeInto(Region *region) const { return false; }
+};
+
+//===--------------------------------------------------------------------===//
+// OperationFolder
+//===--------------------------------------------------------------------===//
+
 /// A utility class for folding operations, and unifying duplicated constants
 /// generated along the way.
 class OperationFolder {
 public:
+  OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
+
   /// Tries to perform folding on the given `op`, including unifying
   /// deduplicated constants. If successful, replaces `op`'s uses with
   /// folded results, and returns success. `preReplaceAction` is invoked on `op`
@@ -116,6 +142,9 @@ private:
   /// This map tracks all of the dialects that an operation is referenced by;
   /// given that many dialects may generate the same constant.
   DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
+
+  /// A collection of dialect folder interfaces.
+  DialectInterfaceCollection<OpFolderDialectInterface> interfaces;
 };
 
 } // end namespace mlir
index d486064..954f826 100644 (file)
@@ -232,7 +232,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
 }
 
 static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
-  OperationFolder state;
+  OperationFolder state(f.getContext());
   DenseSet<Operation *> eraseSet;
 
   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
index c48437f..54c0350 100644 (file)
@@ -288,8 +288,8 @@ template <typename ConcreteOp>
 class LinalgRewritePattern : public RewritePattern {
 public:
   explicit LinalgRewritePattern(MLIRContext *context)
-      : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) {
-  }
+      : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context),
+        folder(context) {}
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
index 11b3334..99e42cf 100644 (file)
@@ -489,7 +489,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
 
 static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes,
                           bool promoteViews) {
-  OperationFolder folder;
+  OperationFolder folder(f.getContext());
   f.walk([promoteViews, tileSizes, &folder](LinalgOp op) {
     // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
     // nothing.
index 6c313e2..5faca12 100644 (file)
@@ -31,7 +31,9 @@ using namespace mlir;
 
 /// Given an operation, find the parent region that folded constants should be
 /// inserted into.
-static Region *getInsertionRegion(Operation *op) {
+static Region *getInsertionRegion(
+    DialectInterfaceCollection<OpFolderDialectInterface> &interfaces,
+    Operation *op) {
   while (Region *region = op->getParentRegion()) {
     // Insert in this region for any of the following scenarios:
     //  * The parent is unregistered, or is known to be isolated from above.
@@ -40,6 +42,12 @@ static Region *getInsertionRegion(Operation *op) {
     if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() ||
         !parentOp->getBlock())
       return region;
+
+    // Otherwise, check if this region is a desired insertion region.
+    auto *interface = interfaces.getInterfaceFor(parentOp);
+    if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
+      return region;
+
     // Traverse up the parent looking for an insertion region.
     op = parentOp;
   }
@@ -119,7 +127,7 @@ void OperationFolder::notifyRemoval(Operation *op) {
   assert(constValue);
 
   // Get the constant map that this operation was uniqued in.
-  auto &uniquedConstants = foldScopes[getInsertionRegion(op)];
+  auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)];
 
   // Erase all of the references to this operation.
   auto type = op->getResult(0)->getType();
@@ -161,12 +169,12 @@ LogicalResult OperationFolder::tryToFold(
 
   // Create a builder to insert new operations into the entry block of the
   // insertion region.
-  auto *insertionRegion = getInsertionRegion(op);
-  auto &entry = insertionRegion->front();
+  auto *insertRegion = getInsertionRegion(interfaces, op);
+  auto &entry = insertRegion->front();
   OpBuilder builder(&entry, entry.begin());
 
   // Get the constant map for the insertion region of this operation.
-  auto &uniquedConstants = foldScopes[insertionRegion];
+  auto &uniquedConstants = foldScopes[insertRegion];
 
   // Create the result constants and replace the results.
   auto *dialect = op->getDialect();
index ddb92a5..86e8848 100644 (file)
@@ -45,7 +45,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const OwningRewritePatternList &patterns)
-      : PatternRewriter(ctx), matcher(patterns) {
+      : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
     worklist.reserve(64);
   }
 
index c1db19e..edf8e5d 100644 (file)
@@ -456,3 +456,18 @@ func @nested_isolated_region() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @custom_insertion_position
+func @custom_insertion_position() {
+  // CHECK: test.one_region_op
+  // CHECK-NEXT: constant 2
+  "test.one_region_op"() ({
+
+    %0 = constant 1 : i32
+    %2 = addi %0, %0 : i32
+    "foo.yield"(%2) : (i32) -> ()
+  }) : () -> ()
+  return
+}
index af5c5c8..84d4ed8 100644 (file)
 #include "TestDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
+// TestDialect Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
+  using OpFolderDialectInterface::OpFolderDialectInterface;
+
+  /// Registered hook to check if the given region, which is attached to an
+  /// operation that is *not* isolated from above, should be used when
+  /// materializing constants.
+  virtual bool shouldMaterializeInto(Region *region) const {
+    // If this is a one region operation, then insert into it.
+    return isa<OneRegionOp>(region->getParentOp());
+  }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//
 
@@ -31,6 +50,7 @@ TestDialect::TestDialect(MLIRContext *context)
 #define GET_OP_LIST
 #include "TestOps.cpp.inc"
       >();
+  addInterfaces<TestOpFolderDialectInterface>();
   allowUnknownOperations();
 }
 
index 0010e1d..ee7a396 100644 (file)
@@ -162,6 +162,10 @@ def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
 // Test Regions
 //===----------------------------------------------------------------------===//
 
+def OneRegionOp : TEST_Op<"one_region_op", []> {
+  let regions = (region AnyRegion);
+}
+
 def TwoRegionOp : TEST_Op<"two_region_op", []> {
   let regions = (region AnyRegion, AnyRegion);
 }
index 9c54169..b1c8952 100644 (file)
@@ -61,7 +61,7 @@ void TestConstantFold::runOnFunction() {
   // folding are at the beginning. This creates somewhat of a linear ordering to
   // the newly generated constants that matches the operation order and improves
   // the readability of test cases.
-  OperationFolder helper;
+  OperationFolder helper(&getContext());
   for (Operation *op : llvm::reverse(ops))
     foldOperation(op, helper);