[mlir][memref] Make result normalization aware of the number symbols
authorKai Sasaki <lewuathe@gmail.com>
Thu, 29 Jun 2023 01:04:35 +0000 (10:04 +0900)
committerKai Sasaki <lewuathe@gmail.com>
Thu, 29 Jun 2023 01:04:53 +0000 (10:04 +0900)
Memref normalization fails to recognize the non-zero symbols used in the memref type itself with strided, offset information. It causes the crash with the type like `memref<128x512xf32, strided<[?, ?], offset: ?>>`. The original issue is here. https://github.com/llvm/llvm-project/issues/61345

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D150250

mlir/include/mlir/Dialect/Affine/Utils.h
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
mlir/test/Transforms/normalize-memrefs.mlir

index ec86f16..b3ccbff 100644 (file)
@@ -249,8 +249,7 @@ LogicalResult normalizeMemRef(memref::AllocOp *op);
 /// transformed to an identity map with a new shape being computed for the
 /// normalized memref type and returns it. The old memref type is simplify
 /// returned if the normalization failed.
-MemRefType normalizeMemRefType(MemRefType memrefType,
-                               unsigned numSymbolicOperands);
+MemRefType normalizeMemRefType(MemRefType memrefType);
 
 /// Given an operation, inserts one or more single result affine apply
 /// operations, results of which are exclusively used by this operation.
index d567093..1ba9c82 100644 (file)
@@ -1720,8 +1720,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
 
   // Fetch a new memref type after normalizing the old memref to have an
   // identity map layout.
-  MemRefType newMemRefType =
-      normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size());
+  MemRefType newMemRefType = normalizeMemRefType(memrefType);
   if (newMemRefType == memrefType)
     // Either memrefType already had an identity map or the map couldn't be
     // transformed to an identity map.
@@ -1772,8 +1771,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
   return success();
 }
 
-MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType,
-                                             unsigned numSymbolicOperands) {
+MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   unsigned rank = memrefType.getRank();
   if (rank == 0)
     return memrefType;
@@ -1784,6 +1782,7 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType,
     return memrefType;
   }
   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
+  unsigned numSymbolicOperands = layoutMap.getNumSymbols();
 
   // We don't do any checks for one-to-one'ness; we assume that it is
   // one-to-one.
index aa21497..33772cc 100644 (file)
@@ -367,8 +367,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
     }
     // Fetch a new memref type after normalizing the old memref to have an
     // identity map layout.
-    MemRefType newMemRefType = normalizeMemRefType(memrefType,
-                                                   /*numSymbolicOperands=*/0);
+    MemRefType newMemRefType = normalizeMemRefType(memrefType);
     if (newMemRefType == memrefType || funcOp.isExternal()) {
       // Either memrefType already had an identity map or the map couldn't be
       // transformed to an identity map.
@@ -475,8 +474,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
       }
       // Computing a new memref type after normalizing the old memref to have an
       // identity map layout.
-      MemRefType newMemRefType = normalizeMemRefType(memrefType,
-                                                     /*numSymbolicOperands=*/0);
+      MemRefType newMemRefType = normalizeMemRefType(memrefType);
       resultTypes.push_back(newMemRefType);
     }
 
@@ -513,9 +511,9 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
       resultTypes.push_back(resultType);
       continue;
     }
+
     // Fetch a new memref type after normalizing the old memref.
-    MemRefType newMemRefType = normalizeMemRefType(memrefType,
-                                                   /*numSymbolicOperands=*/0);
+    MemRefType newMemRefType = normalizeMemRefType(memrefType);
     if (newMemRefType == memrefType) {
       // Either memrefType already had an identity map or the map couldn't
       // be transformed to an identity map.
index 892d5e5..c7af033 100644 (file)
@@ -352,3 +352,14 @@ func.func @neg_map() -> memref<2x3xf32, #neg> {
   %0 = memref.alloc() : memref<2x3xf32, #neg>
   return %0 : memref<2x3xf32, #neg>
 }
+
+// CHECK-LABEL: func @memref_with_strided_offset
+func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = bufferization.to_memref %arg0 : memref<128x512xf32, strided<[?, ?], offset: ?>>
+  %subview = memref.subview %0[%arg2, 0] [%arg1, 512] [1, 1] : memref<128x512xf32, strided<[?, ?], offset: ?>> to memref<?x512xf32, strided<[?, ?], offset: ?>>
+  // CHECK: %{{.*}} = memref.cast %{{.*}} : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
+  %cast = memref.cast %subview : memref<?x512xf32, strided<[?, ?], offset: ?>> to memref<16x512xf32, strided<[?, ?], offset: ?>>
+  %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
+  return %1 : tensor<16x512xf32>
+}