Add hook for dialect specializing processing blocks post inlining calls
authorJacques Pienaar <jpienaar@google.com>
Wed, 16 Jun 2021 19:53:21 +0000 (12:53 -0700)
committerJacques Pienaar <jpienaar@google.com>
Wed, 16 Jun 2021 19:53:21 +0000 (12:53 -0700)
This allows for dialects to do different post-processing depending on operations with the inliner (my use case requires different attribute propagation rules depending on call op). This hook runs before the regular processInlinedBlocks method.

Differential Revision: https://reviews.llvm.org/D104399

mlir/include/mlir/Transforms/InliningUtils.h
mlir/lib/Transforms/Utils/InliningUtils.cpp
mlir/test/Transforms/inlining.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp

index a86a6b9..8dcc1f5 100644 (file)
@@ -140,6 +140,11 @@ public:
                                                Location conversionLoc) const {
     return nullptr;
   }
+
+  /// Process a set of blocks that have been inlined for a call. This callback
+  /// is invoked before inlined terminator operations have been processed.
+  virtual void processInlinedCallBlocks(
+      Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {}
 };
 
 /// This interface provides the hooks into the inlining interface.
@@ -178,6 +183,8 @@ public:
   virtual void handleTerminator(Operation *op, Block *newDest) const;
   virtual void handleTerminator(Operation *op,
                                 ArrayRef<Value> valuesToRepl) const;
+  virtual void processInlinedCallBlocks(
+      Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -209,8 +216,7 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
 /// providing the set of operands ('inlinedOperands') that should be used
 /// in-favor of the region arguments when inlining.
 LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
-                           Operation *inlinePoint,
-                           ValueRange inlinedOperands,
+                           Operation *inlinePoint, ValueRange inlinedOperands,
                            ValueRange resultsToReplace,
                            Optional<Location> inlineLoc = llvm::None,
                            bool shouldCloneInlinedRegion = true);
index 7d18de0..5b50d21 100644 (file)
@@ -106,6 +106,13 @@ void InlinerInterface::handleTerminator(Operation *op,
   handler->handleTerminator(op, valuesToRepl);
 }
 
+void InlinerInterface::processInlinedCallBlocks(
+    Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
+  auto *handler = getInterfaceFor(call);
+  assert(handler && "expected valid dialect handler");
+  handler->processInlinedCallBlocks(call, inlinedBlocks);
+}
+
 /// Utility to check that all of the operations within 'src' can be inlined.
 static bool isLegalToInline(InlinerInterface &interface, Region *src,
                             Region *insertRegion, bool shouldCloneInlinedRegion,
@@ -137,13 +144,12 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
 // Inline Methods
 //===----------------------------------------------------------------------===//
 
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
-                                 Operation *inlinePoint,
-                                 BlockAndValueMapping &mapper,
-                                 ValueRange resultsToReplace,
-                                 TypeRange regionResultTypes,
-                                 Optional<Location> inlineLoc,
-                                 bool shouldCloneInlinedRegion) {
+static LogicalResult
+inlineRegionImpl(InlinerInterface &interface, Region *src,
+                 Operation *inlinePoint, BlockAndValueMapping &mapper,
+                 ValueRange resultsToReplace, TypeRange regionResultTypes,
+                 Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
+                 Operation *call) {
   assert(resultsToReplace.size() == regionResultTypes.size());
   // We expect the region to have at least one block.
   if (src->empty())
@@ -198,6 +204,8 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
     remapInlinedOperands(newBlocks, mapper);
 
   // Process the newly inlined blocks.
+  if (call)
+    interface.processInlinedCallBlocks(call, newBlocks);
   interface.processInlinedBlocks(newBlocks);
 
   // Handle the case where only a single block was inlined.
@@ -232,15 +240,11 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
   return success();
 }
 
-/// This function is an overload of the above 'inlineRegion' that allows for
-/// providing the set of operands ('inlinedOperands') that should be used
-/// in-favor of the region arguments when inlining.
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
-                                 Operation *inlinePoint,
-                                 ValueRange inlinedOperands,
-                                 ValueRange resultsToReplace,
-                                 Optional<Location> inlineLoc,
-                                 bool shouldCloneInlinedRegion) {
+static LogicalResult
+inlineRegionImpl(InlinerInterface &interface, Region *src,
+                 Operation *inlinePoint, ValueRange inlinedOperands,
+                 ValueRange resultsToReplace, Optional<Location> inlineLoc,
+                 bool shouldCloneInlinedRegion, Operation *call) {
   // We expect the region to have at least one block.
   if (src->empty())
     return failure();
@@ -261,9 +265,33 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
   }
 
   // Call into the main region inliner function.
-  return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
-                      resultsToReplace.getTypes(), inlineLoc,
-                      shouldCloneInlinedRegion);
+  return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
+                          resultsToReplace.getTypes(), inlineLoc,
+                          shouldCloneInlinedRegion, call);
+}
+
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+                                 Operation *inlinePoint,
+                                 BlockAndValueMapping &mapper,
+                                 ValueRange resultsToReplace,
+                                 TypeRange regionResultTypes,
+                                 Optional<Location> inlineLoc,
+                                 bool shouldCloneInlinedRegion) {
+  return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
+                          regionResultTypes, inlineLoc,
+                          shouldCloneInlinedRegion,
+                          /*call=*/nullptr);
+}
+
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+                                 Operation *inlinePoint,
+                                 ValueRange inlinedOperands,
+                                 ValueRange resultsToReplace,
+                                 Optional<Location> inlineLoc,
+                                 bool shouldCloneInlinedRegion) {
+  return inlineRegionImpl(interface, src, inlinePoint, inlinedOperands,
+                          resultsToReplace, inlineLoc, shouldCloneInlinedRegion,
+                          /*call=*/nullptr);
 }
 
 /// Utility function used to generate a cast operation from the given interface,
@@ -371,9 +399,9 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
     return cleanupState();
 
   // Attempt to inline the call.
-  if (failed(inlineRegion(interface, src, call, mapper, callResults,
-                          callableResultTypes, call.getLoc(),
-                          shouldCloneInlinedRegion)))
+  if (failed(inlineRegionImpl(interface, src, call, mapper, callResults,
+                              callableResultTypes, call.getLoc(),
+                              shouldCloneInlinedRegion, call)))
     return cleanupState();
   return success();
 }
index d568be0..e0368b2 100644 (file)
@@ -140,9 +140,9 @@ func @convert_callee_fn_multiblock() -> i32 {
 
 // CHECK-LABEL: func @inline_convert_result_multiblock
 func @inline_convert_result_multiblock() -> i16 {
-// CHECK:   br ^bb1
+// CHECK:   br ^bb1 {inlined_conversion}
 // CHECK: ^bb1:
-// CHECK:   %[[C:.+]] = constant 0 : i32
+// CHECK:   %[[C:.+]] = constant {inlined_conversion} 0 : i32
 // CHECK:   br ^bb2(%[[C]] : i32)
 // CHECK: ^bb2(%[[BBARG:.+]]: i32):
 // CHECK:   %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16
index a21e32a..8ef6ec6 100644 (file)
@@ -171,6 +171,20 @@ struct TestInlinerInterface : public DialectInlinerInterface {
       return nullptr;
     return builder.create<TestCastOp>(conversionLoc, resultType, input);
   }
+
+  void processInlinedCallBlocks(
+      Operation *call,
+      iterator_range<Region::iterator> inlinedBlocks) const final {
+    if (!isa<ConversionCallOp>(call))
+      return;
+
+    // Set attributed on all ops in the inlined blocks.
+    for (Block &block : inlinedBlocks) {
+      block.walk([&](Operation *op) {
+        op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
+      });
+    }
+  }
 };
 
 struct TestReductionPatternInterface : public DialectReductionPatternInterface {