From: Haruki Imai Date: Fri, 25 Sep 2020 17:19:23 +0000 (+0530) Subject: [MLIR] Fix for updating function signature in normalizing memrefs X-Git-Tag: llvmorg-13-init~10897 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c1f856803142a113fa094411fa4760512b919ef6;p=platform%2Fupstream%2Fllvm.git [MLIR] Fix for updating function signature in normalizing memrefs Normalizing memrefs failed when a caller of symbolic use in a function can not be casted to `CallOp`. This patch avoids the failure by checking the result of the casting. If the caller can not be casted to `CallOp`, it is skipped. Differential Revision: https://reviews.llvm.org/D87746 --- diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp index c4f91eb..ac02f0e 100644 --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -263,16 +263,23 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp, // type at the caller site. Optional symbolUses = funcOp.getSymbolUses(moduleOp); for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - Operation *callOp = symbolUse.getUser(); - OpBuilder builder(callOp); - StringRef callee = cast(callOp).getCallee(); + Operation *userOp = symbolUse.getUser(); + OpBuilder builder(userOp); + // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes + // that the non-CallOp has no memrefs to be replaced. + // TODO: Handle cases where a non-CallOp symbol use of a function deals with + // memrefs. + auto callOp = dyn_cast(userOp); + if (!callOp) + continue; + StringRef callee = callOp.getCallee(); Operation *newCallOp = builder.create( - callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), - callOp->getOperands()); + userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), + userOp->getOperands()); bool replacingMemRefUsesFailed = false; bool returnTypeChanged = false; - for (unsigned resIndex : llvm::seq(0, callOp->getNumResults())) { - OpResult oldResult = callOp->getResult(resIndex); + for (unsigned resIndex : llvm::seq(0, userOp->getNumResults())) { + OpResult oldResult = userOp->getResult(resIndex); OpResult newResult = newCallOp->getResult(resIndex); // This condition ensures that if the result is not of type memref or if // the resulting memref was already having a trivial map layout then we @@ -302,8 +309,8 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp, if (replacingMemRefUsesFailed) continue; // Replace all uses for other non-memref result types. - callOp->replaceAllUsesWith(newCallOp); - callOp->erase(); + userOp->replaceAllUsesWith(newCallOp); + userOp->erase(); if (returnTypeChanged) { // Since the return type changed it might lead to a change in function's // signature. diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir index 0c67157..b112752 100644 --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -89,3 +89,7 @@ func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () { // CHECK: dealloc %[[v1]] : memref<1x16x14x14xf32> return } + +// Test with an arbitrary op that references the function symbol. + +"test.op_funcref"() {func = @test_norm_mix} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h index 09f84d1..e44b561 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" @@ -29,7 +30,6 @@ #include "TestOpEnums.h.inc" - #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6d4e58c..3743e39 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -629,6 +629,17 @@ def OpNonNorm : TEST_Op<"op_nonnorm"> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); } +// Test for memrefs normalization of an op with a reference to a function +// symbol. +def OpFuncRef : TEST_Op<"op_funcref"> { + let summary = "Test op with a reference to a function symbol"; + let description = [{ + The "test.op_funcref" is a test op with a reference to a function symbol. + }]; + let builders = [OpBuilder<[{OpBuilder &builder, OperationState &state, + FuncOp function}]>]; +} + // Pattern add the argument plus a increasing static number hidden in // OpMTest function. That value is set into the optional argument. // That way, we will know if operations is called once or twice.