From 5830f71a45df33e24c864bea4c5de070be47b488 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 3 Oct 2019 23:10:25 -0700 Subject: [PATCH] Add support for inlining calls with different arg/result types from the callable. Some dialects have implicit conversions inherent in their modeling, meaning that a call may have a different type that the type that the callable expects. To support this, a hook is added to the dialect interface that allows for materializing conversion operations during inlining when there is a mismatch. A hook is also added to the callable interface to allow for introspecting the expected result types. PiperOrigin-RevId: 272814379 --- mlir/include/mlir/Analysis/CallInterfaces.td | 8 +- mlir/include/mlir/IR/Function.h | 7 ++ mlir/include/mlir/Transforms/InliningUtils.h | 52 +++++++---- mlir/lib/Transforms/Inliner.cpp | 8 +- mlir/lib/Transforms/Utils/InliningUtils.cpp | 134 ++++++++++++++++++++------- mlir/test/Transforms/inlining.mlir | 36 +++++++ mlir/test/lib/TestDialect/TestDialect.cpp | 17 +++- mlir/test/lib/TestDialect/TestOps.td | 23 +++++ 8 files changed, 228 insertions(+), 57 deletions(-) diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td index fca7773..3ed802f 100644 --- a/mlir/include/mlir/Analysis/CallInterfaces.td +++ b/mlir/include/mlir/Analysis/CallInterfaces.td @@ -80,11 +80,17 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { "Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable) >, InterfaceMethod<[{ - Returns all of the callable regions of this operation + Returns all of the callable regions of this operation. }], "void", "getCallableRegions", (ins "SmallVectorImpl &":$callables) >, + InterfaceMethod<[{ + Returns the results types that the given callable region produces when + executed. + }], + "ArrayRef", "getCallableResults", (ins "Region *":$callable) + >, ]; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 6bf6e65..95920b3 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -128,6 +128,13 @@ public: callables.push_back(&getBody()); } + /// Returns the results types that the given callable region produces when + /// executed. + ArrayRef getCallableResults(Region *region) { + assert(!isExternal() && region == &getBody() && "invalid callable"); + return getType().getResults(); + } + private: // This trait needs access to `getNumFuncArguments` and `verifyType` hooks // defined below. diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index 7fe67e7..fd12624 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -30,7 +30,10 @@ namespace mlir { class Block; class BlockAndValueMapping; +class CallableOpInterface; +class CallOpInterface; class FuncOp; +class OpBuilder; class Operation; class Region; class Value; @@ -106,6 +109,27 @@ public: llvm_unreachable( "must implement handleTerminator in the case of one inlined block"); } + + /// Attempt to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. For example, this hook may be invoked in the following + /// scenarios: + /// func @foo(i32) -> i32 { ... } + /// + /// // Mismatched input operand + /// ... = foo.call @foo(%input : i16) -> i32 + /// + /// // Mismatched result type. + /// ... = foo.call @foo(%input : i32) -> i16 + /// + /// NOTE: This hook may be invoked before the 'isLegal' checks above. + virtual Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Type resultType, + Location conversionLoc) const { + return nullptr; + } }; /// This interface provides the hooks into the inlining interface. @@ -115,7 +139,6 @@ class InlinerInterface : public DialectInterfaceCollection { public: using Base::Base; - virtual ~InlinerInterface(); /// Process a set of blocks that have been inlined. This callback is invoked /// *before* inlined terminator operations have been processed. @@ -178,24 +201,15 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src, llvm::Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); -/// This function inlines a FuncOp into another. This function returns failure -/// if it is not possible to inline this FuncOp. If the function returned -/// failure, then no changes to the module have been made. -/// -/// Note that this only does one level of inlining. For example, if the -/// instruction 'call B' is inlined into function 'A', and function 'B' also -/// calls 'C', then the call to 'C' now exists inside the body of 'A'. Similarly -/// this will inline a recursive FuncOp by one level. -/// -/// 'callOperands' must correspond, 1-1, with the arguments to the provided -/// FuncOp. 'callResults' must correspond, 1-1, with the results of the -/// provided FuncOp. These results will be replaced by the operands of any -/// return operations that are inlined. 'inlineLoc' should refer to the location -/// that the FuncOp is being inlined into. -LogicalResult inlineFunction(InlinerInterface &interface, FuncOp callee, - Operation *inlinePoint, - ArrayRef callOperands, - ArrayRef callResults, Location inlineLoc); +/// This function inlines a given region, 'src', of a callable operation, +/// 'callable', into the location defined by the given call operation. This +/// function returns failure if inlining is not possible, success otherwise. On +/// failure, no changes are made to the module. 'shouldCloneInlinedRegion' +/// corresponds to whether the source region should be cloned into the 'call' or +/// spliced directly. +LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, + CallableOpInterface callable, Region *src, + bool shouldCloneInlinedRegion = true); } // end namespace mlir diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index afb2dcc..c5defa5 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -157,10 +157,10 @@ static void inlineCallsInSCC(Inliner &inliner, continue; CallOpInterface call = it.call; - LogicalResult inlineResult = inlineRegion( - inliner, it.targetNode->getCallableRegion(), call, - llvm::to_vector<8>(call.getArgOperands()), - llvm::to_vector<8>(call.getOperation()->getResults()), call.getLoc()); + Region *targetRegion = it.targetNode->getCallableRegion(); + LogicalResult inlineResult = inlineCall( + inliner, call, cast(targetRegion->getParentOp()), + targetRegion); if (failed(inlineResult)) continue; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index 6ca875b2..fd08c53 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/MapVector.h" @@ -65,8 +66,6 @@ remapInlinedOperands(llvm::iterator_range inlinedBlocks, // InlinerInterface //===----------------------------------------------------------------------===// -InlinerInterface::~InlinerInterface() {} - bool InlinerInterface::isLegalToInline( Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { // Regions can always be inlined into functions. @@ -74,7 +73,7 @@ bool InlinerInterface::isLegalToInline( return true; auto *handler = getInterfaceFor(dest->getParentOp()); - return handler ? handler->isLegalToInline(src, dest, valueMapping) : false; + return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; } bool InlinerInterface::isLegalToInline( @@ -253,38 +252,109 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, inlineLoc, shouldCloneInlinedRegion); } -/// This function inlines a FuncOp into another. This function returns failure -/// if it is not possible to inline this FuncOp. If the function returned -/// failure, then no changes to the module have been made. -/// -/// Note that this only does one level of inlining. For example, if the -/// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now -/// exists in the instruction stream. Similarly this will inline a recursive -/// FuncOp by one level. -/// -LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee, - Operation *inlinePoint, - ArrayRef callOperands, - ArrayRef callResults, - Location inlineLoc) { - // We don't inline if the provided callee function is a declaration. - assert(callee && "expected valid function to inline"); - if (callee.isExternal()) - return failure(); +/// Utility function used to generate a cast operation from the given interface, +/// or return nullptr if a cast could not be generated. +static Value *materializeConversion(const DialectInlinerInterface *interface, + SmallVectorImpl &castOps, + OpBuilder &castBuilder, Value *arg, + Type type, Location conversionLoc) { + if (!interface) + return nullptr; + + // Check to see if the interface for the call can materialize a conversion. + Operation *castOp = interface->materializeCallConversion(castBuilder, arg, + type, conversionLoc); + if (!castOp) + return nullptr; + castOps.push_back(castOp); + + // Ensure that the generated cast is correct. + assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && + castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); + return castOp->getResult(0); +} - // Verify that the provided arguments match the function arguments. - if (callOperands.size() != callee.getNumArguments()) +/// This function inlines a given region, 'src', of a callable operation, +/// 'callable', into the location defined by the given call operation. This +/// function returns failure if inlining is not possible, success otherwise. On +/// failure, no changes are made to the module. 'shouldCloneInlinedRegion' +/// corresponds to whether the source region should be cloned into the 'call' or +/// spliced directly. +LogicalResult mlir::inlineCall(InlinerInterface &interface, + CallOpInterface call, + CallableOpInterface callable, Region *src, + bool shouldCloneInlinedRegion) { + // We expect the region to have at least one block. + if (src->empty()) return failure(); + auto *entryBlock = &src->front(); + ArrayRef callableResultTypes = callable.getCallableResults(src); + + // Make sure that the number of arguments and results matchup between the call + // and the region. + SmallVector callOperands(call.getArgOperands()); + SmallVector callResults(call.getOperation()->getResults()); + if (callOperands.size() != entryBlock->getNumArguments() || + callResults.size() != callableResultTypes.size()) + return failure(); + + // A set of cast operations generated to matchup the signature of the region + // with the signature of the call. + SmallVector castOps; + castOps.reserve(callOperands.size() + callResults.size()); - // Verify that the provided values to replace match the function results. - auto funcResultTypes = callee.getType().getResults(); - if (callResults.size() != funcResultTypes.size()) + // Functor used to cleanup generated state on failure. + auto cleanupState = [&] { + for (auto *op : castOps) { + op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); + op->erase(); + } return failure(); - for (unsigned i = 0, e = callResults.size(); i != e; ++i) - if (callResults[i]->getType() != funcResultTypes[i]) - return failure(); + }; - // Call into the main region inliner function. - return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands, - callResults, inlineLoc); + // Builder used for any conversion operations that need to be materialized. + OpBuilder castBuilder(call); + Location castLoc = call.getLoc(); + auto *callInterface = interface.getInterfaceFor(call.getDialect()); + + // Map the provided call operands to the arguments of the region. + BlockAndValueMapping mapper; + for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { + BlockArgument *regionArg = entryBlock->getArgument(i); + Value *operand = callOperands[i]; + + // If the call operand doesn't match the expected region argument, try to + // generate a cast. + Type regionArgType = regionArg->getType(); + if (operand->getType() != regionArgType) { + if (!(operand = materializeConversion(callInterface, castOps, castBuilder, + operand, regionArgType, castLoc))) + return cleanupState(); + } + mapper.map(regionArg, operand); + } + + // Ensure that the resultant values of the call, match the callable. + castBuilder.setInsertionPointAfter(call); + for (unsigned i = 0, e = callResults.size(); i != e; ++i) { + Value *callResult = callResults[i]; + if (callResult->getType() == callableResultTypes[i]) + continue; + + // Generate a conversion that will produce the original type, so that the IR + // is still valid after the original call gets replaced. + Value *castResult = + materializeConversion(callInterface, castOps, castBuilder, callResult, + callResult->getType(), castLoc); + if (!castResult) + return cleanupState(); + callResult->replaceAllUsesWith(castResult); + castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult); + } + + // Attempt to inline the call. + if (failed(inlineRegion(interface, src, call, mapper, callResults, + call.getLoc(), shouldCloneInlinedRegion))) + return cleanupState(); + return success(); } diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index 9732992..4d855d0 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -105,3 +105,39 @@ func @no_inline_recursive() { }) : () -> (() -> ()) return } + +// Check that we can convert types for inputs and results as necessary. +func @convert_callee_fn(%arg : i32) -> i32 { + return %arg : i32 +} +func @convert_callee_fn_multi_arg(%a : i32, %b : i32) -> () { + return +} +func @convert_callee_fn_multi_res() -> (i32, i32) { + %res = constant 0 : i32 + return %res, %res : i32, i32 +} + +// CHECK-LABEL: func @inline_convert_call +func @inline_convert_call() -> i16 { + // CHECK: %[[INPUT:.*]] = constant + %test_input = constant 0 : i16 + + // CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[INPUT]]) : (i16) -> i32 + // CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CAST_INPUT]]) : (i32) -> i16 + // CHECK-NEXT: return %[[CAST_RESULT]] + %res = "test.conversion_call_op"(%test_input) { callee=@convert_callee_fn } : (i16) -> (i16) + return %res : i16 +} + +// CHECK-LABEL: func @no_inline_convert_call +func @no_inline_convert_call() { + // CHECK: "test.conversion_call_op" + %test_input_i16 = constant 0 : i16 + %test_input_i64 = constant 0 : i64 + "test.conversion_call_op"(%test_input_i16, %test_input_i64) { callee=@convert_callee_fn_multi_arg } : (i16, i64) -> () + + // CHECK: "test.conversion_call_op" + %res_2:2 = "test.conversion_call_op"() { callee=@convert_callee_fn_multi_res } : () -> (i16, i64) + return +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index ca523d8..78a75f0 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -58,7 +58,7 @@ struct TestInlinerInterface : public DialectInlinerInterface { return true; } - bool shouldAnalyzeRecursively(Operation *op) const override { + bool shouldAnalyzeRecursively(Operation *op) const final { // Analyze recursively if this is not a functional region operation, it // froms a separate functional scope. return !isa(op); @@ -82,6 +82,21 @@ struct TestInlinerInterface : public DialectInlinerInterface { for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); } + + /// Attempt to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Type resultType, + Location conversionLoc) const final { + // Only allow conversion for i16/i32 types. + if (!(resultType.isInteger(16) || resultType.isInteger(32)) || + !(input->getType().isInteger(16) || input->getType().isInteger(32))) + return nullptr; + return builder.create(conversionLoc, resultType, input); + } }; } // end anonymous namespace diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 944ce79..41e44f6 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -194,6 +194,26 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> { let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>); } +//===----------------------------------------------------------------------===// +// Test Call Interfaces +//===----------------------------------------------------------------------===// + +def ConversionCallOp : TEST_Op<"conversion_call_op", + [CallOpInterface]> { + let arguments = (ins Variadic:$inputs, SymbolRefAttr:$callee); + let results = (outs Variadic); + + let extraClassDeclaration = [{ + /// Get the argument operands to the called function. + operand_range getArgOperands() { return inputs(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return getAttrOfType("callee"); + } + }]; +} + def FunctionalRegionOp : TEST_Op<"functional_region_op", [CallableOpInterface]> { let regions = (region AnyRegion:$body); @@ -204,6 +224,9 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op", void getCallableRegions(SmallVectorImpl &callables) { callables.push_back(&body()); } + ArrayRef getCallableResults(Region *) { + return getType().cast().getResults(); + } }]; } -- 2.7.4