From 183c4a391ef344220664d1d103d43639468bf103 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 19 Aug 2022 08:27:20 +0530 Subject: [PATCH] [MLIR][normalize-memrefs] Non-normalizable operations with identity map layouts do not block normalization of the entire function The current approach is convervative in which whenever there is a non-normalizable operations in a function will the function be labelled as non-normalizable. It means it requires that all operations must have MemRefsNormalizable trait. This patch relaxes the requirement that if the memref map layouts of a non-normalizable operation are identity, this operation does not block the normalization of the other operations in the same function. Reviewed By: bondhugula Differential Revision: https://reviews.llvm.org/D125854 --- .../Dialect/MemRef/Transforms/NormalizeMemRefs.cpp | 23 +++++++++++++++------- mlir/test/Transforms/normalize-memrefs-ops.mlir | 18 +++++++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index b0b31c9..55ce128 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -145,10 +145,10 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( /// Check whether all the uses of AllocOps, CallOps and function arguments of a /// function are either of dereferencing type or are uses in: DeallocOp, CallOp /// or ReturnOp. Only if these constraints are satisfied will the function -/// become a candidate for normalization. We follow a conservative approach here -/// wherein even if the non-normalizable memref is not a part of the function's -/// argument or return type, we still label the entire function as -/// non-normalizable. We assume external functions to be normalizable. +/// become a candidate for normalization. When the uses of a memref are +/// non-normalizable and the memref map layout is trivial (identity), we can +/// still label the entire function as normalizable. We assume external +/// functions to be normalizable. bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { // We assume external functions to be normalizable. if (funcOp.isExternal()) @@ -157,7 +157,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { if (funcOp .walk([&](memref::AllocOp allocOp) -> WalkResult { Value oldMemRef = allocOp.getResult(); - if (!isMemRefNormalizable(oldMemRef.getUsers())) + if (!oldMemRef.getType() + .cast() + .getLayout() + .isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); return WalkResult::advance(); }) @@ -170,7 +174,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { llvm::seq(0, callOp.getNumResults())) { Value oldMemRef = callOp.getResult(resIndex); if (oldMemRef.getType().isa()) - if (!isMemRefNormalizable(oldMemRef.getUsers())) + if (!oldMemRef.getType() + .cast() + .getLayout() + .isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -181,7 +189,8 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); if (oldMemRef.getType().isa()) - if (!isMemRefNormalizable(oldMemRef.getUsers())) + if (!oldMemRef.getType().cast().getLayout().isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) return false; } diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir index a16ae14..b45b62a 100644 --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -41,6 +41,24 @@ func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { return } +// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm +// does not block the normalization of other operations. + +// CHECK-LABEL: test_nonnorm_identity_layout +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>) +func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = memref.alloc() : memref<1x16x14x14xf32> + "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> () + "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> () + memref.dealloc %0 : memref<1x16x14x14xf32> + + // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32> + // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> () + // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> () + // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32> + return +} + // Test with op_norm, with maps in the operations in the function. // CHECK-LABEL: test_norm_mix -- 2.7.4