From 501fda0167341f2db0da5198f70defb017a36178 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 28 Oct 2020 21:48:38 -0700 Subject: [PATCH] [mlir][Inliner] Add a new hook for checking if it is legal to inline a callable into a call In certain situations it isn't legal to inline a call operation, but this isn't something that is possible(at least not easily) to prevent with the current hooks. This revision adds a new hook so that dialects with call operations that shouldn't be inlined can prevent it. Differential Revision: https://reviews.llvm.org/D90359 --- mlir/docs/Tutorials/Toy/Ch-4.md | 7 +++++++ mlir/examples/toy/Ch4/mlir/Dialect.cpp | 5 +++++ mlir/examples/toy/Ch5/mlir/Dialect.cpp | 5 +++++ mlir/examples/toy/Ch6/mlir/Dialect.cpp | 5 +++++ mlir/examples/toy/Ch7/mlir/Dialect.cpp | 5 +++++ mlir/include/mlir/IR/Operation.h | 11 ++++++++++- mlir/include/mlir/Transforms/InliningUtils.h | 9 +++++++++ mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | 5 +++++ mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 5 +++++ mlir/lib/Transforms/Utils/InliningUtils.cpp | 10 ++++++++++ mlir/test/Transforms/inlining.mlir | 6 ++++++ mlir/test/lib/Dialect/Test/TestDialect.cpp | 4 ++++ 12 files changed, 76 insertions(+), 1 deletion(-) diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md index 11e6ddf..0580413 100644 --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -61,6 +61,13 @@ In this case, the interface is `DialectInlinerInterface`. struct ToyInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + /// This hook checks to see if the given callable operation is legal to inline + /// into the given call. For Toy this hook can simply return true, as the Toy + /// Call operation is always inlinable. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// This hook checks to see if the given operation is legal to inline into the /// given region. For Toy this hook can simply return true, as all Toy /// operations are inlinable. diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index ca568a5..462de2b 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index d1a518e..87bd185 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index d1a518e..87bd185 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 046637f..14d764e 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -35,6 +35,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within toy can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within toy can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 6de7677..d3dce86 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -323,11 +323,20 @@ public: template AttrClass getAttrOfType(Identifier name) { return getAttr(name).dyn_cast_or_null(); } - template AttrClass getAttrOfType(StringRef name) { return getAttr(name).dyn_cast_or_null(); } + /// Return true if the operation has an attribute with the provided name, + /// false otherwise. + bool hasAttr(Identifier name) { return static_cast(getAttr(name)); } + bool hasAttr(StringRef name) { return static_cast(getAttr(name)); } + template + bool hasAttrOfType(NameT &&name) { + return static_cast( + getAttrOfType(std::forward(name))); + } + /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. void setAttr(Identifier name, Attribute value) { attrs.set(name, value); } diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index a526d0f..9c4fdf2 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -47,6 +47,14 @@ public: // Analysis Hooks //===--------------------------------------------------------------------===// + /// Returns true if the given operation 'callable', that implements the + /// 'CallableOpInterface', can be inlined into the position given call + /// operation 'call', that is registered to the current dialect and implements + /// the `CallOpInterface`. + virtual bool isLegalToInline(Operation *call, Operation *callable) const { + return false; + } + /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. /// 'valueMapping' contains any remapped values from within the 'src' region. @@ -146,6 +154,7 @@ public: // Analysis Hooks //===--------------------------------------------------------------------===// + virtual bool isLegalToInline(Operation *call, Operation *callable) const; virtual bool isLegalToInline(Region *dest, Region *src, BlockAndValueMapping &valueMapping) const; virtual bool isLegalToInline(Operation *op, Region *dest, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index ac6d615..8740876 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -56,6 +56,11 @@ namespace { struct SPIRVInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + /// All call operations within SPIRV can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. bool isLegalToInline(Region *dest, Region *src, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 48c3155..9c8753d 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -46,6 +46,11 @@ struct StdInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// + /// All call operations within standard ops can be inlined. + bool isLegalToInline(Operation *call, Operation *callable) const final { + return true; + } + /// All operations within standard ops can be inlined. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 4b7ae80..4e0251b 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -57,6 +57,12 @@ static void remapInlinedOperands(iterator_range inlinedBlocks, // InlinerInterface //===----------------------------------------------------------------------===// +bool InlinerInterface::isLegalToInline(Operation *call, + Operation *callable) const { + auto *handler = getInterfaceFor(call); + return handler ? handler->isLegalToInline(call, callable) : false; +} + bool InlinerInterface::isLegalToInline( Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { // Regions can always be inlined into functions. @@ -352,6 +358,10 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); } + // Check that it is legal to inline the callable into the call. + if (!interface.isLegalToInline(call, callable)) + return cleanupState(); + // Attempt to inline the call. if (failed(inlineRegion(interface, src, call, mapper, callResults, callableResultTypes, call.getLoc(), diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index 9c9ed70..54bf6c6 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -183,3 +183,9 @@ func @inline_simplify() -> i32 { %res = call_indirect %fn() : () -> i32 return %res : i32 } + +// CHECK-LABEL: func @no_inline_invalid_call +func @no_inline_invalid_call() -> i32 { + %res = "test.conversion_call_op"() { callee=@convert_callee_fn_multiblock, noinline } : () -> (i32) + return %res : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index d2013d8c..8171367 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -77,6 +77,10 @@ struct TestInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// + bool isLegalToInline(Operation *call, Operation *callable) const final { + // Don't allow inlining calls that are marked `noinline`. + return !call->hasAttr("noinline"); + } bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { // Inlining into test dialect regions is legal. return true; -- 2.7.4