[mlir][normalize-memrefs] NFC Follow-up D125854
authorTung D. Le <tung@jp.ibm.com>
Sat, 20 Aug 2022 02:00:28 +0000 (07:30 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Sun, 21 Aug 2022 10:34:41 +0000 (16:04 +0530)
NFC follow-up D125854 to reflect some remaining comments in D125854

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D132200

mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp

index 55ce128..345c71d 100644 (file)
@@ -157,10 +157,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
   if (funcOp
           .walk([&](memref::AllocOp allocOp) -> WalkResult {
             Value oldMemRef = allocOp.getResult();
-            if (!oldMemRef.getType()
-                     .cast<MemRefType>()
-                     .getLayout()
-                     .isIdentity() &&
+            if (!allocOp.getType().getLayout().isIdentity() &&
                 !isMemRefNormalizable(oldMemRef.getUsers()))
               return WalkResult::interrupt();
             return WalkResult::advance();
@@ -173,11 +170,9 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
             for (unsigned resIndex :
                  llvm::seq<unsigned>(0, callOp.getNumResults())) {
               Value oldMemRef = callOp.getResult(resIndex);
-              if (oldMemRef.getType().isa<MemRefType>())
-                if (!oldMemRef.getType()
-                         .cast<MemRefType>()
-                         .getLayout()
-                         .isIdentity() &&
+              if (auto oldMemRefType =
+                      oldMemRef.getType().dyn_cast<MemRefType>())
+                if (!oldMemRefType.getLayout().isIdentity() &&
                     !isMemRefNormalizable(oldMemRef.getUsers()))
                   return WalkResult::interrupt();
             }
@@ -188,8 +183,8 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
 
   for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
     BlockArgument oldMemRef = funcOp.getArgument(argIndex);
-    if (oldMemRef.getType().isa<MemRefType>())
-      if (!oldMemRef.getType().cast<MemRefType>().getLayout().isIdentity() &&
+    if (auto oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>())
+      if (!oldMemRefType.getLayout().isIdentity() &&
           !isMemRefNormalizable(oldMemRef.getUsers()))
         return false;
   }