[mlir][Inliner] Add a new hook for checking if it is legal to inline a callable into...
authorRiver Riddle <riddleriver@gmail.com>
Thu, 29 Oct 2020 04:48:38 +0000 (21:48 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 29 Oct 2020 04:49:28 +0000 (21:49 -0700)
In certain situations it isn't legal to inline a call operation, but this isn't something that is possible(at least not easily) to prevent with the current hooks. This revision adds a new hook so that dialects with call operations that shouldn't be inlined can prevent it.

Differential Revision: https://reviews.llvm.org/D90359

12 files changed:
mlir/docs/Tutorials/Toy/Ch-4.md
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/Transforms/InliningUtils.h
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Transforms/Utils/InliningUtils.cpp
mlir/test/Transforms/inlining.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp

index 11e6ddf..0580413 100644 (file)
@@ -61,6 +61,13 @@ In this case, the interface is `DialectInlinerInterface`.
 struct ToyInlinerInterface : public DialectInlinerInterface {
   using DialectInlinerInterface::DialectInlinerInterface;
 
+  /// This hook checks to see if the given callable operation is legal to inline
+  /// into the given call. For Toy this hook can simply return true, as the Toy
+  /// Call operation is always inlinable.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// This hook checks to see if the given operation is legal to inline into the
   /// given region. For Toy this hook can simply return true, as all Toy
   /// operations are inlinable.
index ca568a5..462de2b 100644 (file)
@@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {
index d1a518e..87bd185 100644 (file)
@@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {
index d1a518e..87bd185 100644 (file)
@@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {
index 046637f..14d764e 100644 (file)
@@ -35,6 +35,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {
index 6de7677..d3dce86 100644 (file)
@@ -323,11 +323,20 @@ public:
   template <typename AttrClass> AttrClass getAttrOfType(Identifier name) {
     return getAttr(name).dyn_cast_or_null<AttrClass>();
   }
-
   template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
     return getAttr(name).dyn_cast_or_null<AttrClass>();
   }
 
+  /// Return true if the operation has an attribute with the provided name,
+  /// false otherwise.
+  bool hasAttr(Identifier name) { return static_cast<bool>(getAttr(name)); }
+  bool hasAttr(StringRef name) { return static_cast<bool>(getAttr(name)); }
+  template <typename AttrClass, typename NameT>
+  bool hasAttrOfType(NameT &&name) {
+    return static_cast<bool>(
+        getAttrOfType<AttrClass>(std::forward<NameT>(name)));
+  }
+
   /// If the an attribute exists with the specified name, change it to the new
   /// value.  Otherwise, add a new attribute with the specified name/value.
   void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
index a526d0f..9c4fdf2 100644 (file)
@@ -47,6 +47,14 @@ public:
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// Returns true if the given operation 'callable', that implements the
+  /// 'CallableOpInterface', can be inlined into the position given call
+  /// operation 'call', that is registered to the current dialect and implements
+  /// the `CallOpInterface`.
+  virtual bool isLegalToInline(Operation *call, Operation *callable) const {
+    return false;
+  }
+
   /// Returns true if the given region 'src' can be inlined into the region
   /// 'dest' that is attached to an operation registered to the current dialect.
   /// 'valueMapping' contains any remapped values from within the 'src' region.
@@ -146,6 +154,7 @@ public:
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  virtual bool isLegalToInline(Operation *call, Operation *callable) const;
   virtual bool isLegalToInline(Region *dest, Region *src,
                                BlockAndValueMapping &valueMapping) const;
   virtual bool isLegalToInline(Operation *op, Region *dest,
index ac6d615..8740876 100644 (file)
@@ -56,6 +56,11 @@ namespace {
 struct SPIRVInlinerInterface : public DialectInlinerInterface {
   using DialectInlinerInterface::DialectInlinerInterface;
 
+  /// All call operations within SPIRV can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// Returns true if the given region 'src' can be inlined into the region
   /// 'dest' that is attached to an operation registered to the current dialect.
   bool isLegalToInline(Region *dest, Region *src,
index 48c3155..9c8753d 100644 (file)
@@ -46,6 +46,11 @@ struct StdInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within standard ops can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within standard ops can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {
index 4b7ae80..4e0251b 100644 (file)
@@ -57,6 +57,12 @@ static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
 // InlinerInterface
 //===----------------------------------------------------------------------===//
 
+bool InlinerInterface::isLegalToInline(Operation *call,
+                                       Operation *callable) const {
+  auto *handler = getInterfaceFor(call);
+  return handler ? handler->isLegalToInline(call, callable) : false;
+}
+
 bool InlinerInterface::isLegalToInline(
     Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
   // Regions can always be inlined into functions.
@@ -352,6 +358,10 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
   }
 
+  // Check that it is legal to inline the callable into the call.
+  if (!interface.isLegalToInline(call, callable))
+    return cleanupState();
+
   // Attempt to inline the call.
   if (failed(inlineRegion(interface, src, call, mapper, callResults,
                           callableResultTypes, call.getLoc(),
index 9c9ed70..54bf6c6 100644 (file)
@@ -183,3 +183,9 @@ func @inline_simplify() -> i32 {
   %res = call_indirect %fn() : () -> i32
   return %res : i32
 }
+
+// CHECK-LABEL: func @no_inline_invalid_call
+func @no_inline_invalid_call() -> i32 {
+  %res = "test.conversion_call_op"() { callee=@convert_callee_fn_multiblock, noinline } : () -> (i32)
+  return %res : i32
+}
index d2013d8..8171367 100644 (file)
@@ -77,6 +77,10 @@ struct TestInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    // Don't allow inlining calls that are marked `noinline`.
+    return !call->hasAttr("noinline");
+  }
   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
     // Inlining into test dialect regions is legal.
     return true;