#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectInterface.h"
namespace mlir {
class Operation;
class Value;
+//===--------------------------------------------------------------------===//
+// Operation Folding Interface
+//===--------------------------------------------------------------------===//
+
+/// This class defines a dialect interface used to assist the operation folder.
+/// It provides hooks for materializing and folding operations.
+class OpFolderDialectInterface
+ : public DialectInterface::Base<OpFolderDialectInterface> {
+public:
+ OpFolderDialectInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// Registered hook to check if the given region, which is attached to an
+ /// operation that is *not* isolated from above, should be used when
+ /// materializing constants. The folder will generally materialize constants
+ /// into the top-level isolated region, this allows for materializing into a
+ /// lower level ancestor region if it is more profitable/correct.
+ virtual bool shouldMaterializeInto(Region *region) const { return false; }
+};
+
+//===--------------------------------------------------------------------===//
+// OperationFolder
+//===--------------------------------------------------------------------===//
+
/// A utility class for folding operations, and unifying duplicated constants
/// generated along the way.
class OperationFolder {
public:
+ OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
+
/// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, replaces `op`'s uses with
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
/// This map tracks all of the dialects that an operation is referenced by;
/// given that many dialects may generate the same constant.
DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
+
+ /// A collection of dialect folder interfaces.
+ DialectInterfaceCollection<OpFolderDialectInterface> interfaces;
};
} // end namespace mlir
}
static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
- OperationFolder state;
+ OperationFolder state(f.getContext());
DenseSet<Operation *> eraseSet;
LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
class LinalgRewritePattern : public RewritePattern {
public:
explicit LinalgRewritePattern(MLIRContext *context)
- : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) {
- }
+ : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context),
+ folder(context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes,
bool promoteViews) {
- OperationFolder folder;
+ OperationFolder folder(f.getContext());
f.walk([promoteViews, tileSizes, &folder](LinalgOp op) {
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing.
/// Given an operation, find the parent region that folded constants should be
/// inserted into.
-static Region *getInsertionRegion(Operation *op) {
+static Region *getInsertionRegion(
+ DialectInterfaceCollection<OpFolderDialectInterface> &interfaces,
+ Operation *op) {
while (Region *region = op->getParentRegion()) {
// Insert in this region for any of the following scenarios:
// * The parent is unregistered, or is known to be isolated from above.
if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() ||
!parentOp->getBlock())
return region;
+
+ // Otherwise, check if this region is a desired insertion region.
+ auto *interface = interfaces.getInterfaceFor(parentOp);
+ if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
+ return region;
+
// Traverse up the parent looking for an insertion region.
op = parentOp;
}
assert(constValue);
// Get the constant map that this operation was uniqued in.
- auto &uniquedConstants = foldScopes[getInsertionRegion(op)];
+ auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)];
// Erase all of the references to this operation.
auto type = op->getResult(0)->getType();
// Create a builder to insert new operations into the entry block of the
// insertion region.
- auto *insertionRegion = getInsertionRegion(op);
- auto &entry = insertionRegion->front();
+ auto *insertRegion = getInsertionRegion(interfaces, op);
+ auto &entry = insertRegion->front();
OpBuilder builder(&entry, entry.begin());
// Get the constant map for the insertion region of this operation.
- auto &uniquedConstants = foldScopes[insertionRegion];
+ auto &uniquedConstants = foldScopes[insertRegion];
// Create the result constants and replace the results.
auto *dialect = op->getDialect();
public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const OwningRewritePatternList &patterns)
- : PatternRewriter(ctx), matcher(patterns) {
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
worklist.reserve(64);
}
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func @custom_insertion_position
+func @custom_insertion_position() {
+ // CHECK: test.one_region_op
+ // CHECK-NEXT: constant 2
+ "test.one_region_op"() ({
+
+ %0 = constant 1 : i32
+ %2 = addi %0, %0 : i32
+ "foo.yield"(%2) : (i32) -> ()
+ }) : () -> ()
+ return
+}
#include "TestDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
+// TestDialect Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
+ using OpFolderDialectInterface::OpFolderDialectInterface;
+
+ /// Registered hook to check if the given region, which is attached to an
+ /// operation that is *not* isolated from above, should be used when
+ /// materializing constants.
+ virtual bool shouldMaterializeInto(Region *region) const {
+ // If this is a one region operation, then insert into it.
+ return isa<OneRegionOp>(region->getParentOp());
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
+ addInterfaces<TestOpFolderDialectInterface>();
allowUnknownOperations();
}
// Test Regions
//===----------------------------------------------------------------------===//
+def OneRegionOp : TEST_Op<"one_region_op", []> {
+ let regions = (region AnyRegion);
+}
+
def TwoRegionOp : TEST_Op<"two_region_op", []> {
let regions = (region AnyRegion, AnyRegion);
}
// folding are at the beginning. This creates somewhat of a linear ordering to
// the newly generated constants that matches the operation order and improves
// the readability of test cases.
- OperationFolder helper;
+ OperationFolder helper(&getContext());
for (Operation *op : llvm::reverse(ops))
foldOperation(op, helper);