[mlir] make remaining memref dialect ops produce strided layouts
authorAlex Zinenko <zinenko@google.com>
Thu, 15 Sep 2022 16:29:38 +0000 (18:29 +0200)
committerAlex Zinenko <zinenko@google.com>
Fri, 16 Sep 2022 08:56:48 +0000 (10:56 +0200)
The three following ops in the memref dialect: transpose, expand_shape,
collapse_shape, have been originally designed to operate on memrefs with
strided layouts but had to go through the affine map representation as the type
did not support anything else. Make these ops produce memref values with
StridedLayoutAttr instead now that it is available.

Depends On D133938

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133947

13 files changed:
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

index 46f46f1..63eb409 100644 (file)
@@ -466,11 +466,6 @@ bool isStrided(MemRefType t);
 /// Return null if the layout is not compatible with a strided layout.
 AffineMap getStridedLinearLayoutMap(MemRefType t);
 
-/// Helper determining if a memref is static-shape and contiguous-row-major
-/// layout, while still allowing for an arbitrary offset (any static or
-/// dynamic value).
-bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType);
-
 } // namespace mlir
 
 #endif // MLIR_IR_BUILTINTYPES_H
index caef749..65f533b 100644 (file)
@@ -961,7 +961,25 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto srcType = op.getSource().getType().cast<BaseMemRefType>();
     auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
 
-    auto isContiguousMemrefType = [](BaseMemRefType type) {
+    auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) {
+      if (!type.hasStaticShape())
+        return false;
+
+      SmallVector<int64_t> strides;
+      int64_t offset;
+      if (failed(getStridesAndOffset(type, strides, offset)))
+        return false;
+
+      int64_t runningStride = 1;
+      for (unsigned i = strides.size(); i > 0; --i) {
+        if (strides[i - 1] != runningStride)
+          return false;
+        runningStride *= type.getDimSize(i - 1);
+      }
+      return true;
+    };
+
+    auto isContiguousMemrefType = [&](BaseMemRefType type) {
       auto memrefType = type.dyn_cast<mlir::MemRefType>();
       // We can use memcpy for memrefs if they have an identity layout or are
       // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
index c236171..c6c031b 100644 (file)
@@ -1761,7 +1761,7 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
 
 /// Compute the layout map after expanding a given source MemRef type with the
 /// specified reassociation indices.
-static FailureOr<AffineMap>
+static FailureOr<StridedLayoutAttr>
 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
                          ArrayRef<ReassociationIndices> reassociation) {
   int64_t srcOffset;
@@ -1798,8 +1798,7 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
   }
   auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
   resultStrides.resize(resultShape.size(), 1);
-  return makeStridedLinearLayoutMap(resultStrides, srcOffset,
-                                    srcType.getContext());
+  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
 }
 
 static FailureOr<MemRefType>
@@ -1814,14 +1813,12 @@ computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
   }
 
   // Source may not be contiguous. Compute the layout map.
-  FailureOr<AffineMap> computedLayout =
+  FailureOr<StridedLayoutAttr> computedLayout =
       computeExpandedLayoutMap(srcType, resultShape, reassociation);
   if (failed(computedLayout))
     return failure();
-  auto computedType =
-      MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
-                      srcType.getMemorySpaceAsInt());
-  return canonicalizeStridedLayout(computedType);
+  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
+                         srcType.getMemorySpace());
 }
 
 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
@@ -1855,10 +1852,9 @@ LogicalResult ExpandShapeOp::verify() {
     return emitOpError("invalid source layout map");
 
   // Check actual result type.
-  auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
-  if (*expectedResultType != canonicalizedResultType)
+  if (*expectedResultType != resultType)
     return emitOpError("expected expanded type to be ")
-           << *expectedResultType << " but found " << canonicalizedResultType;
+           << *expectedResultType << " but found " << resultType;
 
   return success();
 }
@@ -1877,7 +1873,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// not possible to check this by inspecting a MemRefType in the general case.
 /// If non-contiguity cannot be checked statically, the collapse is assumed to
 /// be valid (and thus accepted by this function) unless `strict = true`.
-static FailureOr<AffineMap>
+static FailureOr<StridedLayoutAttr>
 computeCollapsedLayoutMap(MemRefType srcType,
                           ArrayRef<ReassociationIndices> reassociation,
                           bool strict = false) {
@@ -1940,13 +1936,12 @@ computeCollapsedLayoutMap(MemRefType srcType,
         return failure();
     }
   }
-  return makeStridedLinearLayoutMap(resultStrides, srcOffset,
-                                    srcType.getContext());
+  return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
 }
 
 bool CollapseShapeOp::isGuaranteedCollapsible(
     MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
-  // MemRefs with standard layout are always collapsible.
+  // MemRefs with identity layout are always collapsible.
   if (srcType.getLayout().isIdentity())
     return true;
 
@@ -1978,14 +1973,12 @@ computeCollapsedType(MemRefType srcType,
   // Source may not be fully contiguous. Compute the layout map.
   // Note: Dimensions that are collapsed into a single dim are assumed to be
   // contiguous.
-  FailureOr<AffineMap> computedLayout =
+  FailureOr<StridedLayoutAttr> computedLayout =
       computeCollapsedLayoutMap(srcType, reassociation);
   assert(succeeded(computedLayout) &&
          "invalid source layout map or collapsing non-contiguous dims");
-  auto computedType =
-      MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
-                      srcType.getMemorySpaceAsInt());
-  return canonicalizeStridedLayout(computedType);
+  return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
+                         srcType.getMemorySpace());
 }
 
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -2021,21 +2014,19 @@ LogicalResult CollapseShapeOp::verify() {
     // Source may not be fully contiguous. Compute the layout map.
     // Note: Dimensions that are collapsed into a single dim are assumed to be
     // contiguous.
-    FailureOr<AffineMap> computedLayout =
+    FailureOr<StridedLayoutAttr> computedLayout =
         computeCollapsedLayoutMap(srcType, getReassociationIndices());
     if (failed(computedLayout))
       return emitOpError(
           "invalid source layout map or collapsing non-contiguous dims");
-    auto computedType =
+    expectedResultType =
         MemRefType::get(resultType.getShape(), srcType.getElementType(),
-                        *computedLayout, srcType.getMemorySpaceAsInt());
-    expectedResultType = canonicalizeStridedLayout(computedType);
+                        *computedLayout, srcType.getMemorySpace());
   }
 
-  auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
-  if (expectedResultType != canonicalizedResultType)
+  if (expectedResultType != resultType)
     return emitOpError("expected collapsed type to be ")
-           << expectedResultType << " but found " << canonicalizedResultType;
+           << expectedResultType << " but found " << resultType;
 
   return success();
 }
@@ -2709,24 +2700,26 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
                                            AffineMap permutationMap) {
   auto rank = memRefType.getRank();
   auto originalSizes = memRefType.getShape();
-  // Compute permuted sizes.
-  SmallVector<int64_t, 4> sizes(rank, 0);
-  for (const auto &en : llvm::enumerate(permutationMap.getResults()))
-    sizes[en.index()] =
-        originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
-
-  // Compute permuted strides.
   int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  auto res = getStridesAndOffset(memRefType, strides, offset);
-  assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
+  SmallVector<int64_t, 4> originalStrides;
+  auto res = getStridesAndOffset(memRefType, originalStrides, offset);
+  assert(succeeded(res) &&
+         originalStrides.size() == static_cast<unsigned>(rank));
   (void)res;
-  auto map =
-      makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
-  map = permutationMap ? map.compose(permutationMap) : map;
+
+  // Compute permuted sizes and strides.
+  SmallVector<int64_t> sizes(rank, 0);
+  SmallVector<int64_t> strides(rank, 1);
+  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
+    unsigned position = en.value().cast<AffineDimExpr>().getPosition();
+    sizes[en.index()] = originalSizes[position];
+    strides[en.index()] = originalStrides[position];
+  }
+
   return MemRefType::Builder(memRefType)
       .setShape(sizes)
-      .setLayout(AffineMapAttr::get(map));
+      .setLayout(
+          StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
index 65a9375..c6945ff 100644 (file)
@@ -136,11 +136,10 @@ struct CollapseShapeOpInterface
         int64_t offset;
         if (failed(getStridesAndOffset(bufferType, strides, offset)))
           return failure();
-        AffineMap resultLayout =
-            makeStridedLinearLayoutMap({}, offset, op->getContext());
-        resultType =
-            MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
-                            bufferType.getMemorySpaceAsInt());
+        resultType = MemRefType::get(
+            {}, tensorResultType.getElementType(),
+            StridedLayoutAttr::get(op->getContext(), offset, {}),
+            bufferType.getMemorySpace());
       }
 
       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
index c27c685..f3010fd 100644 (file)
@@ -2250,7 +2250,8 @@ void AsmPrinter::Impl::printType(Type type) {
           os << 'x';
         }
         printType(memrefTy.getElementType());
-        if (!memrefTy.getLayout().isIdentity()) {
+        MemRefLayoutAttrInterface layout = memrefTy.getLayout();
+        if (!layout.isa<AffineMapAttr>() || !layout.isIdentity()) {
           os << ", ";
           printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
         }
index b9b5bed..73421ae 100644 (file)
@@ -1027,40 +1027,3 @@ AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
     return AffineMap();
   return makeStridedLinearLayoutMap(strides, offset, t.getContext());
 }
-
-/// Return the AffineExpr representation of the offset, assuming `memRefType`
-/// is a strided memref.
-static AffineExpr getOffsetExpr(MemRefType memrefType) {
-  SmallVector<AffineExpr> strides;
-  AffineExpr offset;
-  if (failed(getStridesAndOffset(memrefType, strides, offset)))
-    assert(false && "expected strided memref");
-  return offset;
-}
-
-/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
-/// `offset` AffineExpr.
-static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
-                                                   ArrayRef<int64_t> shape,
-                                                   Type elementType,
-                                                   AffineExpr offset) {
-  AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
-  AffineExpr contiguousRowMajor = canonical + offset;
-  AffineMap contiguousRowMajorMap =
-      AffineMap::inferFromExprList({contiguousRowMajor})[0];
-  return MemRefType::get(shape, elementType, contiguousRowMajorMap);
-}
-
-/// Helper determining if a memref is static-shape and contiguous-row-major
-/// layout, while still allowing for an arbitrary offset (any static or
-/// dynamic value).
-bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
-  if (!memrefType.hasStaticShape())
-    return false;
-  AffineExpr offset = getOffsetExpr(memrefType);
-  MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
-      memrefType.getContext(), memrefType.getShape(),
-      memrefType.getElementType(), offset);
-  return canonicalizeStridedLayout(memrefType) ==
-         canonicalizeStridedLayout(contiguousRowMajorMemRefType);
-}
index 021e92c..61c426a 100644 (file)
@@ -609,7 +609,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 //       CHECK:   llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
+  %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 
@@ -725,12 +725,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
 // -----
 
 func.func @collapse_shape_dynamic_with_non_identity_layout(
-        %arg0 : memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>>) ->
-        memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> {
+        %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) ->
+        memref<4x?xf32, strided<[?, ?], offset: ?>> {
   %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
-    memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>> into
-    memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
-  return %0 : memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
+    memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
+    memref<4x?xf32, strided<[?, ?], offset: ?>>
+  return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>>
 }
 // CHECK-LABEL:   func @collapse_shape_dynamic_with_non_identity_layout(
 //       CHECK:      llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
@@ -898,12 +898,12 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
 // -----
 
 func.func @expand_shape_dynamic_with_non_identity_layout(
-            %arg0 : memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) ->
-            memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> {
+            %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
+            memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
   %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
-    memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into
-    memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
-  return %0 : memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
+    memref<1x?xf32, strided<[?, ?], offset: ?>> into
+    memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+  return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
 }
 // CHECK-LABEL:   func @expand_shape_dynamic_with_non_identity_layout(
 // CHECK:           llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
@@ -982,10 +982,10 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
 // -----
 
 // CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
-func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>>) -> memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> {
+func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> {
 // CHECK-NOT: memref.collapse_shape
-  %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
-  return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+  %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
+  return %1 : memref<64xf32, strided<[1], offset: ?>>
 }
 
 // -----
@@ -1069,13 +1069,11 @@ func.func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) {
 // -----
 
 // CHECK-LABEL: func @memref_copy_0d_offset
-#map0 = affine_map<(d0) -> (d0 + 1)>
-#map1 = affine_map<() -> (1)>
 func.func @memref_copy_0d_offset(%in: memref<2xi32>) {
   %buf = memref.alloc() : memref<i32>
-  %sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, #map0>
-  %scalar = memref.collapse_shape %sub [] : memref<1xi32, #map0> into memref<i32, #map1>
-  memref.copy %scalar, %buf : memref<i32, #map1> to memref<i32>
+  %sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
+  %scalar = memref.collapse_shape %sub [] : memref<1xi32, strided<[1], offset: 1>> into memref<i32, strided<[], offset: 1>>
+  memref.copy %scalar, %buf : memref<i32, strided<[], offset: 1>> to memref<i32>
   // CHECK: llvm.intr.memcpy
   return
 }
index b062dd2..7533c46 100644 (file)
@@ -23,8 +23,8 @@ func.func @buffer_forwarding_conflict(
   %f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>
 
   //     CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref<?xf32> to memref<?xf32>
-  //     CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32>
-  //     CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32>
+  //     CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32, strided<[1]>>
+  //     CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32, strided<[1]>>
   %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor<?xf32> into tensor<?xf32>
 
   //     CHECK: %[[T_SUBVIEW:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
index 651c669..ae68fb3 100644 (file)
@@ -6,8 +6,6 @@
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
 
-// CHECK: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
-
 func.func @views(%arg0: index) {
   %c0 = arith.constant 0 : index
   %0 = arith.muli %arg0, %arg0 : index
@@ -70,12 +68,12 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
 // -----
 
 func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
+  %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 // CHECK-LABEL: func @transpose
 //       CHECK:   memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
-//  CHECK-SAME:      memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, #[[$strided3DT]]>
+//  CHECK-SAME:      memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
 
 // -----
 
index 4c98b21..6dd5439 100644 (file)
@@ -424,7 +424,7 @@ func.func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
 
 func.func @expand_shape_invalid_result_layout(
     %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
-  // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}}
+  // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>' but found 'memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>'}}
   %0 = memref.expand_shape %arg0 [[0, 1], [2]] :
       memref<30x20xf32, strided<[4000, 2], offset: 100>>
       into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>
index 69c7e17..7d469df 100644 (file)
@@ -104,10 +104,10 @@ func.func @expand_collapse_shape_static(
     %arg1: tensor<3x4x5xf32>,
     %arg2: tensor<3x?x5xf32>,
     %arg3: memref<30x20xf32, strided<[4000, 2], offset: 100>>,
-    %arg4: memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>>,
+    %arg4: memref<1x5xf32, strided<[5, 1], offset: ?>>,
     %arg5: memref<f32>,
     %arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>,
-    %arg7: memref<1x2049xi64, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) {
+    %arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>) {
   // Reshapes that collapse and expand back a contiguous buffer.
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
@@ -157,8 +157,8 @@ func.func @expand_collapse_shape_static(
 
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
   %r4 = memref.expand_shape %arg4 [[0], [1, 2]] :
-      memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>> into
-      memref<1x1x5xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 5 + s0 + d2 + d1 * 5)>>
+      memref<1x5xf32, strided<[5, 1], offset: ?>> into
+      memref<1x1x5xf32, strided<[5, 5, 1], offset: ?>>
 
   // Note: Only the collapsed two shapes are contiguous in the follow test case.
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
@@ -168,8 +168,8 @@ func.func @expand_collapse_shape_static(
 
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1]]
   %r7 = memref.collapse_shape %arg7 [[0, 1]] :
-      memref<1x2049xi64, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into
-      memref<2049xi64, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
+      memref<1x2049xi64, strided<[?, ?], offset: ?>> into
+      memref<2049xi64, strided<[?], offset: ?>>
 
   // Reshapes that expand and collapse back a contiguous buffer with some 1's.
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
@@ -241,15 +241,15 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
     memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
 
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1]]
-//  CHECK-SAME:     memref<?x42xf32, strided<[42, 1]>> into memref<?xf32>
+//  CHECK-SAME:     memref<?x42xf32, strided<[42, 1]>> into memref<?xf32, strided<[1]>>
   %3 = memref.collapse_shape %arg3 [[0, 1]] :
     memref<?x42xf32, strided<[42, 1], offset: 0>> into
-    memref<?xf32>
+    memref<?xf32, strided<[1]>>
 
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1]]
-//  CHECK-SAME:     memref<?xf32> into memref<?x42xf32>
+//  CHECK-SAME:     memref<?xf32, strided<[1]>> into memref<?x42xf32>
   %r3 = memref.expand_shape %3 [[0, 1]] :
-    memref<?xf32> into memref<?x42xf32>
+    memref<?xf32, strided<[1]>> into memref<?x42xf32>
   return
 }
 
index 4d3d26c..7d8c4fd 100644 (file)
@@ -372,8 +372,6 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
 
 // -----
 
-// CHECK-DAG: #[[$MAP2b:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
-
 // CHECK-LABEL: func @tensor.expand_shape_of_slice(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
 func.func @tensor.expand_shape_of_slice(
@@ -383,7 +381,7 @@ func.func @tensor.expand_shape_of_slice(
   %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
       tensor<?x20xf32> to tensor<?x10xf32>
   // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
-  // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, #[[$MAP2b]]>
+  // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
   %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
       tensor<?x10xf32> into tensor<?x7x2x5xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
@@ -393,8 +391,6 @@ func.func @tensor.expand_shape_of_slice(
 
 // -----
 
-// CHECK-DAG: #[[$MAP10:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-
 // CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?xf32>
 func.func @tensor.expand_shape_of_scalar_slice(
@@ -402,7 +398,7 @@ func.func @tensor.expand_shape_of_scalar_slice(
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32>
   // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] :  memref<?xf32> to memref<f32, strided<[], offset: ?>>
   %0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, strided{{.*}}> into memref<1xf32, #[[$MAP10]]>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
   %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
   // CHECK: return %[[r]]
@@ -442,13 +438,11 @@ func.func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
-
 // CHECK-LABEL: func @tensor.collapse_shape_of_slice(
 func.func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> {
   // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
   %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32>
-  // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, strided<[1], offset: 1>> into memref<i32, #[[$MAP4]]>
+  // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, strided<[1], offset: 1>> into memref<i32, strided<[], offset: 1>>
   %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
   return %1 : tensor<i32>
 }
@@ -474,23 +468,19 @@ func.func @tensor.collapse_shape_of_slice2(
 
 // -----
 
-// CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)>
-
 // CHECK-LABEL: func @tensor.collapse_shape_of_slice3(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<1x2xf32>
 func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
   // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, strided<[2, 1]>>
   %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
   // CHECK: memref.collapse_shape %{{.*}} [
-  // CHECK-SAME: [0, 1]] : memref<1x1xf32, strided<[2, 1]>> into memref<1xf32, #[[$MAP6]]>
+  // CHECK-SAME: [0, 1]] : memref<1x1xf32, strided<[2, 1]>> into memref<1xf32, strided<[2]>>
   %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
   return %1 : tensor<1xf32>
 }
 
 // -----
 
-// CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
-
 // CHECK-LABEL:   func @tensor.collapse_shape_of_slice4(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x2x4xf32>,
 // CHECK-SAME:      %[[OFFSET:.*]]: index) -> tensor<8xf32> {
@@ -498,7 +488,7 @@ func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: in
   // CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>>
   %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32>
   // CHECK: memref.collapse_shape %{{.*}} [
-  // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> into memref<8xf32, #[[$MAP8]]>
+  // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> into memref<8xf32, strided<[4], offset: ?>>
   %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
   return %ret: tensor<8xf32>
 }
index eb0f60c..8158f49 100644 (file)
@@ -124,8 +124,8 @@ func.func @insert_slice_fun_not_inplace(
 {
   //      CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 128 : i64} : memref<?xf32>
   //      CHECK: memref.copy %[[A]], %[[ALLOC]] : memref<?xf32{{.*}} to memref<?xf32>
-  //      CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
-  //      CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32>
+  //      CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32, strided<[1]>>
+  //      CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32, strided<[1]>>
   %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   //     CHECK: return %{{.*}} : memref<?xf32>