/// Check whether all the uses of AllocOps, CallOps and function arguments of a
/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
/// or ReturnOp. Only if these constraints are satisfied will the function
-/// become a candidate for normalization. We follow a conservative approach here
-/// wherein even if the non-normalizable memref is not a part of the function's
-/// argument or return type, we still label the entire function as
-/// non-normalizable. We assume external functions to be normalizable.
+/// become a candidate for normalization. When the uses of a memref are
+/// non-normalizable and the memref map layout is trivial (identity), we can
+/// still label the entire function as normalizable. We assume external
+/// functions to be normalizable.
bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
// We assume external functions to be normalizable.
if (funcOp.isExternal())
if (funcOp
.walk([&](memref::AllocOp allocOp) -> WalkResult {
Value oldMemRef = allocOp.getResult();
- if (!isMemRefNormalizable(oldMemRef.getUsers()))
+ if (!oldMemRef.getType()
+ .cast<MemRefType>()
+ .getLayout()
+ .isIdentity() &&
+ !isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
return WalkResult::advance();
})
llvm::seq<unsigned>(0, callOp.getNumResults())) {
Value oldMemRef = callOp.getResult(resIndex);
if (oldMemRef.getType().isa<MemRefType>())
- if (!isMemRefNormalizable(oldMemRef.getUsers()))
+ if (!oldMemRef.getType()
+ .cast<MemRefType>()
+ .getLayout()
+ .isIdentity() &&
+ !isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
}
return WalkResult::advance();
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
if (oldMemRef.getType().isa<MemRefType>())
- if (!isMemRefNormalizable(oldMemRef.getUsers()))
+ if (!oldMemRef.getType().cast<MemRefType>().getLayout().isIdentity() &&
+ !isMemRefNormalizable(oldMemRef.getUsers()))
return false;
}
return
}
+// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm
+// does not block the normalization of other operations.
+
+// CHECK-LABEL: test_nonnorm_identity_layout
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
+func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
+ %0 = memref.alloc() : memref<1x16x14x14xf32>
+ "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
+ "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> ()
+ memref.dealloc %0 : memref<1x16x14x14xf32>
+
+ // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32>
+ // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
+ // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> ()
+ // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32>
+ return
+}
+
// Test with op_norm, with maps in the operations in the function.
// CHECK-LABEL: test_norm_mix